Browse Source

Merge branch 'set_attack_note' of stefan.schmidt/ID2T-toolkit into master

Carlos Garcia 6 years ago
parent
commit
9aa9cd270c
2 changed files with 68 additions and 38 deletions
  1. 52 24
      code/Core/Controller.py
  2. 16 14
      code/Core/StatsDatabase.py

+ 52 - 24
code/Core/Controller.py

@@ -1,6 +1,7 @@
 import os
 import readline
 import sys
+import re
 
 import pyparsing as pp
 import Core.AttackController as atkCtrl
@@ -241,6 +242,53 @@ class Controller:
             print("Unknown keyword '" + param + "', try 'help;' to get a list of allowed keywords'")
             print()
 
+    def internal_command(self, query: str) -> bool:
+        # Strip off semicolon, split into command and parameters
+        query = query.strip(";").split(" ", 1)
+        cmd = query[0].strip().lower()
+        if len(query) > 1:
+            params = [p for p in re.split("(,|\\\".*?\\\"|'.*?')", query[1]) if p.strip(",").strip()]
+            params = list(map(lambda x: x.strip().strip("\"'"), params))
+        else:
+            params = []
+
+        if cmd == "help":
+            self.process_help(params)
+            return True
+        elif cmd == "labels":
+            if not self.label_manager.labels:
+                print("No labels found.")
+            else:
+                print("Attacks listed in the label file:")
+                print()
+                for i, label in enumerate(self.label_manager.labels):
+                    print("Attack number:   " + str(i))
+                    print("Attack name:     " + str(label.attack_name))
+                    print("Attack note:     " + str(label.attack_note))
+                    print("Attack seed:     " + str(label.seed))
+                    print("Start timestamp: " + str(label.timestamp_start))
+                    print("End timestamp:   " + str(label.timestamp_end))
+                    print()
+            print()
+            return True
+        elif cmd == "set":
+            if len(params) == 3:
+                if params[0].lower() == "attack_note":
+                    i = int(params[1])
+                    self.label_manager.labels[i].attack_note = params[2]
+                return True
+        elif cmd == "tables":
+            self.statisticsDB.process_db_query("SELECT name FROM sqlite_master WHERE type='table';", True)
+            return True
+        elif cmd == "columns":
+            self.statisticsDB.process_db_query("SELECT * FROM " + params[0].lower(), False)
+            columns = self.statisticsDB.get_field_types(params[0].lower())
+            for column in columns:
+                print(column + ": " + columns[column])
+            return True
+
+        return False
+
     def enter_query_mode(self):
         """
         Enters into the query mode. This is a read-eval-print-loop, where the user can input named queries or SQL
@@ -285,30 +333,7 @@ class Controller:
             import sqlite3
             if sqlite3.complete_statement(buffer):
                 buffer = buffer.strip()
-                if buffer.lower().startswith('help'):
-                    buffer = buffer.strip(';')
-                    self.process_help(buffer.split(' ')[1:])
-                elif buffer.lower().strip() == 'labels;':
-                    if not self.label_manager.labels:
-                        print("No labels found.")
-                    else:
-                        print("Attacks listed in the label file:")
-                        print()
-                        for label in self.label_manager.labels:
-                            print("Attack name:     " + str(label.attack_name))
-                            print("Attack note:     " + str(label.attack_note))
-                            print("Start timestamp: " + str(label.timestamp_start))
-                            print("End timestamp:   " + str(label.timestamp_end))
-                            print()
-                    print()
-                elif buffer.lower().strip() == 'tables;':
-                    self.statisticsDB.process_db_query("SELECT name FROM sqlite_master WHERE type='table';", True)
-                elif buffer.lower().strip().startswith('columns '):
-                    self.statisticsDB.process_db_query("SELECT * FROM " + buffer.lower()[8:], False)
-                    columns = self.statisticsDB.get_field_types(buffer.lower()[8:].strip(";"))
-                    for column in columns:
-                        print(column + ": " + columns[column])
-                else:
+                if not self.internal_command(buffer):
                     try:
                         self.statisticsDB.process_db_query(buffer, True)
                     except sqlite3.Error as e:
@@ -328,6 +353,9 @@ class Controller:
         readline.set_history_length(1000)
         readline.write_history_file(history_file)
 
+        # Save the label file, in case content has changed
+        self.label_manager.write_label_file(self.pcap_src_path)
+
     def create_statistics_plot(self, params: str, entropy: bool):
         """
         Plots the statistics to a file by using the given customization parameters.

+ 16 - 14
code/Core/StatsDatabase.py

@@ -1,6 +1,6 @@
 import os.path
 import random as rnd
-import re
+import typing
 import sqlite3
 import sys
 
@@ -360,23 +360,23 @@ class StatsDatabase:
                 for i in range(0, len(result)):
                     print(str(self.cursor.description[i][0]) + ": " + str(result[i]))
             else:
-                self._print_query_results(query_string_in, result)
+                self._print_query_results(query_string_in, result if isinstance(result, list) else [result])
 
         return result
 
-    def _print_query_results(self, query_string_in: str, result):
+    def _print_query_results(self, query_string_in: str, result: typing.List[typing.Union[str, float, int]]) -> None:
         """
         Prints the results of a query.
         Based on http://stackoverflow.com/a/20383011/3017719.
 
         :param query_string_in: The query the results belong to
-        :param result: The results of the query
+        :param result: The list of query results
         """
         # Print number of results according to type of result
-        if isinstance(result, list):
-            print("Query returned " + str(len(result)) + " records:\n")
-        else:
+        if len(result) == 1:
             print("Query returned 1 record:\n")
+        else:
+            print("Query returned " + str(len(result)) + " records:\n")
 
         # Print query results
         if query_string_in.lstrip().upper().startswith(
@@ -385,8 +385,13 @@ class StatsDatabase:
             columns = []
             tavnit = '|'
             separator = '+'
-            for cd in self.cursor.description:
-                widths.append(len(cd) + 10)
+            for index, cd in enumerate(self.cursor.description):
+                max_col_length = 0
+                if len(result) > 0:
+                    max_col_length = max(list(map(lambda x:
+                                                  len(str(x[index] if len(self.cursor.description) > 1 else x)),
+                                                  result)))
+                widths.append(max(len(cd[0]), max_col_length))
                 columns.append(cd[0])
             for w in widths:
                 tavnit += " %-" + "%ss |" % (w,)
@@ -394,11 +399,8 @@ class StatsDatabase:
             print(separator)
             print(tavnit % tuple(columns))
             print(separator)
-            if isinstance(result, list):
-                for row in result:
-                    print(tavnit % row)
-            else:
-                print(tavnit % result)
+            for row in result:
+                print(tavnit % row)
             print(separator)
         else:
             print(result)