123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406 |
- import os.path
- import random as rnd
- import typing
- import sqlite3
- import sys
- import ID2TLib.libpcapreader as pr
- import Core.QueryParser as qp
- import pyparsing as pp
- def dict_gen(curs: sqlite3.Cursor):
- """
- Generates a dictionary of a sqlite3.Cursor object by fetching the query's results.
- Taken from Python Essential Reference by David Beazley.
- """
- field_names = [d[0] for d in curs.description]
- while True:
- rows = curs.fetchmany()
- if not rows:
- return
- for row in rows:
- yield dict(zip(field_names, row))
- class QueryExecutionException(Exception):
- pass
- class StatsDatabase:
- def __init__(self, db_path: str):
- """
- Creates a new StatsDatabase.
- :param db_path: The path to the database file
- """
- self.query_parser = qp.QueryParser()
- self.existing_db = os.path.exists(db_path)
- self.database = sqlite3.connect(db_path)
- self.cursor = self.database.cursor()
-
- if self.existing_db:
- if self.get_db_outdated():
- print('Statistics database outdated. Recreating database at: ', db_path)
- else:
- print('Located statistics database at: ', db_path)
- else:
- print('Statistics database not found. Creating new database at: ', db_path)
- def get_file_info(self):
- """
- Retrieves general file statistics from the database. This includes:
- - packetCount : Number of packets in the PCAP file
- - captureDuration : Duration of the packet capture in seconds
- - timestampFirstPacket : Timestamp of the first captured packet
- - timestampLastPacket : Timestamp of the last captured packet
- - avgPacketRate : Average packet rate
- - avgPacketSize : Average packet size
- - avgPacketsSentPerHost : Average number of packets sent per host
- - avgBandwidthIn : Average incoming bandwidth
- - avgBandwidthOut : Average outgoing bandwidth
- :return: a dictionary of keys (see above) and their respective values
- """
- return [r for r in dict_gen(
- self.cursor.execute('SELECT * FROM file_statistics'))][0]
- def get_db_exists(self):
- """
- :return: True if the database was already existent, otherwise False
- """
- return self.existing_db
- def get_db_outdated(self):
- """
- Retrieves the database version from the database and compares it to the version
- it should have to check whether the database is outdated and needs to be recreated.
- :return: True if the versions match, otherwise False
- """
- self.cursor.execute('PRAGMA user_version;')
- return self.cursor.fetchall()[0][0] != pr.pcap_processor.get_db_version()
- @staticmethod
- def _get_selector_keywords():
- """
- :return: a list of selector keywords
- """
- return ['most_used', 'least_used', 'avg', 'all']
- @staticmethod
- def _get_parametrized_selector_keywords():
- """
- :return: a list of parameterizable selector keywords
- """
- return ['ipaddress', 'macaddress']
- @staticmethod
- def _get_extractor_keywords():
- """
- :return: a list of extractor keywords
- """
- return ['random', 'first', 'last']
- def get_all_named_query_keywords(self):
- """
- :return: a list of all named query keywords, used to identify named queries
- """
- return (
- self._get_selector_keywords() + self._get_parametrized_selector_keywords() + self._get_extractor_keywords())
- @staticmethod
- def get_all_sql_query_keywords():
- """
- :return: a list of all supported SQL keywords, used to identify SQL queries
- """
- return ["select", "insert"]
- def process_user_defined_query(self, query_string: str, query_parameters: tuple = None):
- """
- Takes as input a SQL query query_string and optional a tuple of parameters which are marked by '?' in the query
- and later substituted.
- :param query_string: The query to execute
- :param query_parameters: The tuple of parameters to inject into the query
- :return: the results of the query
- """
- if query_parameters is not None:
- self.cursor.execute(query_string, query_parameters)
- else:
- self.cursor.execute(query_string)
- self.database.commit()
- return self.cursor.fetchall()
- def get_field_types(self, *table_names):
- """
- Creates a dictionary whose keys are the fields of the given table(s) and whose values are the appropriate field
- types, like TEXT for strings and REAL for float numbers.
- :param table_names: The name of table(s)
- :return: a dictionary of {field_name : field_type} for fields of all tables
- """
- dic = {}
- for table in table_names:
- self.cursor.execute("PRAGMA table_info('%s')" % table)
- results = self.cursor.fetchall()
- for field in results:
- dic[field[1].lower()] = field[2]
- return dic
- def named_query_parameterized(self, keyword: str, param_op_val: list):
- """
- Executes a parameterizable named query.
- :param keyword: The query to be executed, like ipaddress or macadress
- :param param_op_val: A list consisting of triples with (parameter, operator, value)
- :return: the results of the executed query
- """
- named_queries = {
- "ipaddress": "SELECT DISTINCT ip_statistics.ipAddress from ip_statistics INNER JOIN ip_mac, ip_ttl, "
- "ip_ports, ip_protocols ON ip_statistics.ipAddress=ip_mac.ipAddress AND "
- "ip_statistics.ipAddress=ip_ttl.ipAddress AND ip_statistics.ipAddress=ip_ports.ipAddress "
- "AND ip_statistics.ipAddress=ip_protocols.ipAddress WHERE ",
- "macaddress": "SELECT DISTINCT macAddress from ip_mac WHERE "}
- query = named_queries.get(keyword)
- field_types = self.get_field_types('ip_mac', 'ip_ttl', 'ip_ports', 'ip_protocols', 'ip_statistics', 'ip_mac')
- conditions = []
- for key, op, value in param_op_val:
-
- if isinstance(value, pp.ParseResults):
- if value[0] == "list":
-
- value = value[1:]
-
- if op is not "in":
- raise QueryExecutionException("List values require the usage of the 'in' operator!")
- else:
-
- rvalue = self._execute_query_list(value)
-
- if op is not "in" and value[0] in ['most_used', 'least_used', 'all', 'ipaddress_param',
- 'macaddress_param']:
- raise QueryExecutionException("The extractor '" + value[0] +
- "' may return more than one result!")
-
- value = map(lambda x: str(x[0]), rvalue)
- else:
-
- value = [value]
-
-
- if field_types.get(key) == 'TEXT':
- def ensure_string(x):
- if not str(x).startswith("'") and not str(x).startswith('"'):
- return "'" + x + "'"
- else:
- return x
- value = map(ensure_string, value)
-
- value = ",".join(map(str, value))
-
- if key == 'ipAddress':
- key = 'ip_mac.ipAddress'
- conditions.append(key + " " + op + " (" + str(value) + ")")
- where_clause = " AND ".join(conditions)
- query += where_clause
- self.cursor.execute(query)
- return self.cursor.fetchall()
- named_queries = {
- "most_used.ipaddress": "SELECT ipAddress FROM ip_statistics WHERE (pktsSent+pktsReceived) == "
- "(SELECT MAX(pktsSent+pktsReceived) from ip_statistics) ORDER BY ipAddress ASC",
- "most_used.macaddress": "SELECT macAddress FROM (SELECT macAddress, COUNT(*) as occ from ip_mac GROUP BY "
- "macAddress) WHERE occ=(SELECT COUNT(*) as occ from ip_mac GROUP BY macAddress "
- "ORDER BY occ DESC LIMIT 1) ORDER BY macAddress ASC",
- "most_used.portnumber": "SELECT portNumber FROM ip_ports GROUP BY portNumber HAVING COUNT(portNumber)="
- "(SELECT MAX(cntPort) from (SELECT portNumber, COUNT(portNumber) as cntPort FROM "
- "ip_ports GROUP BY portNumber)) ORDER BY portNumber ASC",
- "most_used.protocolname": "SELECT protocolName FROM ip_protocols GROUP BY protocolName HAVING "
- "COUNT(protocolCount)=(SELECT COUNT(protocolCount) as cnt FROM ip_protocols "
- "GROUP BY protocolName ORDER BY cnt DESC LIMIT 1) ORDER BY protocolName ASC",
- "most_used.ttlvalue": "SELECT ttlValue FROM (SELECT ttlValue, SUM(ttlCount) as occ FROM ip_ttl GROUP BY "
- "ttlValue) WHERE occ=(SELECT SUM(ttlCount) as occ FROM ip_ttl GROUP BY ttlValue "
- "ORDER BY occ DESC LIMIT 1) ORDER BY ttlValue ASC",
- "most_used.mssvalue": "SELECT mssValue FROM (SELECT mssValue, SUM(mssCount) as occ FROM tcp_mss GROUP BY "
- "mssValue) WHERE occ=(SELECT SUM(mssCount) as occ FROM tcp_mss GROUP BY mssValue "
- "ORDER BY occ DESC LIMIT 1) ORDER BY mssValue ASC",
- "most_used.winsize": "SELECT winSize FROM (SELECT winSize, SUM(winCount) as occ FROM tcp_win GROUP BY "
- "winSize) WHERE occ=(SELECT SUM(winCount) as occ FROM tcp_win GROUP BY winSize ORDER "
- "BY occ DESC LIMIT 1) ORDER BY winSize ASC",
- "most_used.ipclass": "SELECT ipClass FROM (SELECT ipClass, COUNT(*) as occ from ip_statistics GROUP BY "
- "ipClass ORDER BY occ DESC) WHERE occ=(SELECT COUNT(*) as occ from ip_statistics "
- "GROUP BY ipClass ORDER BY occ DESC LIMIT 1) ORDER BY ipClass ASC",
- "least_used.ipaddress": "SELECT ipAddress FROM ip_statistics WHERE (pktsSent+pktsReceived) == (SELECT "
- "MIN(pktsSent+pktsReceived) from ip_statistics) ORDER BY ipAddress ASC",
- "least_used.macaddress": "SELECT macAddress FROM (SELECT macAddress, COUNT(*) as occ from ip_mac GROUP "
- "BY macAddress) WHERE occ=(SELECT COUNT(*) as occ from ip_mac GROUP BY macAddress "
- "ORDER BY occ ASC LIMIT 1) ORDER BY macAddress ASC",
- "least_used.portnumber": "SELECT portNumber FROM ip_ports GROUP BY portNumber HAVING COUNT(portNumber)="
- "(SELECT MIN(cntPort) from (SELECT portNumber, COUNT(portNumber) as cntPort FROM "
- "ip_ports GROUP BY portNumber)) ORDER BY portNumber ASC",
- "least_used.protocolname": "SELECT protocolName FROM ip_protocols GROUP BY protocolName HAVING "
- "COUNT(protocolCount)=(SELECT COUNT(protocolCount) as cnt FROM ip_protocols "
- "GROUP BY protocolName ORDER BY cnt ASC LIMIT 1) ORDER BY protocolName ASC",
- "least_used.ttlvalue": "SELECT ttlValue FROM (SELECT ttlValue, SUM(ttlCount) as occ FROM ip_ttl GROUP BY "
- "ttlValue) WHERE occ=(SELECT SUM(ttlCount) as occ FROM ip_ttl GROUP BY ttlValue "
- "ORDER BY occ ASC LIMIT 1) ORDER BY ttlValue ASC",
- "least_used.mssvalue": "SELECT mssValue FROM (SELECT mssValue, SUM(mssCount) as occ FROM tcp_mss GROUP BY "
- "mssValue) WHERE occ=(SELECT SUM(mssCount) as occ FROM tcp_mss GROUP BY mssValue "
- "ORDER BY occ ASC LIMIT 1) ORDER BY mssValue ASC",
- "least_used.winsize": "SELECT winSize FROM (SELECT winSize, SUM(winCount) as occ FROM tcp_win GROUP BY "
- "winSize) WHERE occ=(SELECT SUM(winCount) as occ FROM tcp_win GROUP BY winSize "
- "ORDER BY occ ASC LIMIT 1) ORDER BY winSize ASC",
- "least_used.ipclass": "SELECT ipClass FROM (SELECT ipClass, COUNT(*) as occ from ip_statistics GROUP BY "
- "ipClass ORDER BY occ DESC) WHERE occ=(SELECT COUNT(*) as occ from ip_statistics "
- "GROUP BY ipClass ORDER BY occ ASC LIMIT 1) ORDER BY ipClass ASC",
- "avg.pktsreceived": "SELECT avg(pktsReceived) from ip_statistics",
- "avg.pktssent": "SELECT avg(pktsSent) from ip_statistics",
- "avg.kbytesreceived": "SELECT avg(kbytesReceived) from ip_statistics",
- "avg.kbytessent": "SELECT avg(kbytesSent) from ip_statistics",
- "avg.ttlvalue": "SELECT avg(ttlValue) from ip_ttl",
- "avg.mss": "SELECT avg(mssValue) from tcp_mss",
- "all.ipaddress": "SELECT ipAddress from ip_statistics ORDER BY ipAddress ASC",
- "all.ttlvalue": "SELECT DISTINCT ttlValue from ip_ttl ORDER BY ttlValue ASC",
- "all.mss": "SELECT DISTINCT mssValue from tcp_mss ORDER BY mssValue ASC",
- "all.macaddress": "SELECT DISTINCT macAddress from ip_mac ORDER BY macAddress ASC",
- "all.portnumber": "SELECT DISTINCT portNumber from ip_ports ORDER BY portNumber ASC",
- "all.protocolname": "SELECT DISTINCT protocolName from ip_protocols ORDER BY protocolName ASC",
- "all.winsize": "SELECT DISTINCT winSize FROM tcp_win ORDER BY winSize ASC",
- "all.ipclass": "SELECT DISTINCT ipClass FROM ip_statistics ORDER BY ipClass ASC"}
- def _execute_query_list(self, query_list):
- """
- Recursively executes a list of named queries. They are of the following form:
- ['macaddress_param', [['ipaddress', 'in', ['most_used', 'ipaddress']]]]
- :param query_list: The query statement list obtained from the query parser
- :return: The result of the query (either a single result or a list).
- """
- if query_list[0] == "random":
- return [rnd.choice(self._execute_query_list(query_list[1:]))]
- elif query_list[0] == "first":
- return [self._execute_query_list(query_list[1:])[0]]
- elif query_list[0] == "last":
- return [self._execute_query_list(query_list[1:])[-1]]
- elif query_list[0] == "macaddress_param":
- return self.named_query_parameterized("macaddress", query_list[1])
- elif query_list[0] == "ipaddress_param":
- return self.named_query_parameterized("ipaddress", query_list[1])
- else:
- query = self.named_queries.get(query_list[0] + "." + query_list[1])
- if query is None:
- raise QueryExecutionException("The requested query '" + query_list[0] + "(" + query_list[1] +
- ")' was not found in the internal query list!")
- self.cursor.execute(str(query))
- last_result = self.cursor.fetchall()
- return last_result
- def process_db_query(self, query_string_in: str, print_results=False, sql_query_parameters: tuple = None):
- """
- Processes a database query. This can either be a standard SQL query or a named query (predefined query).
- :param query_string_in: The string containing the query
- :param print_results: Indicated whether the results should be printed to terminal (True) or not (False)
- :param sql_query_parameters: Parameters for the SQL query (optional)
- :return: the results of the query
- """
- named_query_keywords = self.get_all_named_query_keywords()
-
- query_string = query_string_in.lower().lstrip()
-
- result = None
- if sql_query_parameters is not None or query_string.startswith("select") or query_string.startswith("insert"):
- result = self.process_user_defined_query(query_string, sql_query_parameters)
-
- elif any(k in query_string for k in named_query_keywords) and all(k in query_string for k in ['(', ')']):
- if query_string[-1] != ";":
- query_string += ";"
- query_list = self.query_parser.parse_query(query_string)
- result = self._execute_query_list(query_list)
- else:
- sys.stderr.write(
- "Query invalid. Only named queries and SQL SELECT/INSERT allowed. Please check the query's syntax!\n")
- return
-
- requires_extraction = (isinstance(result, list) or isinstance(result, tuple)) and len(result) == 1 and \
- (not isinstance(result[0], tuple) or len(result[0]) == 1)
- while requires_extraction:
- if isinstance(result, list) or isinstance(result, tuple):
- result = result[0]
- else:
- requires_extraction = False
-
-
- if (isinstance(result, tuple) or isinstance(result, list)) and all(len(val) == 1 for val in result):
- result = [c for c in result for c in c]
-
- if print_results:
- if isinstance(result, list) and len(result) == 1:
- result = result[0]
- print("Query returned 1 record:\n")
- for i in range(0, len(result)):
- print(str(self.cursor.description[i][0]) + ": " + str(result[i]))
- else:
- self._print_query_results(query_string_in, result if isinstance(result, list) else [result])
- return result
- def _print_query_results(self, query_string_in: str, result: typing.List[typing.Union[str, float, int]]) -> None:
- """
- Prints the results of a query.
- Based on http://stackoverflow.com/a/20383011/3017719.
- :param query_string_in: The query the results belong to
- :param result: The list of query results
- """
-
- if len(result) == 1:
- print("Query returned 1 record:\n")
- else:
- print("Query returned " + str(len(result)) + " records:\n")
-
- if query_string_in.lstrip().upper().startswith(
- "SELECT") and result is not None and self.cursor.description is not None:
- widths = []
- columns = []
- tavnit = '|'
- separator = '+'
- for index, cd in enumerate(self.cursor.description):
- max_col_length = 0
- if len(result) > 0:
- max_col_length = max(list(map(lambda x:
- len(str(x[index] if len(self.cursor.description) > 1 else x)),
- result)))
- widths.append(max(len(cd[0]), max_col_length))
- columns.append(cd[0])
- for w in widths:
- tavnit += " %-" + "%ss |" % (w,)
- separator += '-' * w + '--+'
- print(separator)
- print(tavnit % tuple(columns))
- print(separator)
- for row in result:
- print(tavnit % row)
- print(separator)
- else:
- print(result)
|