123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333 |
- from lea import Lea
- from Attack.MembersMgmtCommAttack import MessageType
- from Attack.MembersMgmtCommAttack import Message
- EPS_TOLERANCE = 1e-13
- class CommunicationProcessor():
- def __init__(self, packets):
- self.packets = packets
- def set_mapping(self, packets, mapped_ids, id_comms):
- self.packets = packets
- self.ids = mapped_ids.keys()
- self.id_comms = id_comms
- self.indv_id_counts = mapped_ids
- def find_interval_with_most_comm(self, number_ids: int, max_int_time: float):
- """
- Finds a time interval of the given seconds where the given number of ids communicate among themselves the most.
-
- :param packets: The packets containing the communication
- :param number_ids: The number of ids that are to be considered
- :param max_int_time: A short description of the attack.
- :return: A triple consisting of the ids, as well as start and end idx with respect to the given packets.
- """
- packets = self.packets
- def get_nez_msg_counts(msg_counts: dict):
- """
- Filters out all msg_counts that have 0 as value
- """
- nez_msg_counts = dict()
- for msg in msg_counts.keys():
- count = msg_counts[msg]
- if count > 0:
- nez_msg_counts[msg] = count
- return nez_msg_counts
- def greater_than(a: float, b: float):
- """
- A greater than operator desgined to handle slight machine inprecision up to EPS_TOLERANCE.
- :return: True if a > b, otherwise False
- """
- return b - a < -EPS_TOLERANCE
- def change_msg_counts(msg_counts: dict, idx: int, add=True):
- """
- Changes the value of the message count of the message occuring in the packet specified by the given index.
- Adds 1 if add is True and subtracts 1 otherwise.
- """
- change = 1 if add else -1
- id_src, id_dst = packets[idx]["Src"], packets[idx]["Dst"]
- src_to_dst = "{0}-{1}".format(id_src, id_dst)
- dst_to_src = "{0}-{1}".format(id_dst, id_src)
- if src_to_dst in msg_counts.keys():
- msg_counts[src_to_dst] += change
- elif dst_to_src in msg_counts.keys():
- msg_counts[dst_to_src] += change
- elif add:
- msg_counts[src_to_dst] = 1
- def count_ids_in_msg_counts(msg_counts: dict):
- """
- Counts all ids that are involved in messages with a non zero message count
- """
- ids = set()
- for msg in msg_counts.keys():
- src, dst = msg.split("-")
- ids.add(dst)
- ids.add(src)
- return len(ids)
- def get_msg_count_first_ids(msg_counts: list):
- """
- Finds the ids that communicate among themselves the most with respect to the given message counts.
- :param msg_counts: a sorted list of message counts where each entry is a tuple of key and value
- :return: The picked ids and their total message count as a tuple
- """
-
- picked_ids = set()
- total_msg_count = 0
-
- for i, msg in enumerate(msg_counts):
- count_picked_ids = len(picked_ids)
- id_one, id_two = msg[0].split("-")
-
- if count_picked_ids >= number_ids:
- break
-
- if count_picked_ids - 2 <= number_ids:
- picked_ids.add(id_one)
- picked_ids.add(id_two)
- total_msg_count += msg[1]
-
-
- else:
- for j, msg in enumerate(msg_counts[i:]):
- id_one, id_two = msg[0].split("-")
- if id_one in picked_ids:
- picked_ids.add(id_two)
- total_msg_count += msg[1]
- break
- elif id_two in picked_ids:
- picked_ids.add(id_one)
- total_msg_count += msg[1]
- break
- break
- return picked_ids, total_msg_count
- def get_indv_id_counts_and_comms(picked_ids: dict, msg_counts: dict):
- """
- Retrieves the total mentions of one ID in the communication pattern
- and all communication entries that include only picked IDs.
- """
- indv_id_counts = {}
- id_comms = set()
- for msg in msg_counts:
- ids = msg.split("-")
- if ids[0] in picked_ids and ids[1] in picked_ids:
- msg_other_dir = "{}-{}".format(ids[1], ids[0])
- if (not msg in id_comms) and (not msg_other_dir in id_comms):
- id_comms.add(msg)
- for id_ in ids:
- if id_ in indv_id_counts:
- indv_id_counts[id_] += msg_counts[msg]
- else:
- indv_id_counts[id_] = msg_counts[msg]
- return indv_id_counts, id_comms
-
- idx_low, idx_high = 0, 0
- msg_counts = dict()
- possible_intervals = []
-
-
-
- while True:
- if idx_high < len(packets):
- cur_int_time = float(packets[idx_high]["Time"]) - float(packets[idx_low]["Time"])
-
-
- if greater_than(cur_int_time, max_int_time) or idx_high >= len(packets):
-
- nez_msg_counts = get_nez_msg_counts(msg_counts)
-
- if count_ids_in_msg_counts(nez_msg_counts) >= number_ids:
-
- possible_intervals.append((nez_msg_counts, idx_low, idx_high - 1))
- if idx_high >= len(packets):
- break
-
- while greater_than(cur_int_time, max_int_time):
- change_msg_counts(msg_counts, idx_low, add=False)
- idx_low += 1
- cur_int_time = float(packets[idx_high]["Time"]) - float(packets[idx_low]["Time"])
-
- change_msg_counts(msg_counts, idx_high)
- idx_high += 1
-
- summed_intervals = []
- sum_intervals_idxs = []
- cur_highest_sum = 0
-
-
- for j, interval in enumerate(possible_intervals):
- msg_counts = interval[0].items()
- sorted_msg_counts = sorted(msg_counts, key=lambda x: x[1], reverse=True)
- picked_ids, msg_sum = get_msg_count_first_ids(sorted_msg_counts)
- if msg_sum == cur_highest_sum:
- summed_intervals.append({"IDs": picked_ids, "MsgSum": msg_sum, "Start": interval[1], "End": interval[2]})
- sum_intervals_idxs.append(j)
- elif msg_sum > cur_highest_sum:
- summed_intervals = []
- sum_intervals_idxs = [j]
- summed_intervals.append({"IDs": picked_ids, "MsgSum": msg_sum, "Start": interval[1], "End": interval[2]})
- cur_highest_sum = msg_sum
- for j, interval in enumerate(summed_intervals):
- idx = sum_intervals_idxs[j]
- msg_counts_picked = possible_intervals[idx][0]
- indv_id_counts, id_comms = get_indv_id_counts_and_comms(interval["IDs"], msg_counts_picked)
- interval["IDs"] = indv_id_counts
- interval["Comms"] = id_comms
- return summed_intervals
- def det_ext_and_local_ids(self, comm_type: str, prob_init_local: int, prob_rspnd_local: int):
- init_ids, respnd_ids, both_ids = self.init_ids, self.respnd_ids, self.both_ids
- id_comms = self.id_comms
- external_ids = set()
- local_ids = set()
- def map_init_is_local(id_: int):
- for id_comm in id_comms:
- ids = id_comm.split("-")
- other = ids[0] if id_ == ids[1] else ids[1]
-
-
- if other in local_ids or other in external_ids:
- continue
- if comm_type == "mixed":
- other_pos = mixed_respnd_is_local.random()
- if other_pos == "local":
- local_ids.add(other)
- elif other_pos == "external":
- external_ids.add(other)
- elif comm_type == "external":
- if not other in initiators:
- external_ids.add(other)
- def map_init_is_external(id_: int):
- for id_comm in id_comms:
- ids = id_comm.split("-")
- other = ids[0] if id_ == ids[1] else ids[1]
-
-
- if other in local_ids or other in external_ids:
- continue
- if not other in initiators:
- local_ids.add(other)
- if comm_type == "local":
- local_ids = set(mapped_ids.keys())
- else:
- init_local_or_external = Lea.fromValFreqsDict({"local": prob_init_local*100, "external": (1-prob_init_local)*100})
- mixed_respnd_is_local = Lea.fromValFreqsDict({"local": prob_rspnd_local*100, "external": (1-prob_rspnd_local)*100})
-
- initiators = sorted(list(init_ids) + list(both_ids))
- initiators = sorted(initiators, key=lambda id_:self.indv_id_counts[id_], reverse=True)
- for id_ in initiators:
- pos = init_local_or_external.random()
- if pos == "local":
- if id_ in external_ids:
- map_init_is_external(id_)
- else:
- local_ids.add(id_)
- map_init_is_local(id_)
- elif pos == "external":
- if id_ in local_ids:
- map_init_is_local(id_)
- else:
- external_ids.add(id_)
- map_init_is_external(id_)
- self.local_ids, self.external_ids = local_ids, external_ids
- return local_ids, external_ids
- def det_id_roles_and_msgs(self, mtypes: dict):
- """
- Determine the role of every mapped ID. The role can be initiator, responder or both.
- :param packets: the mapped section of abstract packets
- :param all_ids: all IDs that were mapped/chosen
- :return: a dict that for every ID contains its role
- """
- init_ids, respnd_ids, both_ids = set(), set(), set()
- msgs, msg_id = [], 0
- prev_reqs = {}
- all_ids = self.ids
- packets = self.packets
- def process_initiator(id_: str):
- if id_ in both_ids:
- pass
- elif not id_ in respnd_ids:
- init_ids.add(id_)
- elif id_ in respnd_ids:
- respnd_ids.remove(id_)
- both_ids.add(id_)
- def process_responder(id_: str):
- if id_ in both_ids:
- pass
- elif not id_ in init_ids:
- respnd_ids.add(id_)
- elif id_ in init_ids:
- init_ids.remove(id_)
- both_ids.add(id_)
- for packet in packets:
- id_src, id_dst, msg_type, time = packet["Src"], packet["Dst"], int(packet["Type"]), float(packet["Time"])
- if (not id_src in all_ids) or (not id_dst in all_ids):
- continue
- msg_type = mtypes[msg_type]
- if msg_type in {MessageType.SALITY_HELLO, MessageType.SALITY_NL_REQUEST}:
- process_initiator(id_src)
- process_responder(id_dst)
- msg_str = "{0}-{1}".format(id_src, id_dst)
- msg = Message(msg_id, id_src, id_dst, msg_type, time)
- msgs.append(msg)
- prev_reqs[msg_str] = msg_id
- elif msg_type in {MessageType.SALITY_HELLO_REPLY, MessageType.SALITY_NL_REPLY}:
- process_initiator(id_dst)
- process_responder(id_src)
- msg_str = "{0}-{1}".format(id_dst, id_src)
- refer_idx = prev_reqs[msg_str]
- msgs[refer_idx].refer_msg_id = msg_id
-
- msg = Message(msg_id, id_src, id_dst, msg_type, time, refer_idx)
- msgs.append(msg)
- del(prev_reqs[msg_str])
- if not msg_type == MessageType.TIMEOUT:
- msg_id += 1
- self.init_ids, self.respnd_ids, self.both_ids = init_ids, respnd_ids, both_ids
- self.messages = msgs
- return init_ids, respnd_ids, both_ids, msgs
|