MapInputCSVToIDs.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. from Attack.MembersMgmtCommAttack import MessageType
  2. # needed because of machine inprecision. E.g A time difference of 0.1s is stored as >0.1s
  3. EPS_TOLERANCE = 1e-13 # works for a difference of 0.1, no less
  4. def find_interval_with_most_comm(packets: list, number_ids: int, max_int_time: float):
  5. """
  6. Finds a time interval of the given seconds where the given number of ids communicate among themselves the most.
  7. :param packets: The packets containing the communication
  8. :param number_ids: The number of ids that are to be considered
  9. :param max_int_time: A short description of the attack.
  10. :return: A triple consisting of the ids, as well as start and end idx with respect to the given packets.
  11. """
  12. def get_nez_msg_counts(msg_counts: dict):
  13. """
  14. Filters out all msg_counts that have 0 as value
  15. """
  16. nez_msg_counts = dict()
  17. for msg in msg_counts.keys():
  18. count = msg_counts[msg]
  19. if count > 0:
  20. nez_msg_counts[msg] = count
  21. return nez_msg_counts
  22. def greater_than(a: float, b: float):
  23. """
  24. A greater than operator desgined to handle slight machine inprecision up to EPS_TOLERANCE.
  25. :return: True if a > b, otherwise False
  26. """
  27. return b - a < -EPS_TOLERANCE
  28. def change_msg_counts(msg_counts: dict, idx: int, add=True):
  29. """
  30. Changes the value of the message count of the message occuring in the packet specified by the given index.
  31. Adds 1 if add is True and subtracts 1 otherwise.
  32. """
  33. change = 1 if add else -1
  34. id_src, id_dst = packets[idx]["Src"], packets[idx]["Dst"]
  35. src_to_dst = "{0}-{1}".format(id_src, id_dst)
  36. dst_to_src = "{0}-{1}".format(id_dst, id_src)
  37. if src_to_dst in msg_counts.keys():
  38. msg_counts[src_to_dst] += change
  39. elif dst_to_src in msg_counts.keys():
  40. msg_counts[dst_to_src] += change
  41. elif add:
  42. msg_counts[src_to_dst] = 1
  43. def count_ids_in_msg_counts(msg_counts: dict):
  44. """
  45. Counts all ids that are involved in messages with a non zero message count
  46. """
  47. ids = set()
  48. for msg in msg_counts.keys():
  49. src, dst = msg.split("-")
  50. ids.add(dst)
  51. ids.add(src)
  52. return len(ids)
  53. def get_msg_count_first_ids(msg_counts: list):
  54. """
  55. Finds the ids that communicate among themselves the most with respect to the given message counts.
  56. :param msg_counts: a sorted list of message counts where each entry is a tuple of key and value
  57. :return: The picked ids and their total message count as a tuple
  58. """
  59. # if order of most messages is important, use an additional list
  60. picked_ids = set()
  61. total_msg_count = 0
  62. # iterate over every message count
  63. for i, msg in enumerate(msg_counts):
  64. count_picked_ids = len(picked_ids)
  65. id_one, id_two = msg[0].split("-")
  66. # if enough ids have been found, stop
  67. if count_picked_ids >= number_ids:
  68. break
  69. # if two ids can be added without exceeding the desired number of ids, add them
  70. if count_picked_ids - 2 <= number_ids:
  71. picked_ids.add(id_one)
  72. picked_ids.add(id_two)
  73. total_msg_count += msg[1]
  74. # if there is only room for one more id to be added,
  75. # find one that is already contained in the picked ids
  76. else:
  77. for j, msg in enumerate(msg_counts[i:]):
  78. id_one, id_two = msg[0].split("-")
  79. if id_one in picked_ids:
  80. picked_ids.add(id_two)
  81. total_msg_count += msg[1]
  82. break
  83. elif id_two in picked_ids:
  84. picked_ids.add(id_one)
  85. total_msg_count += msg[1]
  86. break
  87. break
  88. return picked_ids, total_msg_count
  89. def get_indv_id_counts_and_comms(picked_ids: dict, msg_counts: dict):
  90. """
  91. Retrieves the total mentions of one ID in the communication pattern
  92. and all communication entries that include only picked IDs.
  93. """
  94. indv_id_counts = {}
  95. id_comms = set()
  96. for msg in msg_counts:
  97. ids = msg.split("-")
  98. if ids[0] in picked_ids and ids[1] in picked_ids:
  99. msg_other_dir = "{}-{}".format(ids[1], ids[0])
  100. if (not msg in id_comms) and (not msg_other_dir in id_comms):
  101. id_comms.add(msg)
  102. for id_ in ids:
  103. if id_ in indv_id_counts:
  104. indv_id_counts[id_] += msg_counts[msg]
  105. else:
  106. indv_id_counts[id_] = msg_counts[msg]
  107. return indv_id_counts, id_comms
  108. # first find all possible intervals that contain enough IDs that communicate among themselves
  109. idx_low, idx_high = 0, 0
  110. msg_counts = dict()
  111. possible_intervals = []
  112. # Iterate over all packets from start to finish and process the info of each packet
  113. # If time of packet within time interval, update the message count for this communication
  114. # If time of packet exceeds time interval, substract from the message count for this communication
  115. while True:
  116. if idx_high < len(packets):
  117. cur_int_time = float(packets[idx_high]["Time"]) - float(packets[idx_low]["Time"])
  118. # if current interval time exceeds time interval, save the message counts if appropriate, or stop if no more packets
  119. if greater_than(cur_int_time, max_int_time) or idx_high >= len(packets):
  120. # get all message counts for communications that took place in the current intervall
  121. nez_msg_counts = get_nez_msg_counts(msg_counts)
  122. # if we have enough ids as specified by the caller, mark as possible interval
  123. if count_ids_in_msg_counts(nez_msg_counts) >= number_ids:
  124. # possible_intervals.append((nez_msg_counts, packets[idx_low]["Time"], packets[idx_high-1]["Time"]))
  125. possible_intervals.append((nez_msg_counts, idx_low, idx_high - 1))
  126. if idx_high >= len(packets):
  127. break
  128. # let idx_low "catch up" so that the current interval time fits into the interval time specified by the caller
  129. while greater_than(cur_int_time, max_int_time):
  130. change_msg_counts(msg_counts, idx_low, add=False)
  131. idx_low += 1
  132. cur_int_time = float(packets[idx_high]["Time"]) - float(packets[idx_low]["Time"])
  133. # consume the new packet at idx_high and process its information
  134. change_msg_counts(msg_counts, idx_high)
  135. idx_high += 1
  136. # now find the interval in which as many ids as specified communicate the most in the given time interval
  137. summed_intervals = []
  138. sum_intervals_idxs = []
  139. cur_highest_sum = 0
  140. # for every interval compute the sum of msg_counts of the first most communicative ids and eventually find
  141. # the interval(s) with most communication and its ids
  142. for j, interval in enumerate(possible_intervals):
  143. msg_counts = interval[0].items()
  144. sorted_msg_counts = sorted(msg_counts, key=lambda x: x[1], reverse=True)
  145. picked_ids, msg_sum = get_msg_count_first_ids(sorted_msg_counts)
  146. if msg_sum == cur_highest_sum:
  147. summed_intervals.append({"IDs": picked_ids, "MsgSum": msg_sum, "Start": interval[1], "End": interval[2]})
  148. sum_intervals_idxs.append(j)
  149. elif msg_sum > cur_highest_sum:
  150. summed_intervals = []
  151. sum_intervals_idxs = [j]
  152. summed_intervals.append({"IDs": picked_ids, "MsgSum": msg_sum, "Start": interval[1], "End": interval[2]})
  153. cur_highest_sum = msg_sum
  154. for j, interval in enumerate(summed_intervals):
  155. idx = sum_intervals_idxs[j]
  156. msg_counts_picked = possible_intervals[idx][0]
  157. indv_id_counts, id_comms = get_indv_id_counts_and_comms(interval["IDs"], msg_counts_picked)
  158. interval["IDs"] = indv_id_counts
  159. interval["Comms"] = id_comms
  160. return summed_intervals
  161. def determine_id_roles(packets: list, all_ids: set):
  162. """
  163. Determine the role of every mapped ID. The role can be initiator, responder or both.
  164. :param packets: the mapped section of abstract packets
  165. :param all_ids: all IDs that were mapped/chosen
  166. :return: a dict that for every ID contains its role
  167. """
  168. init_ids, respnd_ids, both_ids = set(), set(), set()
  169. def process_initiator(id_: str):
  170. if id_ in both_ids:
  171. pass
  172. elif not id_ in respnd_ids:
  173. init_ids.add(id_)
  174. elif id_ in respnd_ids:
  175. respnd_ids.remove(id_)
  176. both_ids.add(id_)
  177. def process_responder(id_: str):
  178. if id_ in both_ids:
  179. pass
  180. elif not id_ in init_ids:
  181. respnd_ids.add(id_)
  182. elif id_ in init_ids:
  183. init_ids.remove(id_)
  184. both_ids.add(id_)
  185. for packet in packets:
  186. id_src, id_dst, msg_type = packet["Src"], packet["Dst"], packet["Type"]
  187. if (not id_src in all_ids) or (not id_dst in all_ids):
  188. continue
  189. if int(msg_type) in {MessageType.SALITY_HELLO.value, MessageType.SALITY_NL_REQUEST.value}:
  190. process_initiator(id_src)
  191. process_responder(id_dst)
  192. elif int(msg_type) in {MessageType.SALITY_HELLO_REPLY.value, MessageType.SALITY_NL_REPLY.value}:
  193. process_initiator(id_dst)
  194. process_responder(id_src)
  195. return init_ids, respnd_ids, both_ids