CommunicationProcessor.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. from lea import Lea
  2. from Attack.MembersMgmtCommAttack import MessageType
  3. from Attack.MembersMgmtCommAttack import Message
  4. # needed because of machine inprecision. E.g A time difference of 0.1s is stored as >0.1s
  5. EPS_TOLERANCE = 1e-13 # works for a difference of 0.1, no less
  6. ######### TODO: WIKI ADD VALUE RANGES ##########
  7. class CommunicationProcessor():
  8. """
  9. Class to process parsed input CSV/XML data and retrieve a mapping or other information.
  10. """
  11. def __init__(self, packets:list, mtypes:dict, nat:bool):
  12. self.packets = packets
  13. self.mtypes = mtypes
  14. self.nat = nat
  15. def set_mapping(self, packets: list, mapped_ids: dict):
  16. """
  17. Set the selected mapping for this communication processor.
  18. :param packets: all packets contained in the mapped time frame
  19. :param mapped_ids: the chosen IDs
  20. """
  21. self.packets = packets
  22. self.local_init_ids = set(mapped_ids.keys())
  23. def find_interval_most_comm(self, number_ids: int, max_int_time: float):
  24. """
  25. Finds a time interval of the given seconds where the given number of IDs commuicate the most.
  26. If NAT is active, the most communication is restricted to the most communication by the given number of initiating IDs.
  27. If NAT is inactive, the intervall the most overall communication, that has at least the given number of initiating IDs in it, is chosen.
  28. :param number_ids: The number of IDs that are to be considered
  29. :param max_int_time: A short description of the attack.
  30. :return: A triple consisting of the IDs, as well as start and end idx with respect to the given packets.
  31. """
  32. packets = self.packets
  33. mtypes = self.mtypes
  34. def get_nez_comm_counts(comm_counts: dict):
  35. """
  36. Filters out all msg_counts that have 0 as value
  37. """
  38. nez_comm_counts = dict()
  39. for id_ in comm_counts.keys():
  40. count = comm_counts[id_]
  41. if count > 0:
  42. nez_comm_counts[id_] = count
  43. return nez_comm_counts
  44. def greater_than(a: float, b: float):
  45. """
  46. A greater than operator desgined to handle slight machine inprecision up to EPS_TOLERANCE.
  47. :return: True if a > b, otherwise False
  48. """
  49. return b - a < -EPS_TOLERANCE
  50. def change_comm_counts(comm_counts: dict, idx: int, add=True):
  51. """
  52. Changes the communication count, stored in comm_counts, of the initiating ID with respect to the
  53. packet specified by the given index. If add is True, 1 is added to the value, otherwise 1 is subtracted.
  54. """
  55. change = 1 if add else -1
  56. mtype = mtypes[int(packets[idx]["Type"])]
  57. id_src, id_dst = packets[idx]["Src"], packets[idx]["Dst"]
  58. if mtype in {MessageType.SALITY_HELLO, MessageType.SALITY_NL_REQUEST}:
  59. if id_src in comm_counts:
  60. comm_counts[id_src] += change
  61. elif change > 0:
  62. comm_counts[id_src] = 1
  63. elif mtype in {MessageType.SALITY_HELLO_REPLY, MessageType.SALITY_NL_REPLY}:
  64. if id_dst in comm_counts:
  65. comm_counts[id_dst] += change
  66. elif change > 0:
  67. comm_counts[id_dst] = 1
  68. def get_comm_count_first_ids(comm_counts: list):
  69. """
  70. Finds the IDs that communicate among themselves the most with respect to the given message counts.
  71. :param msg_counts: a sorted list of message counts where each entry is a tuple of key and value
  72. :return: The picked IDs and their total message count as a tuple
  73. """
  74. # if order of most messages is important, use an additional list
  75. picked_ids = {}
  76. total_comm_count = 0
  77. # iterate over every message count
  78. for i, comm in enumerate(comm_counts):
  79. count_picked_ids = len(picked_ids)
  80. # if enough IDs have been found, stop
  81. if count_picked_ids >= number_ids:
  82. break
  83. picked_ids[comm[0]] = comm[1]
  84. total_comm_count += comm[1]
  85. return picked_ids, total_comm_count
  86. # first find all possible intervals that contain enough IDs that initiate communication
  87. idx_low, idx_high = 0, 0
  88. comm_counts = {}
  89. possible_intervals = []
  90. general_comm_sum, cur_highest_sum = 0, 0
  91. # Iterate over all packets from start to finish and process the info of each packet
  92. # If time of packet within time interval, update the message count for this communication
  93. # If time of packet exceeds time interval, substract from the message count for this communication
  94. # Similar to a Sliding Window approach
  95. while True:
  96. if idx_high < len(packets):
  97. cur_int_time = float(packets[idx_high]["Time"]) - float(packets[idx_low]["Time"])
  98. # if current interval time exceeds time interval, save the message counts if appropriate, or stop if no more packets
  99. if greater_than(cur_int_time, max_int_time) or idx_high >= len(packets):
  100. # get all message counts for communications that took place in the current intervall
  101. nez_comm_counts = get_nez_comm_counts(comm_counts)
  102. # if we have enough IDs as specified by the caller, mark as possible interval
  103. if len(nez_comm_counts) >= number_ids:
  104. if self.nat:
  105. possible_intervals.append((nez_comm_counts, idx_low, idx_high-1))
  106. elif general_comm_sum >= cur_highest_sum:
  107. cur_highest_sum = general_comm_sum
  108. possible_intervals.append({"IDs": nez_comm_counts, "CommSum": general_comm_sum, "Start": idx_low, "End": idx_high-1})
  109. general_comm_sum = 0
  110. if idx_high >= len(packets):
  111. break
  112. # let idx_low "catch up" so that the current interval time fits into the interval time specified by the caller
  113. while greater_than(cur_int_time, max_int_time):
  114. change_comm_counts(comm_counts, idx_low, add=False)
  115. idx_low += 1
  116. cur_int_time = float(packets[idx_high]["Time"]) - float(packets[idx_low]["Time"])
  117. # consume the new packet at idx_high and process its information
  118. change_comm_counts(comm_counts, idx_high)
  119. idx_high += 1
  120. general_comm_sum += 1
  121. if self.nat:
  122. # now find the interval in which as many IDs as specified communicate the most in the given time interval
  123. summed_intervals = []
  124. sum_intervals_idxs = []
  125. cur_highest_sum = 0
  126. # for every interval compute the sum of id_counts of the first most communicative IDs and eventually find
  127. # the interval(s) with most communication and its IDs
  128. # on the side also store the communication count of the individual IDs
  129. for j, interval in enumerate(possible_intervals):
  130. comm_counts = interval[0].items()
  131. sorted_comm_counts = sorted(comm_counts, key=lambda x: x[1], reverse=True)
  132. picked_ids, comm_sum = get_comm_count_first_ids(sorted_comm_counts)
  133. if comm_sum == cur_highest_sum:
  134. summed_intervals.append({"IDs": picked_ids, "CommSum": comm_sum, "Start": interval[1], "End": interval[2]})
  135. elif comm_sum > cur_highest_sum:
  136. summed_intervals = []
  137. summed_intervals.append({"IDs": picked_ids, "CommSum": comm_sum, "Start": interval[1], "End": interval[2]})
  138. cur_highest_sum = comm_sum
  139. return summed_intervals
  140. else:
  141. return possible_intervals
  142. def det_id_roles_and_msgs(self):
  143. """
  144. Determine the role of every mapped ID. The role can be initiator, responder or both.
  145. On the side also connect corresponding messages together to quickly find out
  146. which reply belongs to which request and vice versa.
  147. :return: a triple as (initiator IDs, responder IDs, messages)
  148. """
  149. mtypes = self.mtypes
  150. # setup initial variables and their values
  151. respnd_ids = set()
  152. # msgs --> the filtered messages, msg_id --> an increasing ID to give every message an artificial primary key
  153. msgs, msg_id = [], 0
  154. # keep track of previous request to find connections
  155. prev_reqs = {}
  156. local_init_ids = self.local_init_ids
  157. external_init_ids = set()
  158. # process every packet individually
  159. for packet in self.packets:
  160. id_src, id_dst, msg_type, time = packet["Src"], packet["Dst"], int(packet["Type"]), float(packet["Time"])
  161. # if if either one of the IDs is not mapped, continue
  162. if (id_src not in local_init_ids) and (id_dst not in local_init_ids):
  163. continue
  164. # convert message type number to enum type
  165. msg_type = mtypes[msg_type]
  166. # process a request
  167. if msg_type in {MessageType.SALITY_HELLO, MessageType.SALITY_NL_REQUEST}:
  168. if not self.nat and id_dst in local_init_ids and id_src not in local_init_ids:
  169. external_init_ids.add(id_src)
  170. elif id_src not in local_init_ids:
  171. continue
  172. else:
  173. # process ID's role
  174. respnd_ids.add(id_dst)
  175. # convert the abstract message into a message object to handle it better
  176. msg_str = "{0}-{1}".format(id_src, id_dst)
  177. msg = Message(msg_id, id_src, id_dst, msg_type, time)
  178. msgs.append(msg)
  179. prev_reqs[msg_str] = msg_id
  180. msg_id += 1
  181. # process a reply
  182. elif msg_type in {MessageType.SALITY_HELLO_REPLY, MessageType.SALITY_NL_REPLY}:
  183. if not self.nat and id_src in local_init_ids and id_dst not in local_init_ids:
  184. # process ID's role
  185. external_init_ids.add(id_dst)
  186. elif id_dst not in local_init_ids:
  187. continue
  188. else:
  189. # process ID's role
  190. respnd_ids.add(id_src)
  191. # convert the abstract message into a message object to handle it better
  192. msg_str = "{0}-{1}".format(id_dst, id_src)
  193. # find the request message ID for this response and set its reference index
  194. refer_idx = prev_reqs[msg_str]
  195. msgs[refer_idx].refer_msg_id = msg_id
  196. msg = Message(msg_id, id_src, id_dst, msg_type, time, refer_idx)
  197. msgs.append(msg)
  198. # remove the request to this response from storage
  199. del(prev_reqs[msg_str])
  200. msg_id += 1
  201. # store the retrieved information in this object for later use
  202. self.respnd_ids = sorted(respnd_ids)
  203. self.external_init_ids = sorted(external_init_ids)
  204. self.messages = msgs
  205. # return the retrieved information
  206. return self.local_init_ids, self.external_init_ids, self.respnd_ids, self.messages
  207. def det_ext_and_local_ids(self, prob_rspnd_local: int):
  208. """
  209. Map the given IDs to a locality (i.e. local or external} considering the given probabilities.
  210. :param comm_type: the type of communication (i.e. local, external or mixed)
  211. :param prob_rspnd_local: the probabilty that a responder is local
  212. """
  213. external_ids = set()
  214. local_ids = self.local_init_ids.copy()
  215. # set up probabilistic chooser
  216. rspnd_locality = Lea.fromValFreqsDict({"local": prob_rspnd_local*100, "external": (1-prob_rspnd_local)*100})
  217. for id_ in self.external_init_ids:
  218. external_ids.add(id_)
  219. # determine responder localities
  220. for id_ in self.respnd_ids:
  221. if id_ in local_ids or id_ in external_ids:
  222. continue
  223. pos = rspnd_locality.random()
  224. if pos == "local":
  225. local_ids.add(id_)
  226. elif pos == "external":
  227. external_ids.add(id_)
  228. self.local_ids, self.external_ids = local_ids, external_ids
  229. return self.local_ids, self.external_ids