Sfoglia il codice sorgente

Moved parsing of query mode exclusive commands to separate function

Stefan Schmidt 6 anni fa
parent
commit
ced8858d1c
1 ha cambiato i file con 42 aggiunte e 24 eliminazioni
  1. 42 24
      code/Core/Controller.py

+ 42 - 24
code/Core/Controller.py

@@ -1,6 +1,7 @@
 import os
 import readline
 import sys
+import re
 
 import pyparsing as pp
 import Core.AttackController as atkCtrl
@@ -241,6 +242,46 @@ class Controller:
             print("Unknown keyword '" + param + "', try 'help;' to get a list of allowed keywords'")
             print()
 
+    def internal_command(self, query: str) -> bool:
+        # Strip off semicolon, split into command and parameters
+        query = query.strip(";").split(" ", 1)
+        cmd = query[0].strip().lower()
+        if len(query) > 1:
+            params = [p for p in re.split("(,|\\\".*?\\\"|'.*?')", query[1]) if p.strip(",").strip()]
+            params = list(map(lambda x: x.strip().strip("\"'"), params))
+        else:
+            params = []
+
+        if cmd == "help":
+            self.process_help(params)
+            return True
+        elif cmd == "labels":
+            if not self.label_manager.labels:
+                print("No labels found.")
+            else:
+                print("Attacks listed in the label file:")
+                print()
+                for i, label in enumerate(self.label_manager.labels):
+                    print("Attack number:   " + str(i))
+                    print("Attack name:     " + str(label.attack_name))
+                    print("Attack note:     " + str(label.attack_note))
+                    print("Start timestamp: " + str(label.timestamp_start))
+                    print("End timestamp:   " + str(label.timestamp_end))
+                    print()
+            print()
+            return True
+        elif cmd == "tables":
+            self.statisticsDB.process_db_query("SELECT name FROM sqlite_master WHERE type='table';", True)
+            return True
+        elif cmd == "columns":
+            self.statisticsDB.process_db_query("SELECT * FROM " + params[0].lower(), False)
+            columns = self.statisticsDB.get_field_types(params[0].lower())
+            for column in columns:
+                print(column + ": " + columns[column])
+            return True
+
+        return False
+
     def enter_query_mode(self):
         """
         Enters into the query mode. This is a read-eval-print-loop, where the user can input named queries or SQL
@@ -285,30 +326,7 @@ class Controller:
             import sqlite3
             if sqlite3.complete_statement(buffer):
                 buffer = buffer.strip()
-                if buffer.lower().startswith('help'):
-                    buffer = buffer.strip(';')
-                    self.process_help(buffer.split(' ')[1:])
-                elif buffer.lower().strip() == 'labels;':
-                    if not self.label_manager.labels:
-                        print("No labels found.")
-                    else:
-                        print("Attacks listed in the label file:")
-                        print()
-                        for label in self.label_manager.labels:
-                            print("Attack name:     " + str(label.attack_name))
-                            print("Attack note:     " + str(label.attack_note))
-                            print("Start timestamp: " + str(label.timestamp_start))
-                            print("End timestamp:   " + str(label.timestamp_end))
-                            print()
-                    print()
-                elif buffer.lower().strip() == 'tables;':
-                    self.statisticsDB.process_db_query("SELECT name FROM sqlite_master WHERE type='table';", True)
-                elif buffer.lower().strip().startswith('columns '):
-                    self.statisticsDB.process_db_query("SELECT * FROM " + buffer.lower()[8:], False)
-                    columns = self.statisticsDB.get_field_types(buffer.lower()[8:].strip(";"))
-                    for column in columns:
-                        print(column + ": " + columns[column])
-                else:
+                if not self.internal_command(buffer):
                     try:
                         self.statisticsDB.process_db_query(buffer, True)
                     except sqlite3.Error as e: