Browse Source

Improve documentation and split interval mapping function in two (for NAT and no NAT).

dustin.born 7 years ago
parent
commit
684ce05166

+ 20 - 36
code/Attack/MembersMgmtCommAttack.py

@@ -1,4 +1,13 @@
 from enum import Enum
 from enum import Enum
+from random import randint, randrange, choice, uniform
+from collections import deque
+from scipy.stats import gamma
+from lea import Lea
+from datetime import datetime
+
+from Attack import BaseAttack
+from Attack.AttackParameters import Parameter as Param
+from Attack.AttackParameters import ParameterTypes
 
 
 class MessageType(Enum):
 class MessageType(Enum):
     """
     """
@@ -11,6 +20,12 @@ class MessageType(Enum):
     SALITY_HELLO = 103
     SALITY_HELLO = 103
     SALITY_HELLO_REPLY = 104
     SALITY_HELLO_REPLY = 104
 
 
+    def is_request(mtype):
+        return mtype in {MessageType.SALITY_HELLO, MessageType.SALITY_NL_REQUEST}
+
+    def is_response(mtype):
+        return mtype in {MessageType.SALITY_HELLO_REPLY, MessageType.SALITY_NL_REPLY}
+
 class Message():
 class Message():
     """
     """
     Defines a compact message type that contains all necessary information.
     Defines a compact message type that contains all necessary information.
@@ -37,15 +52,6 @@ class Message():
         str_ = "{0}. at {1}: {2}-->{3}, {4}, refer:{5}".format(self.msg_id, self.time, self.src, self.dst, self.type, self.refer_msg_id)
         str_ = "{0}. at {1}: {2}-->{3}, {4}, refer:{5}".format(self.msg_id, self.time, self.src, self.dst, self.type, self.refer_msg_id)
         return str_
         return str_
 
 
-from random import randint, randrange, choice, uniform
-from collections import deque
-from scipy.stats import gamma
-from lea import Lea
-from datetime import datetime
-
-from Attack import BaseAttack
-from Attack.AttackParameters import Parameter as Param
-from Attack.AttackParameters import ParameterTypes
 
 
 from ID2TLib import FileUtils, PaddingGenerator
 from ID2TLib import FileUtils, PaddingGenerator
 from ID2TLib.PacketGenerator import PacketGenerator
 from ID2TLib.PacketGenerator import PacketGenerator
@@ -115,15 +121,12 @@ class MembersMgmtCommAttack(BaseAttack.BaseAttack):
         # PARAMETERS: initialize with default values
         # PARAMETERS: initialize with default values
         # (values are overwritten if user specifies them)
         # (values are overwritten if user specifies them)
         self.add_param_value(Param.INJECT_AFTER_PACKET, randint(1, int(self.statistics.get_packet_count()/5)))
         self.add_param_value(Param.INJECT_AFTER_PACKET, randint(1, int(self.statistics.get_packet_count()/5)))
-        #self.add_param_value(Param.INJECT_AFTER_PACKET, 1)
 
 
         self.add_param_value(Param.PACKETS_PER_SECOND, 0)
         self.add_param_value(Param.PACKETS_PER_SECOND, 0)
         self.add_param_value(Param.FILE_XML, self.DEFAULT_XML_PATH)
         self.add_param_value(Param.FILE_XML, self.DEFAULT_XML_PATH)
 
 
         # Alternatively new attack parameter?
         # Alternatively new attack parameter?
         duration = int(float(self._get_capture_duration()))
         duration = int(float(self._get_capture_duration()))
-        # if duration == 0:
-        #     duration = 1
         self.add_param_value(Param.ATTACK_DURATION, duration)
         self.add_param_value(Param.ATTACK_DURATION, duration)
         self.add_param_value(Param.NUMBER_INITIATOR_BOTS, 1)
         self.add_param_value(Param.NUMBER_INITIATOR_BOTS, 1)
         # NAT on by default
         # NAT on by default
@@ -263,25 +266,6 @@ class MembersMgmtCommAttack(BaseAttack.BaseAttack):
             else:
             else:
                 return 0
                 return 0
 
 
-        def assign_realistic_ttls2(bot_configs):
-            # Gamma distribution parameters derived from MAWI 13.8G dataset
-            ids = sorted(bot_configs.keys())
-            alpha, loc, beta = (2.3261710235, -0.188306914406, 44.4853123884)
-            gd = gamma.rvs(alpha, loc=loc, scale=beta, size=len(ids))
-
-            for pos, bot in enumerate(ids):
-                # print(bot)
-                is_invalid = True
-                pos_max = len(gd)
-                while is_invalid:
-                    ttl = int(round(gd[pos]))
-                    if 0 < ttl < 256:  # validity check
-                        is_invalid = False
-                    else:
-                        pos = index_increment(pos, pos_max)
-                bot_configs[bot]["TTL"] = ttl
-
-
         def assign_realistic_ttls(bot_configs):
         def assign_realistic_ttls(bot_configs):
             '''
             '''
             Assigns a realisitic ttl to each bot from @param: bot_configs. Uses statistics and distribution to be able
             Assigns a realisitic ttl to each bot from @param: bot_configs. Uses statistics and distribution to be able
@@ -325,6 +309,7 @@ class MembersMgmtCommAttack(BaseAttack.BaseAttack):
             unique_offset = uniform(-0.1*general_offset, 0.1*general_offset)
             unique_offset = uniform(-0.1*general_offset, 0.1*general_offset)
             return timestamp + minDelay + general_offset + unique_offset
             return timestamp + minDelay + general_offset + unique_offset
 
 
+
         # parse input CSV or XML
         # parse input CSV or XML
         filepath_xml = self.get_param_value(Param.FILE_XML)
         filepath_xml = self.get_param_value(Param.FILE_XML)
         filepath_csv = self.get_param_value(Param.FILE_CSV)
         filepath_csv = self.get_param_value(Param.FILE_CSV)
@@ -351,14 +336,14 @@ class MembersMgmtCommAttack(BaseAttack.BaseAttack):
         mapped_ids, packet_start_idx, packet_end_idx = comm_interval["IDs"], comm_interval["Start"], comm_interval["End"]
         mapped_ids, packet_start_idx, packet_end_idx = comm_interval["IDs"], comm_interval["Start"], comm_interval["End"]
         while len(mapped_ids) > number_init_bots:
         while len(mapped_ids) > number_init_bots:
             rm_idx = randrange(0, len(mapped_ids))
             rm_idx = randrange(0, len(mapped_ids))
-            del(mapped_ids[sorted(mapped_ids)[rm_idx]])
+            del mapped_ids[rm_idx]
 
 
         # assign the communication processor this mapping for further processing
         # assign the communication processor this mapping for further processing
         comm_proc.set_mapping(abstract_packets[packet_start_idx:packet_end_idx+1], mapped_ids)
         comm_proc.set_mapping(abstract_packets[packet_start_idx:packet_end_idx+1], mapped_ids)
         # print start and end time of mapped interval
         # print start and end time of mapped interval
-        #print(abstract_packets[packet_start_idx]["Time"])
-        #print(abstract_packets[packet_end_idx]["Time"])
-        #print(mapped_ids.keys())
+        # print(abstract_packets[packet_start_idx]["Time"])
+        # print(abstract_packets[packet_end_idx]["Time"])
+        # print(mapped_ids)
 
 
         # determine number of reused local and external IPs
         # determine number of reused local and external IPs
         reuse_percent_total = self.get_param_value(Param.IP_REUSE_TOTAL)
         reuse_percent_total = self.get_param_value(Param.IP_REUSE_TOTAL)
@@ -396,7 +381,6 @@ class MembersMgmtCommAttack(BaseAttack.BaseAttack):
             add_ids_to_config(sorted(external_ids), existing_external_ips, new_external_ips, bot_configs, idtype="external", router_mac=router_mac)
             add_ids_to_config(sorted(external_ids), existing_external_ips, new_external_ips, bot_configs, idtype="external", router_mac=router_mac)
 
 
         #### Set realistic timestamps for messages ####
         #### Set realistic timestamps for messages ####
-
         most_used_ip_address = self.statistics.get_most_used_ip_address()
         most_used_ip_address = self.statistics.get_most_used_ip_address()
         minDelay = self.get_reply_delay(most_used_ip_address)[0]
         minDelay = self.get_reply_delay(most_used_ip_address)[0]
         next_timestamp = self.get_param_value(Param.INJECT_AT_TIMESTAMP)
         next_timestamp = self.get_param_value(Param.INJECT_AT_TIMESTAMP)

+ 174 - 98
code/ID2TLib/CommunicationProcessor.py

@@ -5,6 +5,12 @@ from Attack.MembersMgmtCommAttack import Message
 # needed because of machine inprecision. E.g A time difference of 0.1s is stored as >0.1s
 # needed because of machine inprecision. E.g A time difference of 0.1s is stored as >0.1s
 EPS_TOLERANCE = 1e-13  # works for a difference of 0.1, no less
 EPS_TOLERANCE = 1e-13  # works for a difference of 0.1, no less
 
 
+def greater_than(a: float, b: float):
+    """
+    A greater than operator desgined to handle slight machine inprecision up to EPS_TOLERANCE.
+    :return: True if a > b, otherwise False
+    """
+    return b - a < -EPS_TOLERANCE
 
 
 class CommunicationProcessor():
 class CommunicationProcessor():
     """
     """
@@ -24,146 +30,216 @@ class CommunicationProcessor():
         :param mapped_ids: the chosen IDs
         :param mapped_ids: the chosen IDs
         """
         """
         self.packets = packets
         self.packets = packets
-        self.local_init_ids = set(mapped_ids.keys())
+        self.local_init_ids = set(mapped_ids)
 
 
     def find_interval_most_comm(self, number_ids: int, max_int_time: float):
     def find_interval_most_comm(self, number_ids: int, max_int_time: float):
+        if self.nat:
+            return self._find_interval_most_comm_nat(number_ids, max_int_time)
+        else:
+            return self._find_interval_most_comm_nonat(number_ids, max_int_time)
+
+
+    def _find_interval_most_comm_nonat(self, number_ids: int, max_int_time: float):
         """
         """
-        Finds a time interval of the given seconds where the given number of IDs commuicate the most.
-        If NAT is active, the most communication is restricted to the most communication by the given number of initiating IDs.
-        If NAT is inactive, the intervall the most overall communication, that has at least the given number of initiating IDs in it, is chosen.
-        
-        :param number_ids: The number of IDs that are to be considered
-        :param max_int_time: A short description of the attack.
-        :return: A triple consisting of the IDs, as well as start and end idx with respect to the given packets. 
+        Finds the time interval(s) of the given seconds with the most overall communication (i.e. requests and responses)
+        that has at least number_ids communication initiators in it. 
+        :param number_ids: The number of initiator IDs that have to exist in the interval(s)
+        :param max_int_time: The maximum time period of the interval
+        :return: A list of triples, where each triple contains the initiator IDs, the start index and end index
+                 of the respective interval in that order. The indices are with respect to self.packets
         """
         """
+
+        # setup initial variables
         packets = self.packets
         packets = self.packets
         mtypes = self.mtypes
         mtypes = self.mtypes
+        idx_low, idx_high = 0, 0  # the indices spanning the interval
+        comm_sum = 0  # the communication sum of the current interval
+        cur_highest_sum = 0  # the highest communication sum seen so far
+        ids = []  # the initiator IDs seen in the current interval in order of appearance
+        possible_intervals = []  # all intervals that have cur_highest_sum of communication and contain enough IDs
+
+        # Iterate over all packets from start to finish and process the info of each packet.
+        # Similar to a Sliding Window approach.
+        while True:
+            if idx_high < len(packets):
+                cur_int_time = float(packets[idx_high]["Time"]) - float(packets[idx_low]["Time"])
+     
+            # if current interval time exceeds maximum time period, process information of the current interval
+            if greater_than(cur_int_time, max_int_time) or idx_high >= len(packets):
+                interval_ids = set(ids)
+                # if the interval contains enough initiator IDs, add it to possible_intervals
+                if len(interval_ids) >= number_ids:
+                    interval = {"IDs": sorted(interval_ids), "Start": idx_low, "End": idx_high-1}
+                    # reset possible intervals if new maximum of communication is found
+                    if comm_sum > cur_highest_sum:
+                        possible_intervals = [interval]
+                        cur_highest_sum = comm_sum
+                    # append otherwise
+                    elif comm_sum == cur_highest_sum:
+                        possible_intervals.append(interval)
+
+                # stop if all packets have been processed
+                if idx_high >= len(packets):
+                    break
 
 
-        def get_nez_comm_counts(comm_counts: dict):
-            """
-            Filters out all msg_counts that have 0 as value
-            """
-            nez_comm_counts = dict()
-            for id_ in comm_counts.keys():
-                count = comm_counts[id_]
-                if count > 0:
-                    nez_comm_counts[id_] = count
-            return nez_comm_counts
-
-        def greater_than(a: float, b: float):
+            # let idx_low "catch up" so that the current interval time fits into the maximum time period again
+            while greater_than(cur_int_time, max_int_time):
+                cur_packet = packets[idx_low]
+                # if message was no timeout, delete the first appearance of the initiator ID 
+                # of this packet from the initiator list and update comm_sum
+                if mtypes[int(cur_packet["Type"])] != MessageType.TIMEOUT:
+                    comm_sum -= 1
+                    del ids[0]
+
+                idx_low += 1
+                cur_int_time = float(packets[idx_high]["Time"]) - float(packets[idx_low]["Time"])
+
+            # consume the new packet at idx_high and process its information
+            cur_packet = packets[idx_high]
+            cur_mtype = mtypes[int(cur_packet["Type"])]
+            # if message is request, add src to initiator list
+            if MessageType.is_request(cur_mtype):
+                ids.append(cur_packet["Src"])
+                comm_sum += 1
+            # if message is response, add dst to initiator list
+            elif MessageType.is_response(cur_mtype):
+                ids.append(cur_packet["Dst"])
+                comm_sum += 1
+
+            idx_high += 1
+
+        return possible_intervals
+
+
+    def _find_interval_most_comm_nat(self, number_ids: int, max_int_time: float):
+        """
+        Finds the time interval(s) of the given seconds with the most communication (i.e. requests and responses) 
+        by the most number_ids communicative initiator IDs of the interval.
+        :param number_ids: The number of initiator IDs that have to exist in the interval(s)
+        :param max_int_time: The maximum time period of the interval
+        :return: A list of triples, where each triple contains the initiator IDs, the start index and the end index
+                 of the respective interval in that order. The indices are with respect to self.packets
+        """
+
+        def get_nez_comm_amounts():
             """
             """
-            A greater than operator desgined to handle slight machine inprecision up to EPS_TOLERANCE.
-            :return: True if a > b, otherwise False
+            Filters out all comm_amounts that have 0 as value.
+
+            :return: a dict with initiator IDs as keys and their non-zero communication amount as value
             """
             """
-            return b - a < -EPS_TOLERANCE
 
 
-        def change_comm_counts(comm_counts: dict, idx: int, add=True):
+            nez_comm_amounts = dict()
+            # Iterate over comm_amounts dict and add every entry
+            # with non-zero comm_amount to new dict.
+            for id_ in comm_amounts:
+                amount = comm_amounts[id_]
+                if amount > 0:
+                    nez_comm_amounts[id_] = amount
+            return nez_comm_amounts
+
+        def change_comm_amounts(packet: dict, add:bool=True):
             """
             """
-            Changes the communication count, stored in comm_counts, of the initiating ID with respect to the
-            packet specified by the given index. If add is True, 1 is added to the value, otherwise 1 is subtracted.
+            Changes the communication amount, stored in comm_amounts, of the initiating ID with respect to the
+            packet specified by the given index.
+
+            :param packet: the packet to be processed, containing src and dst ID
+            :param add: If add is True, 1 is added to the communication amount of the IDs, otherwise 1 is subtracted
             """
             """
+
             change = 1 if add else -1
             change = 1 if add else -1
-            mtype = mtypes[int(packets[idx]["Type"])]
-            id_src, id_dst = packets[idx]["Src"], packets[idx]["Dst"]
-            if mtype in {MessageType.SALITY_HELLO, MessageType.SALITY_NL_REQUEST}:
-                if id_src in comm_counts:
-                    comm_counts[id_src] += change
+            mtype = mtypes[int(packet["Type"])]
+            id_src, id_dst = packet["Src"], packet["Dst"]
+            # if message is request, src is initiator
+            if MessageType.is_request(mtype):
+                # if src exists in comm_amounts, add 1 to its amount
+                if id_src in comm_amounts:
+                    comm_amounts[id_src] += change
+                # else if op is add, add the ID with comm value 1 to comm_amounts
                 elif change > 0:
                 elif change > 0:
-                    comm_counts[id_src] = 1
-            elif mtype in {MessageType.SALITY_HELLO_REPLY, MessageType.SALITY_NL_REPLY}:
-                if id_dst in comm_counts:
-                    comm_counts[id_dst] += change
+                    comm_amounts[id_src] = 1
+            # if message is response, dst is initiator
+            elif MessageType.is_response(mtype):
+                # if src exists in comm_amounts, add 1 to its amount
+                if id_dst in comm_amounts:
+                    comm_amounts[id_dst] += change
+                # else if op is add, add the ID with comm value 1 to comm_amounts
                 elif change > 0:
                 elif change > 0:
-                    comm_counts[id_dst] = 1
+                    comm_amounts[id_dst] = 1
 
 
-        def get_comm_count_first_ids(comm_counts: list):
+        def get_comm_amount_first_ids():
             """
             """
-            Finds the IDs that communicate among themselves the most with respect to the given message counts.
-            :param msg_counts: a sorted list of message counts where each entry is a tuple of key and value
-            :return: The picked IDs and their total message count as a tuple
+            Finds the number_ids IDs that communicate the most with respect to nez_comm_amounts
+            :return: The picked IDs as a list and their summed message amount as a tuple like (IDs, sum).
             """
             """
-            # if order of most messages is important, use an additional list
-            picked_ids = {}
-            total_comm_count = 0
 
 
-            # iterate over every message count
-            for i, comm in enumerate(comm_counts):
+            picked_ids = []  # the IDs that have been picked
+            summed_comm_amount = 0  # the summed communication amount of all picked IDs
+            # sort the comm amounts to easily access the IDs with the most communication
+            sorted_comm_amounts = sorted(nez_comm_amounts.items(), key=lambda x: x[1], reverse=True)
+
+            # iterate over the sorted communication amounts
+            for id_, amount in sorted_comm_amounts:
                 count_picked_ids = len(picked_ids)
                 count_picked_ids = len(picked_ids)
 
 
                 # if enough IDs have been found, stop
                 # if enough IDs have been found, stop
                 if count_picked_ids >= number_ids:
                 if count_picked_ids >= number_ids:
                     break
                     break
 
 
-                picked_ids[comm[0]] = comm[1]
-                total_comm_count += comm[1]
-
-            return picked_ids, total_comm_count
-
+                # else pick this ID
+                picked_ids.append(id_)
+                summed_comm_amount += amount
 
 
-        # first find all possible intervals that contain enough IDs that initiate communication
-        idx_low, idx_high = 0, 0
-        comm_counts = {}
-        possible_intervals = []
-        general_comm_sum, cur_highest_sum = 0, 0
+            return picked_ids, summed_comm_amount
 
 
-        # Iterate over all packets from start to finish and process the info of each packet
-        # If time of packet within time interval, update the message count for this communication
-        # If time of packet exceeds time interval, substract from the message count for this communication
-        # Similar to a Sliding Window approach
+        # setup initial variables
+        packets = self.packets
+        mtypes = self.mtypes
+        idx_low, idx_high = 0, 0  # the indices spanning the interval
+        cur_highest_sum = 0  # the highest communication sum seen so far
+        # a dict containing information about what initiator ID has communicated how much
+        comm_amounts = {}  # entry is a tuple of (ID, amount)
+        possible_intervals = []  # all intervals that have cur_highest_sum of communication and contain enough IDs
+
+        # Iterate over all packets from start to finish and process the info of each packet.
+        # Similar to a Sliding Window approach.
         while True:
         while True:
             if idx_high < len(packets):
             if idx_high < len(packets):
                 cur_int_time = float(packets[idx_high]["Time"]) - float(packets[idx_low]["Time"])
                 cur_int_time = float(packets[idx_high]["Time"]) - float(packets[idx_low]["Time"])
      
      
-            # if current interval time exceeds time interval, save the message counts if appropriate, or stop if no more packets
+            # if current interval time exceeds maximum time period, process information of the current interval
             if greater_than(cur_int_time, max_int_time) or idx_high >= len(packets):
             if greater_than(cur_int_time, max_int_time) or idx_high >= len(packets):
-                # get all message counts for communications that took place in the current intervall
-                nez_comm_counts = get_nez_comm_counts(comm_counts)
-                # if we have enough IDs as specified by the caller, mark as possible interval
-                if len(nez_comm_counts) >= number_ids:
-                    if self.nat:
-                        possible_intervals.append((nez_comm_counts, idx_low, idx_high-1))
-                    elif general_comm_sum >= cur_highest_sum:
-                        cur_highest_sum = general_comm_sum
-                        possible_intervals.append({"IDs": nez_comm_counts, "CommSum": general_comm_sum, "Start": idx_low, "End": idx_high-1})
-                        general_comm_sum = 0
-
+                # filter out all IDs with a zero amount of communication for the current interval
+                nez_comm_amounts = get_nez_comm_amounts()
+                # if the interval contains enough initiator IDs, add it to possible_intervals
+                if len(nez_comm_amounts) >= number_ids:
+                    # pick the most communicative IDs and store their sum of communication
+                    picked_ids, comm_sum = get_comm_amount_first_ids()
+                    interval = {"IDs": picked_ids, "Start": idx_low, "End": idx_high-1}
+
+                    # reset possible intervals if new maximum of communication is found
+                    if comm_sum > cur_highest_sum:
+                        possible_intervals = [interval]
+                        cur_highest_sum = comm_sum
+                    # append otherwise
+                    elif comm_sum == cur_highest_sum:
+                        possible_intervals.append(interval)
+
+                # stop if all packets have been processed
                 if idx_high >= len(packets):
                 if idx_high >= len(packets):
                     break
                     break
 
 
-            # let idx_low "catch up" so that the current interval time fits into the interval time specified by the caller
+            # let idx_low "catch up" so that the current interval time fits into the maximum time period again
             while greater_than(cur_int_time, max_int_time):
             while greater_than(cur_int_time, max_int_time):
-                change_comm_counts(comm_counts, idx_low, add=False)
+                # adjust communication amounts to discard the earliest packet of the current interval
+                change_comm_amounts(packets[idx_low], add=False)
                 idx_low += 1
                 idx_low += 1
                 cur_int_time = float(packets[idx_high]["Time"]) - float(packets[idx_low]["Time"])
                 cur_int_time = float(packets[idx_high]["Time"]) - float(packets[idx_low]["Time"])
 
 
             # consume the new packet at idx_high and process its information
             # consume the new packet at idx_high and process its information
-            change_comm_counts(comm_counts, idx_high)
+            change_comm_amounts(packets[idx_high])
             idx_high += 1
             idx_high += 1
-            general_comm_sum += 1
 
 
-        if self.nat:
-            # now find the interval in which as many IDs as specified communicate the most in the given time interval
-            summed_intervals = []
-            sum_intervals_idxs = []
-            cur_highest_sum = 0
-
-            # for every interval compute the sum of id_counts of the first most communicative IDs and eventually find
-            # the interval(s) with most communication and its IDs
-            # on the side also store the communication count of the individual IDs
-            for j, interval in enumerate(possible_intervals):
-                comm_counts = interval[0].items()
-                sorted_comm_counts = sorted(comm_counts, key=lambda x: x[1], reverse=True)
-                picked_ids, comm_sum = get_comm_count_first_ids(sorted_comm_counts)
-
-                if comm_sum == cur_highest_sum:
-                    summed_intervals.append({"IDs": picked_ids, "CommSum": comm_sum, "Start": interval[1], "End": interval[2]})
-                elif comm_sum > cur_highest_sum:
-                    summed_intervals = []
-                    summed_intervals.append({"IDs": picked_ids, "CommSum": comm_sum, "Start": interval[1], "End": interval[2]})
-                    cur_highest_sum = comm_sum
-            return summed_intervals
-        else:
-            return possible_intervals
+        return possible_intervals
 
 
 
 
     def det_id_roles_and_msgs(self):
     def det_id_roles_and_msgs(self):

+ 4 - 0
code/ID2TLib/Controller.py

@@ -84,6 +84,8 @@ class Controller:
         # merge single attack pcap with all attacks into base pcap
         # merge single attack pcap with all attacks into base pcap
         print("Merging base pcap with single attack pcap...", end=" ")
         print("Merging base pcap with single attack pcap...", end=" ")
         sys.stdout.flush()  # force python to print text immediately
         sys.stdout.flush()  # force python to print text immediately
+
+        # cp merged PCAP to output path
         self.pcap_dest_path = self.pcap_file.merge_attack(attacks_pcap_path)
         self.pcap_dest_path = self.pcap_file.merge_attack(attacks_pcap_path)
         if self.pcap_out_path:
         if self.pcap_out_path:
             if not self.pcap_out_path.endswith(".pcap"):
             if not self.pcap_out_path.endswith(".pcap"):
@@ -101,6 +103,8 @@ class Controller:
 
 
         # write label file with attacks
         # write label file with attacks
         self.label_manager.write_label_file(self.pcap_dest_path)
         self.label_manager.write_label_file(self.pcap_dest_path)
+
+        # if MembersMgmtCommAttack created an xml file, move it into input pcap directory
         if created_xml:
         if created_xml:
             pcap_dir = os.path.splitext(self.pcap_dest_path)[0]
             pcap_dir = os.path.splitext(self.pcap_dest_path)[0]
             if "/" in pcap_dir:
             if "/" in pcap_dir:

+ 11 - 30
code/ID2TLib/PcapAddressOperations.py

@@ -21,25 +21,29 @@ class PcapAddressOperations():
     def get_probable_router_mac(self):
     def get_probable_router_mac(self):
         """
         """
         Returns the most probable router MAC address based on the most used MAC address in the statistics.
         Returns the most probable router MAC address based on the most used MAC address in the statistics.
+        :return: the MAC address
         """
         """
         self.probable_router_mac, count = self.statistics.process_db_query("most_used(macAddress)", print_results=False)[0]
         self.probable_router_mac, count = self.statistics.process_db_query("most_used(macAddress)", print_results=False)[0]
         return self.probable_router_mac     # and count as a measure of certainty?
         return self.probable_router_mac     # and count as a measure of certainty?
 
 
     def pcap_contains_priv_ips(self):
     def pcap_contains_priv_ips(self):
         """
         """
-        Returns True if the provided traffic contains private IPs, otherwise False.
+        Returns if the provided traffic contains private IPs.
+        :return: True if the provided traffic contains private IPs, otherwise False
         """
         """
         return self.contains_priv_ips
         return self.contains_priv_ips
 
 
     def get_local_address_range(self):
     def get_local_address_range(self):
         """
         """
         Returns a tuple with the start and end of the observed local IP range.
         Returns a tuple with the start and end of the observed local IP range.
+        :return: The IP range as tuple
         """
         """
         return str(self.min_local_ip), str(self.max_local_ip)
         return str(self.min_local_ip), str(self.max_local_ip)
 
 
     def get_count_rem_local_ips(self):
     def get_count_rem_local_ips(self):
         """
         """
         Returns the number of local IPs in the pcap file that have not aldready been returned by get_existing_local_ips.
         Returns the number of local IPs in the pcap file that have not aldready been returned by get_existing_local_ips.
+        :return: the not yet assigned local IPs
         """
         """
         return len(self.remaining_local_ips)
         return len(self.remaining_local_ips)
 
 
@@ -63,14 +67,8 @@ class PcapAddressOperations():
             retr_local_ips.append(str(random_local_ip))
             retr_local_ips.append(str(random_local_ip))
             local_ips.remove(random_local_ip)
             local_ips.remove(random_local_ip)
 
 
-        # if count == 1:
-        #     return retr_local_ips[0]
-
         return retr_local_ips
         return retr_local_ips
 
 
-    # also use IPs below minimum observed IP?
-    # offset for later, start at x after minimum? e.g. start at 192.168.0.100
-    # exclude the last IP of an IP segment because its broadcast?
     def get_new_local_ips(self, count: int=1):
     def get_new_local_ips(self, count: int=1):
         """
         """
         Returns in the pcap not existent local IPs that are in proximity of the observed local IPs. IPs can be returned
         Returns in the pcap not existent local IPs that are in proximity of the observed local IPs. IPs can be returned
@@ -84,15 +82,9 @@ class PcapAddressOperations():
 
 
         unused_local_ips = self.unused_local_ips
         unused_local_ips = self.unused_local_ips
         uncertain_local_ips = self.uncertain_local_ips
         uncertain_local_ips = self.uncertain_local_ips
-
-        # warning reasonable?
-        if count > len(unused_local_ips):
-            print("Warning: there are no {0} unused certain local IPs in the .pcap file.\n \
-                Returning {1} certain and {2} uncertain local IPs.".format(count, len(unused_local_ips), count-len(unused_local_ips)))
-
         count_certain = min(count, len(unused_local_ips))
         count_certain = min(count, len(unused_local_ips))
-    
         retr_local_ips = []
         retr_local_ips = []
+
         for _ in range(0, count_certain):
         for _ in range(0, count_certain):
             random_local_ip = choice(sorted(unused_local_ips))
             random_local_ip = choice(sorted(unused_local_ips))
             retr_local_ips.append(str(random_local_ip))
             retr_local_ips.append(str(random_local_ip))
@@ -128,9 +120,6 @@ class PcapAddressOperations():
                 random_local_ip = choice(sorted(uncertain_local_ips))
                 random_local_ip = choice(sorted(uncertain_local_ips))
                 retr_local_ips.append(str(random_local_ip))
                 retr_local_ips.append(str(random_local_ip))
                 uncertain_local_ips.remove(random_local_ip)
                 uncertain_local_ips.remove(random_local_ip)
-
-        # if count == 1:
-        #     return retr_local_ips[0]
             
             
         return retr_local_ips
         return retr_local_ips
 
 
@@ -142,17 +131,11 @@ class PcapAddressOperations():
         :return: the chosen external IPs
         :return: the chosen external IPs
         """
         """
 
 
-        # reasonable to include this?
         if not (len(self.external_ips) > 0):
         if not (len(self.external_ips) > 0):
             print("Warning: .pcap does not contain any external ips.")
             print("Warning: .pcap does not contain any external ips.")
             return []
             return []
 
 
-        if count > len(self.remaining_external_ips):
-            print("Warning: There are no more %d external IPs in the .pcap file.\n" % count +
-                "Returning all %d existing external IPs." % len(self.remaining_external_ips))
-
         total = min(len(self.remaining_external_ips), count)
         total = min(len(self.remaining_external_ips), count)
-
         retr_external_ips = []
         retr_external_ips = []
         external_ips = self.remaining_external_ips
         external_ips = self.remaining_external_ips
 
 
@@ -161,9 +144,6 @@ class PcapAddressOperations():
             retr_external_ips.append(str(random_external_ip))
             retr_external_ips.append(str(random_external_ip))
             external_ips.remove(random_external_ip)
             external_ips.remove(random_external_ip)
 
 
-        # if count == 1:
-        #     return retr_external_ips[0]
-
         return retr_external_ips
         return retr_external_ips
 
 
     def _init_ipaddress_ops(self):
     def _init_ipaddress_ops(self):
@@ -173,7 +153,7 @@ class PcapAddressOperations():
 
 
         # retrieve local and external IPs
         # retrieve local and external IPs
         all_ips_str = set(self.statistics.process_db_query("all(ipAddress)", print_results=False))
         all_ips_str = set(self.statistics.process_db_query("all(ipAddress)", print_results=False))
-        external_ips_str = set(self.statistics.process_db_query("ipAddress(macAddress=%s)" % self.get_probable_router_mac(), print_results=False))
+        external_ips_str = set(self.statistics.process_db_query("ipAddress(macAddress=%s)" % self.get_probable_router_mac(), print_results=False))  # including router
         local_ips_str = all_ips_str - external_ips_str
         local_ips_str = all_ips_str - external_ips_str
         external_ips = set()
         external_ips = set()
         local_ips = set()
         local_ips = set()
@@ -187,6 +167,7 @@ class PcapAddressOperations():
                 if ip.is_private() and not self.contains_priv_ips:
                 if ip.is_private() and not self.contains_priv_ips:
                     self.contains_priv_ips = True
                     self.contains_priv_ips = True
                     self.priv_ip_segment = ip.get_private_segment()
                     self.priv_ip_segment = ip.get_private_segment()
+                # exclude local broadcast address and other special addresses
                 if (not str(ip) == "255.255.255.255") and (not ip.is_localhost()) and (not ip.is_multicast()) and (not ip.is_reserved()) and (not ip.is_zero_conf()):
                 if (not str(ip) == "255.255.255.255") and (not ip.is_localhost()) and (not ip.is_multicast()) and (not ip.is_reserved()) and (not ip.is_zero_conf()):
                     local_ips.add(ip)
                     local_ips.add(ip)
 
 
@@ -194,11 +175,11 @@ class PcapAddressOperations():
         for ip in external_ips_str:
         for ip in external_ips_str:
             if is_ipv4(ip):
             if is_ipv4(ip):
                 ip = IPAddress.parse(ip)
                 ip = IPAddress.parse(ip)
-                # if router MAC can definitely be mapped to local/private IP, add it to local_ips
+                # if router MAC can definitely be mapped to local/private IP, add it to local_ips (because at first it is stored in external_ips, see above)
+                # this depends on whether the local network is identified by a private IP address range or not.
                 if ip.is_private():
                 if ip.is_private():
                     local_ips.add(ip)
                     local_ips.add(ip)
-                # new function in IPv4 to shorten this?
-                # exclude local broadcast address
+                # exclude local broadcast address and other special addresses
                 elif (not str(ip) == "255.255.255.255") and (not ip.is_localhost()) and (not ip.is_multicast()) and (not ip.is_reserved()) and (not ip.is_zero_conf()):
                 elif (not str(ip) == "255.255.255.255") and (not ip.is_localhost()) and (not ip.is_multicast()) and (not ip.is_reserved()) and (not ip.is_zero_conf()):
                     external_ips.add(ip)
                     external_ips.add(ip)