StatsDatabase.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462
  1. import os.path
  2. import random as rnd
  3. import typing
  4. import sqlite3
  5. import sys
  6. # TODO: double check this import
  7. # does it complain because libpcapreader is not a .py?
  8. import ID2TLib.libpcapreader as pr
  9. import Core.QueryParser as qp
  10. import pyparsing as pp
  11. def dict_gen(curs: sqlite3.Cursor):
  12. """
  13. Generates a dictionary of a sqlite3.Cursor object by fetching the query's results.
  14. Taken from Python Essential Reference by David Beazley.
  15. """
  16. field_names = [d[0] for d in curs.description]
  17. while True:
  18. rows = curs.fetchmany()
  19. if not rows:
  20. return
  21. for row in rows:
  22. yield dict(zip(field_names, row))
  23. class QueryExecutionException(Exception):
  24. pass
  25. class StatsDatabase:
  26. def __init__(self, db_path: str):
  27. """
  28. Creates a new StatsDatabase.
  29. :param db_path: The path to the database file
  30. """
  31. self.query_parser = qp.QueryParser()
  32. self.existing_db = os.path.exists(db_path)
  33. self.database = sqlite3.connect(db_path)
  34. self.cursor = self.database.cursor()
  35. self.current_interval_statistics_tables = []
  36. # If DB not existing, create a new DB scheme
  37. if self.existing_db:
  38. if self.get_db_outdated():
  39. print('Statistics database outdated. Recreating database at: ', db_path)
  40. else:
  41. print('Located statistics database at: ', db_path)
  42. else:
  43. print('Statistics database not found. Creating new database at: ', db_path)
  44. def get_file_info(self):
  45. """
  46. Retrieves general file statistics from the database. This includes:
  47. - packetCount : Number of packets in the PCAP file
  48. - captureDuration : Duration of the packet capture in seconds
  49. - timestampFirstPacket : Timestamp of the first captured packet
  50. - timestampLastPacket : Timestamp of the last captured packet
  51. - avgPacketRate : Average packet rate
  52. - avgPacketSize : Average packet size
  53. - avgPacketsSentPerHost : Average number of packets sent per host
  54. - avgBandwidthIn : Average incoming bandwidth
  55. - avgBandwidthOut : Average outgoing bandwidth
  56. :return: a dictionary of keys (see above) and their respective values
  57. """
  58. return [r for r in dict_gen(
  59. self.cursor.execute('SELECT * FROM file_statistics'))][0]
  60. def get_db_exists(self):
  61. """
  62. :return: True if the database was already existent, otherwise False
  63. """
  64. return self.existing_db
  65. def get_db_outdated(self):
  66. """
  67. Retrieves the database version from the database and compares it to the version
  68. it should have to check whether the database is outdated and needs to be recreated.
  69. :return: True if the versions match, otherwise False
  70. """
  71. self.cursor.execute('PRAGMA user_version;')
  72. return self.cursor.fetchall()[0][0] != pr.pcap_processor.get_db_version()
  73. @staticmethod
  74. def _get_selector_keywords():
  75. """
  76. :return: a list of selector keywords
  77. """
  78. return ['most_used', 'least_used', 'avg', 'all']
  79. @staticmethod
  80. def _get_parametrized_selector_keywords():
  81. """
  82. :return: a list of parameterizable selector keywords
  83. """
  84. return ['ipaddress', 'macaddress']
  85. @staticmethod
  86. def _get_extractor_keywords():
  87. """
  88. :return: a list of extractor keywords
  89. """
  90. return ['random', 'first', 'last']
  91. def get_all_named_query_keywords(self):
  92. """
  93. :return: a list of all named query keywords, used to identify named queries
  94. """
  95. return (
  96. self._get_selector_keywords() + self._get_parametrized_selector_keywords() + self._get_extractor_keywords())
  97. @staticmethod
  98. def get_all_sql_query_keywords():
  99. """
  100. :return: a list of all supported SQL keywords, used to identify SQL queries
  101. """
  102. return ["select", "insert"]
  103. def process_user_defined_query(self, query_string: str, query_parameters: tuple = None):
  104. """
  105. Takes as input a SQL query query_string and optional a tuple of parameters which are marked by '?' in the query
  106. and later substituted.
  107. :param query_string: The query to execute
  108. :param query_parameters: The tuple of parameters to inject into the query
  109. :return: the results of the query
  110. """
  111. if query_parameters is not None:
  112. self.cursor.execute(query_string, query_parameters)
  113. else:
  114. self.cursor.execute(query_string)
  115. self.database.commit()
  116. return self.cursor.fetchall()
  117. def get_field_types(self, *table_names):
  118. """
  119. Creates a dictionary whose keys are the fields of the given table(s) and whose values are the appropriate field
  120. types, like TEXT for strings and REAL for float numbers.
  121. :param table_names: The name of table(s)
  122. :return: a dictionary of {field_name : field_type} for fields of all tables
  123. """
  124. dic = {}
  125. for table in table_names:
  126. self.cursor.execute("PRAGMA table_info('%s')" % table)
  127. results = self.cursor.fetchall()
  128. for field in results:
  129. dic[field[1].lower()] = field[2]
  130. return dic
  131. def get_current_interval_statistics_table(self):
  132. """
  133. :return: the current interval statistics table used for internal calculations
  134. """
  135. if len(self.current_interval_statistics_tables) > 0:
  136. return self.current_interval_statistics_tables[0]
  137. else:
  138. return ""
  139. def get_all_current_interval_statistics_tables(self):
  140. """
  141. :return: the list of all current interval statistics tables
  142. """
  143. if len(self.current_interval_statistics_tables) == 0:
  144. return [self.process_db_query("SELECT name FROM interval_tables WHERE is_default=1")]
  145. return self.current_interval_statistics_tables
  146. def set_current_interval_statistics_tables(self, current_intervals: list):
  147. """
  148. Sets the current interval statistics table, which should be used for internal calculations.
  149. :param current_intervals: a list of current intervals in seconds, first of which should be used for internal
  150. calculations
  151. """
  152. for current_interval in current_intervals:
  153. if current_interval == 0.0:
  154. table_name = self.process_db_query("SELECT name FROM interval_tables WHERE is_default=1")
  155. if table_name != []:
  156. substr = "Using default interval: " + str(float(table_name[len("interval_statistics_"):])/1000000) \
  157. + "s"
  158. else:
  159. substr = "The default interval will used after it is calculated."
  160. print("No user specified interval found. " + substr)
  161. else:
  162. self.current_interval_statistics_tables.append("interval_statistics_" +
  163. str(int(current_interval*1000000)))
  164. if current_interval == current_intervals[0]:
  165. print("User specified interval(s) found. Using first interval length given for internal "
  166. "calculations: " + str(current_interval) + "s")
  167. def named_query_parameterized(self, keyword: str, param_op_val: list):
  168. """
  169. Executes a parameterizable named query.
  170. :param keyword: The query to be executed, like ipaddress or macadress
  171. :param param_op_val: A list consisting of triples with (parameter, operator, value)
  172. :return: the results of the executed query
  173. """
  174. named_queries = {
  175. "ipaddress": "SELECT DISTINCT ip_statistics.ipAddress from ip_statistics INNER JOIN ip_mac, ip_ttl, "
  176. "ip_ports, ip_protocols ON ip_statistics.ipAddress=ip_mac.ipAddress AND "
  177. "ip_statistics.ipAddress=ip_ttl.ipAddress AND ip_statistics.ipAddress=ip_ports.ipAddress "
  178. "AND ip_statistics.ipAddress=ip_protocols.ipAddress WHERE ",
  179. "macaddress": "SELECT DISTINCT macAddress from ip_mac WHERE "}
  180. query = named_queries.get(keyword)
  181. field_types = self.get_field_types('ip_mac', 'ip_ttl', 'ip_ports', 'ip_protocols', 'ip_statistics', 'ip_mac')
  182. conditions = []
  183. for key, op, value in param_op_val:
  184. # Check whether the value is not a simple value, but another query (or list)
  185. if isinstance(value, pp.ParseResults):
  186. if value[0] == "list":
  187. # We have a list, cut the token off and use the remaining elements
  188. value = value[1:]
  189. # Lists can only be used with "in"
  190. if op is not "in":
  191. raise QueryExecutionException("List values require the usage of the 'in' operator!")
  192. else:
  193. # If we have another query instead of a direct value, execute and replace it
  194. rvalue = self._execute_query_list(value)
  195. # Do we have a comparison operator with a multiple-result query?
  196. if op is not "in" and value[0] in ['most_used', 'least_used', 'all', 'ipaddress_param',
  197. 'macaddress_param']:
  198. raise QueryExecutionException("The extractor '" + value[0] +
  199. "' may return more than one result!")
  200. # Make value contain a simple list with the results of the query
  201. value = map(lambda x: str(x[0]), rvalue)
  202. else:
  203. # Make sure value is a list now to simplify handling
  204. value = [value]
  205. # this makes sure that TEXT fields are queried by strings,
  206. # e.g. ipAddress=192.168.178.1 --is-converted-to--> ipAddress='192.168.178.1'
  207. if field_types.get(key) == 'TEXT':
  208. def ensure_string(x):
  209. if not str(x).startswith("'") and not str(x).startswith('"'):
  210. return "'" + x + "'"
  211. else:
  212. return x
  213. value = map(ensure_string, value)
  214. # If we have more than one value, join them together, separated by commas
  215. value = ",".join(map(str, value))
  216. # this replacement is required to remove ambiguity in SQL query
  217. if key == 'ipAddress':
  218. key = 'ip_mac.ipAddress'
  219. conditions.append(key + " " + op + " (" + str(value) + ")")
  220. where_clause = " AND ".join(conditions)
  221. query += where_clause
  222. self.cursor.execute(query)
  223. return self.cursor.fetchall()
  224. named_queries = {
  225. "most_used.ipaddress": "SELECT ipAddress FROM ip_statistics WHERE (pktsSent+pktsReceived) == "
  226. "(SELECT MAX(pktsSent+pktsReceived) from ip_statistics) ORDER BY ipAddress ASC",
  227. "most_used.macaddress": "SELECT macAddress FROM (SELECT macAddress, COUNT(*) as occ from ip_mac GROUP BY "
  228. "macAddress) WHERE occ=(SELECT COUNT(*) as occ from ip_mac GROUP BY macAddress "
  229. "ORDER BY occ DESC LIMIT 1) ORDER BY macAddress ASC",
  230. "most_used.portnumber": "SELECT portNumber FROM ip_ports GROUP BY portNumber HAVING COUNT(portNumber)="
  231. "(SELECT MAX(cntPort) from (SELECT portNumber, COUNT(portNumber) as cntPort FROM "
  232. "ip_ports GROUP BY portNumber)) ORDER BY portNumber ASC",
  233. "most_used.protocolname": "SELECT protocolName FROM ip_protocols GROUP BY protocolName HAVING "
  234. "COUNT(protocolCount)=(SELECT COUNT(protocolCount) as cnt FROM ip_protocols "
  235. "GROUP BY protocolName ORDER BY cnt DESC LIMIT 1) ORDER BY protocolName ASC",
  236. "most_used.ttlvalue": "SELECT ttlValue FROM (SELECT ttlValue, SUM(ttlCount) as occ FROM ip_ttl GROUP BY "
  237. "ttlValue) WHERE occ=(SELECT SUM(ttlCount) as occ FROM ip_ttl GROUP BY ttlValue "
  238. "ORDER BY occ DESC LIMIT 1) ORDER BY ttlValue ASC",
  239. "most_used.mssvalue": "SELECT mssValue FROM (SELECT mssValue, SUM(mssCount) as occ FROM tcp_mss GROUP BY "
  240. "mssValue) WHERE occ=(SELECT SUM(mssCount) as occ FROM tcp_mss GROUP BY mssValue "
  241. "ORDER BY occ DESC LIMIT 1) ORDER BY mssValue ASC",
  242. "most_used.winsize": "SELECT winSize FROM (SELECT winSize, SUM(winCount) as occ FROM tcp_win GROUP BY "
  243. "winSize) WHERE occ=(SELECT SUM(winCount) as occ FROM tcp_win GROUP BY winSize ORDER "
  244. "BY occ DESC LIMIT 1) ORDER BY winSize ASC",
  245. "most_used.ipclass": "SELECT ipClass FROM (SELECT ipClass, COUNT(*) as occ from ip_statistics GROUP BY "
  246. "ipClass ORDER BY occ DESC) WHERE occ=(SELECT COUNT(*) as occ from ip_statistics "
  247. "GROUP BY ipClass ORDER BY occ DESC LIMIT 1) ORDER BY ipClass ASC",
  248. "least_used.ipaddress": "SELECT ipAddress FROM ip_statistics WHERE (pktsSent+pktsReceived) == (SELECT "
  249. "MIN(pktsSent+pktsReceived) from ip_statistics) ORDER BY ipAddress ASC",
  250. "least_used.macaddress": "SELECT macAddress FROM (SELECT macAddress, COUNT(*) as occ from ip_mac GROUP "
  251. "BY macAddress) WHERE occ=(SELECT COUNT(*) as occ from ip_mac GROUP BY macAddress "
  252. "ORDER BY occ ASC LIMIT 1) ORDER BY macAddress ASC",
  253. "least_used.portnumber": "SELECT portNumber FROM ip_ports GROUP BY portNumber HAVING COUNT(portNumber)="
  254. "(SELECT MIN(cntPort) from (SELECT portNumber, COUNT(portNumber) as cntPort FROM "
  255. "ip_ports GROUP BY portNumber)) ORDER BY portNumber ASC",
  256. "least_used.protocolname": "SELECT protocolName FROM ip_protocols GROUP BY protocolName HAVING "
  257. "COUNT(protocolCount)=(SELECT COUNT(protocolCount) as cnt FROM ip_protocols "
  258. "GROUP BY protocolName ORDER BY cnt ASC LIMIT 1) ORDER BY protocolName ASC",
  259. "least_used.ttlvalue": "SELECT ttlValue FROM (SELECT ttlValue, SUM(ttlCount) as occ FROM ip_ttl GROUP BY "
  260. "ttlValue) WHERE occ=(SELECT SUM(ttlCount) as occ FROM ip_ttl GROUP BY ttlValue "
  261. "ORDER BY occ ASC LIMIT 1) ORDER BY ttlValue ASC",
  262. "least_used.mssvalue": "SELECT mssValue FROM (SELECT mssValue, SUM(mssCount) as occ FROM tcp_mss GROUP BY "
  263. "mssValue) WHERE occ=(SELECT SUM(mssCount) as occ FROM tcp_mss GROUP BY mssValue "
  264. "ORDER BY occ ASC LIMIT 1) ORDER BY mssValue ASC",
  265. "least_used.winsize": "SELECT winSize FROM (SELECT winSize, SUM(winCount) as occ FROM tcp_win GROUP BY "
  266. "winSize) WHERE occ=(SELECT SUM(winCount) as occ FROM tcp_win GROUP BY winSize "
  267. "ORDER BY occ ASC LIMIT 1) ORDER BY winSize ASC",
  268. "least_used.ipclass": "SELECT ipClass FROM (SELECT ipClass, COUNT(*) as occ from ip_statistics GROUP BY "
  269. "ipClass ORDER BY occ DESC) WHERE occ=(SELECT COUNT(*) as occ from ip_statistics "
  270. "GROUP BY ipClass ORDER BY occ ASC LIMIT 1) ORDER BY ipClass ASC",
  271. "avg.pktsreceived": "SELECT avg(pktsReceived) from ip_statistics",
  272. "avg.pktssent": "SELECT avg(pktsSent) from ip_statistics",
  273. "avg.kbytesreceived": "SELECT avg(kbytesReceived) from ip_statistics",
  274. "avg.kbytessent": "SELECT avg(kbytesSent) from ip_statistics",
  275. "avg.ttlvalue": "SELECT avg(ttlValue) from ip_ttl",
  276. "avg.mss": "SELECT avg(mssValue) from tcp_mss",
  277. "all.ipaddress": "SELECT ipAddress from ip_statistics ORDER BY ipAddress ASC",
  278. "all.ttlvalue": "SELECT DISTINCT ttlValue from ip_ttl ORDER BY ttlValue ASC",
  279. "all.mss": "SELECT DISTINCT mssValue from tcp_mss ORDER BY mssValue ASC",
  280. "all.macaddress": "SELECT DISTINCT macAddress from ip_mac ORDER BY macAddress ASC",
  281. "all.portnumber": "SELECT DISTINCT portNumber from ip_ports ORDER BY portNumber ASC",
  282. "all.protocolname": "SELECT DISTINCT protocolName from ip_protocols ORDER BY protocolName ASC",
  283. "all.winsize": "SELECT DISTINCT winSize FROM tcp_win ORDER BY winSize ASC",
  284. "all.ipclass": "SELECT DISTINCT ipClass FROM ip_statistics ORDER BY ipClass ASC"}
  285. def _execute_query_list(self, query_list):
  286. """
  287. Recursively executes a list of named queries. They are of the following form:
  288. ['macaddress_param', [['ipaddress', 'in', ['most_used', 'ipaddress']]]]
  289. :param query_list: The query statement list obtained from the query parser
  290. :return: The result of the query (either a single result or a list).
  291. """
  292. if query_list[0] == "random":
  293. return [rnd.choice(self._execute_query_list(query_list[1:]))]
  294. elif query_list[0] == "first":
  295. return [self._execute_query_list(query_list[1:])[0]]
  296. elif query_list[0] == "last":
  297. return [self._execute_query_list(query_list[1:])[-1]]
  298. elif query_list[0] == "macaddress_param":
  299. return self.named_query_parameterized("macaddress", query_list[1])
  300. elif query_list[0] == "ipaddress_param":
  301. return self.named_query_parameterized("ipaddress", query_list[1])
  302. else:
  303. query = self.named_queries.get(query_list[0] + "." + query_list[1])
  304. if query is None:
  305. raise QueryExecutionException("The requested query '" + query_list[0] + "(" + query_list[1] +
  306. ")' was not found in the internal query list!")
  307. self.cursor.execute(str(query))
  308. # TODO: fetch query on demand
  309. last_result = self.cursor.fetchall()
  310. return last_result
  311. def process_db_query(self, query_string_in: str, print_results=False, sql_query_parameters: tuple = None):
  312. """
  313. Processes a database query. This can either be a standard SQL query or a named query (predefined query).
  314. :param query_string_in: The string containing the query
  315. :param print_results: Indicated whether the results should be printed to terminal (True) or not (False)
  316. :param sql_query_parameters: Parameters for the SQL query (optional)
  317. :return: the results of the query
  318. """
  319. named_query_keywords = self.get_all_named_query_keywords()
  320. # Clean query_string
  321. query_string = query_string_in.lower().lstrip()
  322. # query_string is a user-defined SQL query
  323. result = None
  324. if sql_query_parameters is not None or query_string.startswith("select") or query_string.startswith("insert"):
  325. result = self.process_user_defined_query(query_string, sql_query_parameters)
  326. # query string is a named query -> parse it and pass it to statisticsDB
  327. elif any(k in query_string for k in named_query_keywords) and all(k in query_string for k in ['(', ')']):
  328. if query_string[-1] != ";":
  329. query_string += ";"
  330. query_list = self.query_parser.parse_query(query_string)
  331. result = self._execute_query_list(query_list)
  332. else:
  333. sys.stderr.write(
  334. "Query invalid. Only named queries and SQL SELECT/INSERT allowed. Please check the query's syntax!\n")
  335. return
  336. # If result is tuple/list with single element, extract value from list
  337. requires_extraction = (isinstance(result, list) or isinstance(result, tuple)) and len(result) == 1 and \
  338. (not isinstance(result[0], tuple) or len(result[0]) == 1)
  339. while requires_extraction:
  340. if isinstance(result, list) or isinstance(result, tuple):
  341. result = result[0]
  342. else:
  343. requires_extraction = False
  344. # If tuple of tuples or list of tuples, each consisting of single element is returned,
  345. # then convert it into list of values, because the returned column is clearly specified by the given query
  346. if (isinstance(result, tuple) or isinstance(result, list)) and all(len(val) == 1 for val in result):
  347. result = [c for c in result for c in c]
  348. # Print results if option print_results is True
  349. if print_results:
  350. if isinstance(result, list) and len(result) == 1:
  351. result = result[0]
  352. print("Query returned 1 record:\n")
  353. for i in range(0, len(result)):
  354. print(str(self.cursor.description[i][0]) + ": " + str(result[i]))
  355. else:
  356. self._print_query_results(query_string_in, result if isinstance(result, list) else [result])
  357. return result
  358. def process_interval_statistics_query(self, query_string_in: str, table_param: str=""):
  359. """
  360. :param query_string_in: a query to be executed over the current internal interval statistics table
  361. :param table_param: a name of a specific interval statistics table
  362. :return: the result of the query
  363. """
  364. if table_param != "":
  365. table_name = table_param
  366. elif self.get_current_interval_statistics_table() != "":
  367. table_name = self.get_current_interval_statistics_table()
  368. else:
  369. table_name = self.process_db_query("SELECT name FROM interval_tables WHERE is_default=1")
  370. return self.process_user_defined_query(query_string_in % table_name)
  371. def _print_query_results(self, query_string_in: str, result: typing.List[typing.Union[str, float, int]]) -> None:
  372. """
  373. Prints the results of a query.
  374. Based on http://stackoverflow.com/a/20383011/3017719.
  375. :param query_string_in: The query the results belong to
  376. :param result: The list of query results
  377. """
  378. # Print number of results according to type of result
  379. if len(result) == 1:
  380. print("Query returned 1 record:\n")
  381. else:
  382. print("Query returned " + str(len(result)) + " records:\n")
  383. # Print query results
  384. if query_string_in.lstrip().upper().startswith(
  385. "SELECT") and result is not None and self.cursor.description is not None:
  386. widths = []
  387. columns = []
  388. tavnit = '|'
  389. separator = '+'
  390. for index, cd in enumerate(self.cursor.description):
  391. max_col_length = 0
  392. if len(result) > 0:
  393. max_col_length = max(list(map(lambda x:
  394. len(str(x[index] if len(self.cursor.description) > 1 else x)),
  395. result)))
  396. widths.append(max(len(cd[0]), max_col_length))
  397. columns.append(cd[0])
  398. for w in widths:
  399. tavnit += " %-" + "%ss |" % (w,)
  400. separator += '-' * w + '--+'
  401. print(separator)
  402. print(tavnit % tuple(columns))
  403. print(separator)
  404. for row in result:
  405. print(tavnit % row)
  406. print(separator)
  407. else:
  408. print(result)