CommunicationProcessor.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. from lea import Lea
  2. from Attack.MembersMgmtCommAttack import MessageType
  3. from Attack.MembersMgmtCommAttack import Message
  4. # needed because of machine inprecision. E.g A time difference of 0.1s is stored as >0.1s
  5. EPS_TOLERANCE = 1e-13 # works for a difference of 0.1, no less
  6. class CommunicationProcessor():
  7. def __init__(self, packets):
  8. self.packets = packets
  9. def set_mapping(self, packets, mapped_ids, id_comms):
  10. self.packets = packets
  11. self.ids = mapped_ids.keys()
  12. self.id_comms = id_comms
  13. self.indv_id_counts = mapped_ids
  14. def find_interval_with_most_comm(self, number_ids: int, max_int_time: float):
  15. """
  16. Finds a time interval of the given seconds where the given number of ids communicate among themselves the most.
  17. :param packets: The packets containing the communication
  18. :param number_ids: The number of ids that are to be considered
  19. :param max_int_time: A short description of the attack.
  20. :return: A triple consisting of the ids, as well as start and end idx with respect to the given packets.
  21. """
  22. packets = self.packets
  23. def get_nez_msg_counts(msg_counts: dict):
  24. """
  25. Filters out all msg_counts that have 0 as value
  26. """
  27. nez_msg_counts = dict()
  28. for msg in msg_counts.keys():
  29. count = msg_counts[msg]
  30. if count > 0:
  31. nez_msg_counts[msg] = count
  32. return nez_msg_counts
  33. def greater_than(a: float, b: float):
  34. """
  35. A greater than operator desgined to handle slight machine inprecision up to EPS_TOLERANCE.
  36. :return: True if a > b, otherwise False
  37. """
  38. return b - a < -EPS_TOLERANCE
  39. def change_msg_counts(msg_counts: dict, idx: int, add=True):
  40. """
  41. Changes the value of the message count of the message occuring in the packet specified by the given index.
  42. Adds 1 if add is True and subtracts 1 otherwise.
  43. """
  44. change = 1 if add else -1
  45. id_src, id_dst = packets[idx]["Src"], packets[idx]["Dst"]
  46. src_to_dst = "{0}-{1}".format(id_src, id_dst)
  47. dst_to_src = "{0}-{1}".format(id_dst, id_src)
  48. if src_to_dst in msg_counts.keys():
  49. msg_counts[src_to_dst] += change
  50. elif dst_to_src in msg_counts.keys():
  51. msg_counts[dst_to_src] += change
  52. elif add:
  53. msg_counts[src_to_dst] = 1
  54. def count_ids_in_msg_counts(msg_counts: dict):
  55. """
  56. Counts all ids that are involved in messages with a non zero message count
  57. """
  58. ids = set()
  59. for msg in msg_counts.keys():
  60. src, dst = msg.split("-")
  61. ids.add(dst)
  62. ids.add(src)
  63. return len(ids)
  64. def get_msg_count_first_ids(msg_counts: list):
  65. """
  66. Finds the ids that communicate among themselves the most with respect to the given message counts.
  67. :param msg_counts: a sorted list of message counts where each entry is a tuple of key and value
  68. :return: The picked ids and their total message count as a tuple
  69. """
  70. # if order of most messages is important, use an additional list
  71. picked_ids = set()
  72. total_msg_count = 0
  73. # iterate over every message count
  74. for i, msg in enumerate(msg_counts):
  75. count_picked_ids = len(picked_ids)
  76. id_one, id_two = msg[0].split("-")
  77. # if enough ids have been found, stop
  78. if count_picked_ids >= number_ids:
  79. break
  80. # if two ids can be added without exceeding the desired number of ids, add them
  81. if count_picked_ids - 2 <= number_ids:
  82. picked_ids.add(id_one)
  83. picked_ids.add(id_two)
  84. total_msg_count += msg[1]
  85. # if there is only room for one more id to be added,
  86. # find one that is already contained in the picked ids
  87. else:
  88. for j, msg in enumerate(msg_counts[i:]):
  89. id_one, id_two = msg[0].split("-")
  90. if id_one in picked_ids:
  91. picked_ids.add(id_two)
  92. total_msg_count += msg[1]
  93. break
  94. elif id_two in picked_ids:
  95. picked_ids.add(id_one)
  96. total_msg_count += msg[1]
  97. break
  98. break
  99. return picked_ids, total_msg_count
  100. def get_indv_id_counts_and_comms(picked_ids: dict, msg_counts: dict):
  101. """
  102. Retrieves the total mentions of one ID in the communication pattern
  103. and all communication entries that include only picked IDs.
  104. """
  105. indv_id_counts = {}
  106. id_comms = set()
  107. for msg in msg_counts:
  108. ids = msg.split("-")
  109. if ids[0] in picked_ids and ids[1] in picked_ids:
  110. msg_other_dir = "{}-{}".format(ids[1], ids[0])
  111. if (not msg in id_comms) and (not msg_other_dir in id_comms):
  112. id_comms.add(msg)
  113. for id_ in ids:
  114. if id_ in indv_id_counts:
  115. indv_id_counts[id_] += msg_counts[msg]
  116. else:
  117. indv_id_counts[id_] = msg_counts[msg]
  118. return indv_id_counts, id_comms
  119. # first find all possible intervals that contain enough IDs that communicate among themselves
  120. idx_low, idx_high = 0, 0
  121. msg_counts = dict()
  122. possible_intervals = []
  123. # Iterate over all packets from start to finish and process the info of each packet
  124. # If time of packet within time interval, update the message count for this communication
  125. # If time of packet exceeds time interval, substract from the message count for this communication
  126. while True:
  127. if idx_high < len(packets):
  128. cur_int_time = float(packets[idx_high]["Time"]) - float(packets[idx_low]["Time"])
  129. # if current interval time exceeds time interval, save the message counts if appropriate, or stop if no more packets
  130. if greater_than(cur_int_time, max_int_time) or idx_high >= len(packets):
  131. # get all message counts for communications that took place in the current intervall
  132. nez_msg_counts = get_nez_msg_counts(msg_counts)
  133. # if we have enough ids as specified by the caller, mark as possible interval
  134. if count_ids_in_msg_counts(nez_msg_counts) >= number_ids:
  135. # possible_intervals.append((nez_msg_counts, packets[idx_low]["Time"], packets[idx_high-1]["Time"]))
  136. possible_intervals.append((nez_msg_counts, idx_low, idx_high - 1))
  137. if idx_high >= len(packets):
  138. break
  139. # let idx_low "catch up" so that the current interval time fits into the interval time specified by the caller
  140. while greater_than(cur_int_time, max_int_time):
  141. change_msg_counts(msg_counts, idx_low, add=False)
  142. idx_low += 1
  143. cur_int_time = float(packets[idx_high]["Time"]) - float(packets[idx_low]["Time"])
  144. # consume the new packet at idx_high and process its information
  145. change_msg_counts(msg_counts, idx_high)
  146. idx_high += 1
  147. # now find the interval in which as many ids as specified communicate the most in the given time interval
  148. summed_intervals = []
  149. sum_intervals_idxs = []
  150. cur_highest_sum = 0
  151. # for every interval compute the sum of msg_counts of the first most communicative ids and eventually find
  152. # the interval(s) with most communication and its ids
  153. for j, interval in enumerate(possible_intervals):
  154. msg_counts = interval[0].items()
  155. sorted_msg_counts = sorted(msg_counts, key=lambda x: x[1], reverse=True)
  156. picked_ids, msg_sum = get_msg_count_first_ids(sorted_msg_counts)
  157. if msg_sum == cur_highest_sum:
  158. summed_intervals.append({"IDs": picked_ids, "MsgSum": msg_sum, "Start": interval[1], "End": interval[2]})
  159. sum_intervals_idxs.append(j)
  160. elif msg_sum > cur_highest_sum:
  161. summed_intervals = []
  162. sum_intervals_idxs = [j]
  163. summed_intervals.append({"IDs": picked_ids, "MsgSum": msg_sum, "Start": interval[1], "End": interval[2]})
  164. cur_highest_sum = msg_sum
  165. for j, interval in enumerate(summed_intervals):
  166. idx = sum_intervals_idxs[j]
  167. msg_counts_picked = possible_intervals[idx][0]
  168. indv_id_counts, id_comms = get_indv_id_counts_and_comms(interval["IDs"], msg_counts_picked)
  169. interval["IDs"] = indv_id_counts
  170. interval["Comms"] = id_comms
  171. return summed_intervals
  172. def det_ext_and_local_ids(self, comm_type: str, prob_init_local: int, prob_rspnd_local: int):
  173. init_ids, respnd_ids, both_ids = self.init_ids, self.respnd_ids, self.both_ids
  174. id_comms = self.id_comms
  175. external_ids = set()
  176. local_ids = set()
  177. def map_init_is_local(id_: int):
  178. for id_comm in id_comms:
  179. ids = id_comm.split("-")
  180. other = ids[0] if id_ == ids[1] else ids[1]
  181. # what if before other was external ...
  182. if other in local_ids or other in external_ids:
  183. continue
  184. if comm_type == "mixed":
  185. other_pos = mixed_respnd_is_local.random()
  186. if other_pos == "local":
  187. local_ids.add(other)
  188. elif other_pos == "external":
  189. external_ids.add(other)
  190. elif comm_type == "external":
  191. if not other in initiators:
  192. external_ids.add(other)
  193. def map_init_is_external(id_: int):
  194. for id_comm in id_comms:
  195. ids = id_comm.split("-")
  196. other = ids[0] if id_ == ids[1] else ids[1]
  197. # what if before other was external ...
  198. if other in local_ids or other in external_ids:
  199. continue
  200. if not other in initiators:
  201. local_ids.add(other)
  202. if comm_type == "local":
  203. local_ids = set(mapped_ids.keys())
  204. else:
  205. init_local_or_external = Lea.fromValFreqsDict({"local": prob_init_local*100, "external": (1-prob_init_local)*100})
  206. mixed_respnd_is_local = Lea.fromValFreqsDict({"local": prob_rspnd_local*100, "external": (1-prob_rspnd_local)*100})
  207. # assign IDs in 'both' local everytime for mixed?
  208. initiators = sorted(list(init_ids) + list(both_ids))
  209. initiators = sorted(initiators, key=lambda id_:self.indv_id_counts[id_], reverse=True)
  210. for id_ in initiators:
  211. pos = init_local_or_external.random()
  212. if pos == "local":
  213. if id_ in external_ids:
  214. map_init_is_external(id_)
  215. else:
  216. local_ids.add(id_)
  217. map_init_is_local(id_)
  218. elif pos == "external":
  219. if id_ in local_ids:
  220. map_init_is_local(id_)
  221. else:
  222. external_ids.add(id_)
  223. map_init_is_external(id_)
  224. self.local_ids, self.external_ids = local_ids, external_ids
  225. return local_ids, external_ids
  226. def det_id_roles_and_msgs(self, mtypes: dict):
  227. """
  228. Determine the role of every mapped ID. The role can be initiator, responder or both.
  229. :param packets: the mapped section of abstract packets
  230. :param all_ids: all IDs that were mapped/chosen
  231. :return: a dict that for every ID contains its role
  232. """
  233. init_ids, respnd_ids, both_ids = set(), set(), set()
  234. msgs, msg_id = [], 0
  235. prev_reqs = {}
  236. all_ids = self.ids
  237. packets = self.packets
  238. def process_initiator(id_: str):
  239. if id_ in both_ids:
  240. pass
  241. elif not id_ in respnd_ids:
  242. init_ids.add(id_)
  243. elif id_ in respnd_ids:
  244. respnd_ids.remove(id_)
  245. both_ids.add(id_)
  246. def process_responder(id_: str):
  247. if id_ in both_ids:
  248. pass
  249. elif not id_ in init_ids:
  250. respnd_ids.add(id_)
  251. elif id_ in init_ids:
  252. init_ids.remove(id_)
  253. both_ids.add(id_)
  254. for packet in packets:
  255. id_src, id_dst, msg_type, time = packet["Src"], packet["Dst"], int(packet["Type"]), float(packet["Time"])
  256. if (not id_src in all_ids) or (not id_dst in all_ids):
  257. continue
  258. msg_type = mtypes[msg_type]
  259. if msg_type in {MessageType.SALITY_HELLO, MessageType.SALITY_NL_REQUEST}:
  260. process_initiator(id_src)
  261. process_responder(id_dst)
  262. msg_str = "{0}-{1}".format(id_src, id_dst)
  263. msg = Message(msg_id, id_src, id_dst, msg_type, time)
  264. msgs.append(msg)
  265. prev_reqs[msg_str] = msg_id
  266. elif msg_type in {MessageType.SALITY_HELLO_REPLY, MessageType.SALITY_NL_REPLY}:
  267. process_initiator(id_dst)
  268. process_responder(id_src)
  269. msg_str = "{0}-{1}".format(id_dst, id_src)
  270. refer_idx = prev_reqs[msg_str]
  271. msgs[refer_idx].refer_msg_id = msg_id
  272. # print(msgs[refer_idx])
  273. msg = Message(msg_id, id_src, id_dst, msg_type, time, refer_idx)
  274. msgs.append(msg)
  275. del(prev_reqs[msg_str])
  276. if not msg_type == MessageType.TIMEOUT:
  277. msg_id += 1
  278. self.init_ids, self.respnd_ids, self.both_ids = init_ids, respnd_ids, both_ids
  279. self.messages = msgs
  280. return init_ids, respnd_ids, both_ids, msgs