Browse Source

Introduced "in" operator for named queries

Stefan Schmidt 6 years ago
parent
commit
9d9908b20d
4 changed files with 47 additions and 10 deletions
  1. 4 0
      code/Core/Controller.py
  2. 14 4
      code/Core/QueryParser.py
  3. 28 5
      code/Core/StatsDatabase.py
  4. 1 1
      code/Test/test_Queries.py

+ 4 - 0
code/Core/Controller.py

@@ -8,6 +8,7 @@ import Core.LabelManager as LabelManager
 import Core.Statistics as Statistics
 import Core.Statistics as Statistics
 import ID2TLib.PcapFile as PcapFile
 import ID2TLib.PcapFile as PcapFile
 import ID2TLib.Utility as Util
 import ID2TLib.Utility as Util
+import Core.StatsDatabase as StatsDB
 
 
 
 
 class Controller:
 class Controller:
@@ -291,6 +292,9 @@ class Controller:
                         for i in range(1, e.col):
                         for i in range(1, e.col):
                             sys.stderr.write(" ")
                             sys.stderr.write(" ")
                         sys.stderr.write("^\n\n")
                         sys.stderr.write("^\n\n")
+                    except StatsDB.QueryExecutionException as e:
+                        sys.stderr.write("An error occured: ")
+                        sys.stderr.write(e.args[0] + "\n")
                 buffer = ""
                 buffer = ""
 
 
         readline.set_history_length(1000)
         readline.set_history_length(1000)

+ 14 - 4
code/Core/QueryParser.py

@@ -3,13 +3,23 @@ import pyparsing as pp
 
 
 class QueryParser:
 class QueryParser:
     def __init__(self):
     def __init__(self):
+        # TODO: Try to disallow invalid combinations
+        # TODO: Have tests for invalid combinations
+        # TODO: allow lists as input, like: ipaddress(macaddress in [1,2,3])
         extractor = pp.Keyword("random") ^ pp.Keyword("first") ^ pp.Keyword("last")
         extractor = pp.Keyword("random") ^ pp.Keyword("first") ^ pp.Keyword("last")
         selector = pp.Keyword("most_used") ^ pp.Keyword("least_used") ^ pp.Keyword("avg") ^ pp.Keyword("all")
         selector = pp.Keyword("most_used") ^ pp.Keyword("least_used") ^ pp.Keyword("avg") ^ pp.Keyword("all")
-        attribute = pp.Keyword("ipaddress") ^ pp.Keyword("macaddress") ^ pp.Keyword("portnumber") ^ pp.Keyword("protocolname") ^ pp.Keyword("ttlvalue") ^ pp.Keyword("mssvalue") ^ pp.Keyword("winsize") ^ pp.Keyword("ipclass") ^ pp.Keyword("pktssent") ^ pp.Keyword("pktsreceived") ^ pp.Keyword("mss") ^ pp.Keyword("kbytesreceived") ^ pp.Keyword("kbytessent")
+        attribute = pp.Keyword("ipaddress") ^ pp.Keyword("macaddress") ^ pp.Keyword("portnumber") ^\
+                    pp.Keyword("protocolname") ^ pp.Keyword("ttlvalue") ^ pp.Keyword("mssvalue") ^\
+                    pp.Keyword("winsize") ^ pp.Keyword("ipclass") ^ pp.Keyword("pktssent") ^\
+                    pp.Keyword("pktsreceived") ^ pp.Keyword("mss") ^ pp.Keyword("kbytesreceived") ^\
+                    pp.Keyword("kbytessent")
         simple_selector_query = selector + pp.Suppress("(") + attribute + pp.Suppress(")")
         simple_selector_query = selector + pp.Suppress("(") + attribute + pp.Suppress(")")
 
 
-        param_selectors = pp.Keyword("ipaddress").setParseAction(pp.replaceWith("ipaddress_param")) ^ pp.Keyword("macaddress").setParseAction(pp.replaceWith("macaddress_param"))
-        operators = pp.Literal("=") ^ pp.Literal("<=") ^ pp.Literal("<") ^ pp.Literal(">=") ^ pp.Literal(">")
+        param_selectors = pp.Keyword("ipaddress").setParseAction(pp.replaceWith("ipaddress_param")) ^\
+                          pp.Keyword("macaddress").setParseAction(pp.replaceWith("macaddress_param"))
+
+        operators = pp.Literal("<=") ^ pp.Literal("<") ^ pp.Literal("=") ^\
+                    pp.Literal(">=") ^ pp.Literal(">") ^ pp.CaselessLiteral("in")
         expr = pp.Forward()
         expr = pp.Forward()
         comparison = pp.Group(attribute + operators + (pp.Word(pp.alphanums + ".:") ^ expr))
         comparison = pp.Group(attribute + operators + (pp.Word(pp.alphanums + ".:") ^ expr))
         parameterized_query = param_selectors + pp.Suppress("(") + pp.Group(pp.delimitedList(comparison)) + pp.Suppress(")")
         parameterized_query = param_selectors + pp.Suppress("(") + pp.Group(pp.delimitedList(comparison)) + pp.Suppress(")")
@@ -22,5 +32,5 @@ class QueryParser:
         expr << pp.Group(named_query)
         expr << pp.Group(named_query)
         self.full_query = named_query + pp.Suppress(";")
         self.full_query = named_query + pp.Suppress(";")
 
 
-    def parse_query(self, querystring):
+    def parse_query(self, querystring: str) -> pp.ParseResults:
         return self.full_query.parseString(querystring)
         return self.full_query.parseString(querystring)

+ 28 - 5
code/Core/StatsDatabase.py

@@ -25,6 +25,10 @@ def dict_gen(curs: sqlite3.Cursor):
             yield dict(zip(field_names, row))
             yield dict(zip(field_names, row))
 
 
 
 
+class QueryExecutionException(Exception):
+    pass
+
+
 class StatsDatabase:
 class StatsDatabase:
     def __init__(self, db_path: str):
     def __init__(self, db_path: str):
         """
         """
@@ -170,16 +174,35 @@ class StatsDatabase:
         for key, op, value in param_op_val:
         for key, op, value in param_op_val:
             if isinstance(value, pp.ParseResults):
             if isinstance(value, pp.ParseResults):
                 # If we have another query instead of a direct value, execute and replace it
                 # If we have another query instead of a direct value, execute and replace it
-                value = self._execute_query_list(value)[0][0]
+                rvalue = self._execute_query_list(value)
+
+                # Do we have a comparison operator with a multiple-result query?
+                if op is not "in" and value[0] in ['most_used', 'least_used', 'all']:
+                    raise QueryExecutionException("The extractor '" + value[0] + "' may return more than one result!")
+
+                # Make value contain a simple list with the results of the query
+                value = map(lambda x: str(x[0]), rvalue)
+            else:
+                # Make sure value is a list now to simplify handling
+                value = [value]
+
             # this makes sure that TEXT fields are queried by strings,
             # this makes sure that TEXT fields are queried by strings,
             # e.g. ipAddress=192.168.178.1 --is-converted-to--> ipAddress='192.168.178.1'
             # e.g. ipAddress=192.168.178.1 --is-converted-to--> ipAddress='192.168.178.1'
             if field_types.get(key) == 'TEXT':
             if field_types.get(key) == 'TEXT':
-                if not str(value).startswith("'") and not str(value).startswith('"'):
-                    value = "'" + value + "'"
+                def ensure_string(x):
+                    if not str(x).startswith("'") and not str(x).startswith('"'):
+                        return "'" + x + "'"
+                    else:
+                        return x
+                value = map(ensure_string, value)
+
+            # If we have more than one value, join them together, separated by commas
+            value = ",".join(map(str, value))
+
             # this replacement is required to remove ambiguity in SQL query
             # this replacement is required to remove ambiguity in SQL query
             if key == 'ipAddress':
             if key == 'ipAddress':
                 key = 'ip_mac.ipAddress'
                 key = 'ip_mac.ipAddress'
-            conditions.append(key + op + str(value))
+            conditions.append(key + " " + op + " (" + str(value) + ")")
 
 
         where_clause = " AND ".join(conditions)
         where_clause = " AND ".join(conditions)
         query += where_clause
         query += where_clause
@@ -246,7 +269,7 @@ class StatsDatabase:
     def _execute_query_list(self, query_list):
     def _execute_query_list(self, query_list):
         """
         """
         Recursively executes a list of named queries. They are of the following form:
         Recursively executes a list of named queries. They are of the following form:
-        ['macaddress_param', [['ipaddress', '=', ['most_used', 'ipaddress']]]]
+        ['macaddress_param', [['ipaddress', 'in', ['most_used', 'ipaddress']]]]
         :param query_list: The query statement list obtained from the query parser
         :param query_list: The query statement list obtained from the query parser
         :return: The result of the query (either a single result or a list).
         :return: The result of the query (either a single result or a list).
         """
         """

+ 1 - 1
code/Test/test_Queries.py

@@ -235,4 +235,4 @@ class TestQueries(unittest.TestCase):
         self.assertEqual(controller.statistics.process_db_query('all(protocolname)'), ['IPv4', 'TCP', 'UDP'])
         self.assertEqual(controller.statistics.process_db_query('all(protocolname)'), ['IPv4', 'TCP', 'UDP'])
 
 
     def test_nested_query(self):
     def test_nested_query(self):
-        self.assertEqual(controller.statistics.process_db_query('macaddress(ipaddress=most_used(ipaddress))'), '08:00:27:a3:83:43')
+        self.assertEqual(controller.statistics.process_db_query('macaddress(ipaddress in most_used(ipaddress))'), '08:00:27:a3:83:43')