Browse Source

only calculate needed interval stats with extra tests

Jens Keim 5 years ago
parent
commit
93cb2aa0e9

+ 1 - 1
code/Attack/BaseAttack.py

@@ -417,7 +417,7 @@ class BaseAttack(metaclass=abc.ABCMeta):
                 print('Error: Statistics-dependent attack parameter added without setting a statistics object first.')
                 exit(1)
 
-            ts = pr.pcap_processor(self.statistics.pcap_filepath, "False", Util.RESOURCE_DIR).get_timestamp_mu_sec(int(value))
+            ts = pr.pcap_processor(self.statistics.pcap_filepath, "False", Util.RESOURCE_DIR, "").get_timestamp_mu_sec(int(value))
             if 0 <= int(value) <= self.statistics.get_packet_count() and ts >= 0:
                 is_valid = True
                 param_name = atkParam.Parameter.INJECT_AT_TIMESTAMP

+ 1 - 1
code/CLI.py

@@ -170,7 +170,7 @@ class CLI(object):
         # Load PCAP statistics
         recalculate_intervals = None
         if self.args.recalculate_delete:
-            self.args.recalculate = True
+            recalculate_intervals = True
         elif self.args.recalculate_yes:
             recalculate_intervals = True
             self.args.recalculate = True

+ 12 - 7
code/Core/Statistics.py

@@ -58,11 +58,12 @@ class Statistics:
                                                                       "database:")
             i = 0
             if output:
-                print("ID".ljust(3) + " | " + "interval in seconds".ljust(30) + " | is_default")
+                print("ID".ljust(3) + " | " + "interval in seconds".ljust(30) + " | is_default" + " | extra_tests")
             for table in previous_interval_tables:
                 seconds = float(table[0][len("interval_statistics_"):])/1000000
                 if output:
-                    print(str(i).ljust(3) + " | " + str(seconds).ljust(30) + " | " + str(table[1]))
+                    print(str(i).ljust(3) + " | " + str(seconds).ljust(30) + " | " + str(table[1]).ljust(
+                        len("is_default")) + " | " + str(table[2]))
                 previous_intervals.append(seconds)
                 i = i + 1
         return previous_intervals
@@ -107,12 +108,15 @@ class Statistics:
             # Get interval statistics tables which already exist
             previous_intervals = self.list_previous_interval_statistic_tables()
 
-            self.pcap_proc = pr.pcap_processor(self.pcap_filepath, str(self.do_extra_tests), Util.RESOURCE_DIR)
+            self.pcap_proc = pr.pcap_processor(self.pcap_filepath, str(self.do_extra_tests), Util.RESOURCE_DIR,
+                                               self.path_db)
 
-            recalc_intervals = None
             if previous_intervals:
-                recalc_intervals = recalculate_intervals
-                while (recalc_intervals is None and not delete) or self.stats_db.get_db_outdated():
+                if delete:
+                    recalc_intervals = False
+                else:
+                    recalc_intervals = recalculate_intervals
+                while recalc_intervals is None:
                     user_input = input("Do you want to recalculate them as well? (y)es|(n)o|(d)elete: ")
                     if user_input.lower() == "yes" or user_input.lower() == "y":
                         recalc_intervals = True
@@ -144,7 +148,8 @@ class Statistics:
             if not flag_print_statistics and not flag_non_verbose:
                 self.stats_summary_new_db()
         elif intervals is not None and intervals != []:
-                self.pcap_proc = pr.pcap_processor(self.pcap_filepath, str(self.do_extra_tests), Util.RESOURCE_DIR)
+                self.pcap_proc = pr.pcap_processor(self.pcap_filepath, str(self.do_extra_tests), Util.RESOURCE_DIR,
+                                                   self.path_db)
 
                 # Get interval statistics tables which already exist
                 previous_intervals = self.list_previous_interval_statistic_tables(output=False)

+ 6 - 2
code/Core/StatsDatabase.py

@@ -181,8 +181,12 @@ class StatsDatabase:
         for current_interval in current_intervals:
             if current_interval == 0.0:
                 table_name = self.process_db_query("SELECT name FROM interval_tables WHERE is_default=1")
-                print("No user specified interval found. Using default interval: " +
-                      str(float(table_name[len("interval_statistics_"):])/1000000) + "s")
+                if table_name != []:
+                    substr = "Using default interval: " + str(float(table_name[len("interval_statistics_"):])/1000000) \
+                             + "s"
+                else:
+                    substr = "The default interval will used after it is calculated."
+                print("No user specified interval found. " + substr)
             else:
                 self.current_interval_statistics_tables.append("interval_statistics_" +
                                                                str(int(current_interval*1000000)))

+ 1 - 1
code/ID2TLib/PcapFile.py

@@ -21,7 +21,7 @@ class PcapFile(object):
         :param attack_pcap_path: The path to the PCAP file to merge with the PCAP at pcap_file_path
         :return: The file path of the resulting PCAP file
         """
-        pcap = pr.pcap_processor(self.pcap_file_path, "False", Util.RESOURCE_DIR)
+        pcap = pr.pcap_processor(self.pcap_file_path, "False", Util.RESOURCE_DIR, "")
         file_out_path = pcap.merge_pcaps(attack_pcap_path)
         return file_out_path
 

+ 35 - 14
code_boost/src/cxx/pcap_processor.cpp

@@ -1,6 +1,3 @@
-#include <pybind11/pybind11.h>
-namespace py = pybind11;
-
 #include "pcap_processor.h"
 
 using namespace Tins;
@@ -9,8 +6,10 @@ using namespace Tins;
  * Creates a new pcap_processor object.
  * @param path The path where the PCAP to get analyzed is locatated.
  */
-pcap_processor::pcap_processor(std::string path, std::string extraTests, std::string resourcePath) : stats(resourcePath) {
+pcap_processor::pcap_processor(std::string path, std::string extraTests, std::string resource_path, std::string database_path) : stats(resource_path) {
     filePath = path;
+    resourcePath = resource_path;
+    databasePath = database_path;
     hasUnrecognized = false;
     if(extraTests == "True")
         stats.setDoExtraTests(true);
@@ -157,7 +156,7 @@ bool pcap_processor::read_pcap_info(const std::string &filePath, std::size_t &to
  * Collect statistics of the loaded PCAP file. Calls for each packet the method process_packets.
  * param: user specified interval in seconds
  */
-void pcap_processor::collect_statistics(const py::list& intervals) {
+void pcap_processor::collect_statistics(py::list& intervals) {
     // Only process PCAP if file exists
     if (file_exists(filePath)) {
         std::cout << "Loading pcap..." << std::endl;
@@ -179,7 +178,12 @@ void pcap_processor::collect_statistics(const py::list& intervals) {
         std::vector<std::chrono::duration<int, std::micro>> timeIntervals;
         std::vector<std::chrono::microseconds> barriers;
 
-        if (intervals.size() == 0 || intervals[0].cast<double>() == 0) {
+        std::vector<double> intervals_vec;
+        for (auto interval: intervals) {
+            intervals_vec.push_back(interval.cast<double>());
+        }
+
+        if (intervals_vec.size() == 0 || intervals_vec[0] == 0) {
             int timeIntervalsNum = 100;
             std::chrono::microseconds lastTimestamp = stats.getTimestampLastPacket();
             std::chrono::microseconds captureDuration = lastTimestamp - firstTimestamp;
@@ -195,9 +199,12 @@ void pcap_processor::collect_statistics(const py::list& intervals) {
             timeIntervals.push_back(timeInterval);
             barriers.push_back(barrier);
         } else {
-            for (auto interval: intervals) {
-                double interval_double = interval.cast<double>();
-                timeInterval_microsec = static_cast<long>(interval_double * 1000000);
+            if (stats.getDoExtraTests()) {
+                statistics_db stats_db(databasePath, resourcePath);
+                stats_db.getNoneExtraTestsInveralStats(intervals_vec);
+            }
+            for (auto interval: intervals_vec) {
+                timeInterval_microsec = static_cast<long>(interval * 1000000);
                 intervalStartTimestamp.push_back(firstTimestamp);
                 std::chrono::duration<int, std::micro> timeInterval(timeInterval_microsec);
                 std::chrono::microseconds barrier = timeInterval;
@@ -406,9 +413,16 @@ void pcap_processor::process_packets(const Packet &pkt) {
  */
 void pcap_processor::write_to_database(std::string database_path, const py::list& intervals, bool del) {
     std::vector<std::chrono::duration<int, std::micro>> timeIntervals;
+    std::vector<double> intervals_vec;
     for (auto interval: intervals) {
-        double interval_double = interval.cast<double>();
-        std::chrono::duration<int, std::micro> timeInterval(static_cast<long>(interval_double * 1000000));
+        intervals_vec.push_back(interval.cast<double>());
+    }
+    if (stats.getDoExtraTests()) {
+        statistics_db stats_db(databasePath, resourcePath);
+        stats_db.getNoneExtraTestsInveralStats(intervals_vec);
+    }
+    for (auto interval: intervals_vec) {
+        std::chrono::duration<int, std::micro> timeInterval(static_cast<long>(interval * 1000000));
         timeIntervals.push_back(timeInterval);
     }
     stats.writeToDatabase(database_path, timeIntervals, del);
@@ -416,9 +430,16 @@ void pcap_processor::write_to_database(std::string database_path, const py::list
 
 void pcap_processor::write_new_interval_statistics(std::string database_path, const py::list& intervals) {
     std::vector<std::chrono::duration<int, std::micro>> timeIntervals;
+    std::vector<double> intervals_vec;
     for (auto interval: intervals) {
-        double interval_double = interval.cast<double>();
-        std::chrono::duration<int, std::micro> timeInterval(static_cast<long>(interval_double * 1000000));
+        intervals_vec.push_back(interval.cast<double>());
+    }
+    if (stats.getDoExtraTests()) {
+        statistics_db stats_db(databasePath, resourcePath);
+        stats_db.getNoneExtraTestsInveralStats(intervals_vec);
+    }
+    for (auto interval: intervals_vec) {
+        std::chrono::duration<int, std::micro> timeInterval(static_cast<long>(interval * 1000000));
         timeIntervals.push_back(timeInterval);
     }
     stats.writeIntervalsToDatabase(database_path, timeIntervals, false);
@@ -465,7 +486,7 @@ bool inline pcap_processor::file_exists(const std::string &filePath) {
  */
 PYBIND11_MODULE (libpcapreader, m) {
     py::class_<pcap_processor>(m, "pcap_processor")
-            .def(py::init<std::string, std::string, std::string>())
+            .def(py::init<std::string, std::string, std::string, std::string>())
             .def("merge_pcaps", &pcap_processor::merge_pcaps)
             .def("collect_statistics", &pcap_processor::collect_statistics)
             .def("get_timestamp_mu_sec", &pcap_processor::get_timestamp_mu_sec)

+ 4 - 2
code_boost/src/cxx/pcap_processor.h

@@ -27,13 +27,15 @@ public:
     /*
     * Class constructor
     */
-    pcap_processor(std::string path, std::string extraTests, std::string resource_path);
+    pcap_processor(std::string path, std::string extraTests, std::string resource_path, std::string database_path);
 
     /*
      * Attributes
      */
     statistics stats;
     std::string filePath;
+    std::string databasePath;
+    std::string resourcePath;
     bool hasUnrecognized;
     std::chrono::duration<int, std::micro> timeInterval;
 
@@ -50,7 +52,7 @@ public:
 
     bool read_pcap_info(const std::string &filePath, std::size_t &totalPakets);
 
-    void collect_statistics(const py::list& intervals);
+    void collect_statistics(py::list& intervals);
 
     void write_to_database(std::string database_path, const py::list& intervals, bool del);
 

+ 3 - 3
code_boost/src/cxx/statistics.cpp

@@ -480,7 +480,7 @@ void statistics::incrementUnrecognizedPDUCount(const std::string &srcMac, const
 /**
  * Creates a new statistics object.
  */
-statistics::statistics(std::string resourcePath) {
+statistics::statistics(std::string resourcePath) {;
     this->resourcePath = resourcePath;
 }
 
@@ -790,7 +790,7 @@ void statistics::writeToDatabase(std::string database_path, std::vector<std::chr
         db.writeStatisticsWin(win_distribution);
         db.writeStatisticsConv(conv_statistics);
         db.writeStatisticsConvExt(conv_statistics_extended);
-        db.writeStatisticsInterval(interval_statistics, timeIntervals, del, this->default_interval);
+        db.writeStatisticsInterval(interval_statistics, timeIntervals, del, this->default_interval, this->getDoExtraTests());
         db.writeDbVersion();
         db.writeStatisticsUnrecognizedPDUs(unrecognized_PDUs);
     }
@@ -803,5 +803,5 @@ void statistics::writeToDatabase(std::string database_path, std::vector<std::chr
 
 void statistics::writeIntervalsToDatabase(std::string database_path, std::vector<std::chrono::duration<int, std::micro>> timeIntervals, bool del) {
     statistics_db db(database_path, resourcePath);
-    db.writeStatisticsInterval(interval_statistics, timeIntervals, del, this->default_interval);
+    db.writeStatisticsInterval(interval_statistics, timeIntervals, del, this->default_interval, this->getDoExtraTests());
 }

+ 62 - 7
code_boost/src/cxx/statistics_db.cpp

@@ -27,6 +27,35 @@ statistics_db::statistics_db(std::string database_path, std::string resourcePath
     readPortServicesFromNmap();
 }
 
+void statistics_db::getNoneExtraTestsInveralStats(std::vector<double>& intervals){
+    try {
+        //SQLite::Statement query(*db, "SELECT name FROM sqlite_master WHERE type='table' AND name='interval_tables';");
+        std::vector<std::string> tables;
+        try {
+            SQLite::Statement query(*db, "SELECT name FROM interval_tables WHERE extra_tests=1;");
+            while (query.executeStep()) {
+                tables.push_back(query.getColumn(0));
+            }
+        } catch (std::exception &e) {
+            std::cerr << "Exception in statistics_db::" << __func__ << ": " << e.what() << std::endl;
+        }
+        if (tables.size() != 0) {
+            std::string table_name;
+            double interval;
+            for (auto table = tables.begin(); table != tables.end(); table++) {
+                table_name = table->substr(std::string("interval_statistics_").length());
+                interval = static_cast<double>(::atof(table_name.c_str()))/1000000;
+                auto found = std::find(intervals.begin(), intervals.end(), interval);
+                if (found != intervals.end()) {
+                    intervals.erase(found, found);
+                }
+            }
+        }
+    } catch (std::exception &e) {
+        std::cerr << "Exception in statistics_db::" << __func__ << ": " << e.what() << std::endl;
+    }
+}
+
 /**
  * Writes the IP statistics into the database.
  * @param ipStatistics The IP statistics from class statistics.
@@ -564,11 +593,12 @@ void statistics_db::writeStatisticsConvExt(std::unordered_map<convWithProt, entr
  * Writes the interval statistics into the database.
  * @param intervalStatistics The interval entries from class statistics.
  */
-void statistics_db::writeStatisticsInterval(const std::unordered_map<std::string, entry_intervalStat> &intervalStatistics, std::vector<std::chrono::duration<int, std::micro>> timeIntervals, bool del, int defaultInterval){
+void statistics_db::writeStatisticsInterval(const std::unordered_map<std::string, entry_intervalStat> &intervalStatistics, std::vector<std::chrono::duration<int, std::micro>> timeIntervals, bool del, int defaultInterval, bool extraTests){
     try {
         // remove old tables produced by prior database versions
         db->exec("DROP TABLE IF EXISTS interval_statistics");
 
+        // delete all former interval statistics, if requested
         if (del) {
             SQLite::Statement query(*db, "SELECT name FROM sqlite_master WHERE type='table' AND name LIKE 'interval_statistics_%';");
             std::vector<std::string> previous_tables;
@@ -581,28 +611,53 @@ void statistics_db::writeStatisticsInterval(const std::unordered_map<std::string
             db->exec("DROP TABLE IF EXISTS interval_tables");
         }
 
-        db->exec("CREATE TABLE IF NOT EXISTS interval_tables (name TEXT, is_default INTEGER);");
-        std::string is_default = "0";
+        // create interval table index
+        db->exec("CREATE TABLE IF NOT EXISTS interval_tables (name TEXT, is_default INTEGER, extra_tests INTEGER);");
+
+        std::string default_table_name = "";
+        // get name for default table
+        try {
+            SQLite::Statement query(*db, "SELECT name FROM interval_tables WHERE is_default=1;");
+            query.executeStep();
+            default_table_name = query.getColumn(0).getString();
 
+        } catch (std::exception &e) {
+            std::cerr << "Exception in statistics_db::" << __func__ << ": " << e.what() << std::endl;
+        }
+
+        // handle default interval only runs
+        std::string is_default = "0";
+        std::chrono::duration<int, std::micro> defaultTimeInterval(defaultInterval);
         if (defaultInterval != 0.0) {
             is_default = "1";
-            std::chrono::duration<int, std::micro> defaultTimeInterval(defaultInterval);
             if (timeIntervals.empty() || timeIntervals[0].count() == 0) {
                 timeIntervals.clear();
                 timeIntervals.push_back(defaultTimeInterval);
             }
         }
 
+        // extra tests handling
+        std::string extra = "0";
+        if (extraTests) {
+            extra = "1";
+        }
+
         for (auto timeInterval: timeIntervals) {
+            // get interval statistics table name
             std::ostringstream strs;
             strs << timeInterval.count();
             std::string table_name = "interval_statistics_" + strs.str();
 
+            // check for recalculation of default table
+            if (table_name == default_table_name || timeInterval == defaultTimeInterval) {
+                is_default = "1";
+            } else {
+                is_default = "0";
+            }
+
             // add interval_tables entry
             db->exec("DELETE FROM interval_tables WHERE name = '" + table_name + "';");
-            db->exec("INSERT INTO interval_tables VALUES ('" + table_name + "', '" + is_default + "');");
-
-            is_default = "0";
+            db->exec("INSERT INTO interval_tables VALUES ('" + table_name + "', '" + is_default + "', '" + extra + "');");
 
             // new interval statistics implementation
             db->exec("DROP TABLE IF EXISTS " + table_name);

+ 10 - 3
code_boost/src/cxx/statistics_db.h

@@ -9,9 +9,12 @@
 #include <memory>
 #include <string>
 #include "statistics.h"
+#include <pybind11/pybind11.h>
 #include <SQLiteCpp/SQLiteCpp.h>
 #include <unordered_map>
 
+namespace py = pybind11;
+
 class statistics_db {
 public:
     /*
@@ -22,7 +25,12 @@ public:
     /*
      * Database version: Increment number on every change in the C++ code!
      */
-    static const int DB_VERSION = 14;
+    static const int DB_VERSION = 16;
+
+    /*
+     * Methods to read from database
+     */
+    void getNoneExtraTestsInveralStats(std::vector<double>& intervals);
 
     /*
      * Methods for writing values into database
@@ -54,7 +62,7 @@ public:
 
     void writeStatisticsConvExt(std::unordered_map<convWithProt, entry_convStatExt> &conv_statistics_extended);
 
-    void writeStatisticsInterval(const std::unordered_map<std::string, entry_intervalStat> &intervalStatistics, std::vector<std::chrono::duration<int, std::micro>> timeInterval, bool del, int defaultInterval);
+    void writeStatisticsInterval(const std::unordered_map<std::string, entry_intervalStat> &intervalStatistics, std::vector<std::chrono::duration<int, std::micro>> timeInterval, bool del, int defaultInterval, bool extraTests);
 
     void writeDbVersion();
 
@@ -62,7 +70,6 @@ public:
 
     void writeStatisticsUnrecognizedPDUs(const std::unordered_map<unrecognized_PDU, unrecognized_PDU_stat> &unrecognized_PDUs);
 
-
 private:
     // Pointer to the SQLite database
     std::unique_ptr<SQLite::Database> db;