Kaynağa Gözat

refactor imports in Attack/BaseAttack.py

refactor whitspaces in Attack/BaseAttack.py
deep refactoring Attack/BaseAttack.py
Jens Keim 6 yıl önce
ebeveyn
işleme
3a0648e8cc
2 değiştirilmiş dosya ile 149 ekleme ve 127 silme
  1. 144 122
      code/Attack/BaseAttack.py
  2. 5 5
      code/Test/test_BaseAttack.py

+ 144 - 122
code/Attack/BaseAttack.py

@@ -1,26 +1,26 @@
-import socket
-import sys
+import abc
 import ipaddress
 import ipaddress
 import os
 import os
 import random
 import random
 import re
 import re
+import socket
+import sys
 import tempfile
 import tempfile
 import time
 import time
-import numpy as np
 
 
-from abc import abstractmethod, ABCMeta
-from scapy.layers.inet import Ether
-from scapy.utils import PcapWriter
+# TODO: double check this import
+# does it complain because libpcapreader is not a .py?
+import ID2TLib.libpcapreader as pr
+import lea
+import numpy as np
+import scapy.layers.inet as inet
+import scapy.utils
 
 
-from Attack import AttackParameters
-from Attack.AttackParameters import Parameter
-from Attack.AttackParameters import ParameterTypes
+import Attack.AttackParameters as atkParam
 import ID2TLib.Utility as Util
 import ID2TLib.Utility as Util
-from lea import Lea
-import ID2TLib.libpcapreader as pr
 
 
 
 
-class BaseAttack(metaclass=ABCMeta):
+class BaseAttack(metaclass=abc.ABCMeta):
     """
     """
     Abstract base class for all attack classes. Provides basic functionalities, like parameter validation.
     Abstract base class for all attack classes. Provides basic functionalities, like parameter validation.
     """
     """
@@ -29,7 +29,6 @@ class BaseAttack(metaclass=ABCMeta):
         """
         """
         To be called within the individual attack class to initialize the required parameters.
         To be called within the individual attack class to initialize the required parameters.
 
 
-        :param statistics: A reference to the Statistics class.
         :param name: The name of the attack class.
         :param name: The name of the attack class.
         :param description: A short description of the attack.
         :param description: A short description of the attack.
         :param attack_type: The type the attack belongs to, like probing/scanning, malware.
         :param attack_type: The type the attack belongs to, like probing/scanning, malware.
@@ -60,7 +59,7 @@ class BaseAttack(metaclass=ABCMeta):
         """
         """
         self.statistics = statistics
         self.statistics = statistics
 
 
-    @abstractmethod
+    @abc.abstractmethod
     def init_params(self):
     def init_params(self):
         """
         """
         Initialize all required parameters taking into account user supplied values. If no value is supplied,
         Initialize all required parameters taking into account user supplied values. If no value is supplied,
@@ -69,14 +68,14 @@ class BaseAttack(metaclass=ABCMeta):
         """
         """
         pass
         pass
 
 
-    @abstractmethod
+    @abc.abstractmethod
     def generate_attack_packets(self):
     def generate_attack_packets(self):
         """
         """
         Creates the attack packets.
         Creates the attack packets.
         """
         """
         pass
         pass
 
 
-    @abstractmethod
+    @abc.abstractmethod
     def generate_attack_pcap(self):
     def generate_attack_pcap(self):
         """
         """
         Creates a pcap containing the attack packets.
         Creates a pcap containing the attack packets.
@@ -93,7 +92,8 @@ class BaseAttack(metaclass=ABCMeta):
     @staticmethod
     @staticmethod
     def _is_mac_address(mac_address: str):
     def _is_mac_address(mac_address: str):
         """
         """
-        Verifies if the given string is a valid MAC address. Accepts the formats 00:80:41:ae:fd:7e and 00-80-41-ae-fd-7e.
+        Verifies if the given string is a valid MAC address.
+        Accepts the formats 00:80:41:ae:fd:7e and 00-80-41-ae-fd-7e.
 
 
         :param mac_address: The MAC address as string.
         :param mac_address: The MAC address as string.
         :return: True if the MAC address is valid, otherwise False.
         :return: True if the MAC address is valid, otherwise False.
@@ -118,6 +118,7 @@ class BaseAttack(metaclass=ABCMeta):
         :param ip_address: The IP address(es) as list of strings, comma-separated or dash-separated string.
         :param ip_address: The IP address(es) as list of strings, comma-separated or dash-separated string.
         :return: True if all IP addresses are valid, otherwise False. And a list of IP addresses as string.
         :return: True if all IP addresses are valid, otherwise False. And a list of IP addresses as string.
         """
         """
+
         def append_ips(ip_address_input):
         def append_ips(ip_address_input):
             """
             """
             Recursive appending function to handle lists and ranges of IP addresses.
             Recursive appending function to handle lists and ranges of IP addresses.
@@ -141,9 +142,7 @@ class BaseAttack(metaclass=ABCMeta):
                         return False, ip_list
                         return False, ip_list
             return is_valid, ip_list
             return is_valid, ip_list
 
 
-        ip_address_output = []
-
-        # a comma-separated list of IP addresses must be splitted first
+        # a comma-separated list of IP addresses must be split first
         if isinstance(ip_address, str):
         if isinstance(ip_address, str):
             ip_address = ip_address.split(',')
             ip_address = ip_address.split(',')
 
 
@@ -191,6 +190,7 @@ class BaseAttack(metaclass=ABCMeta):
                 if _is_invalid_port(port_entry):
                 if _is_invalid_port(port_entry):
                     return False
                     return False
                 ports_output.append(port_entry)
                 ports_output.append(port_entry)
+            # TODO: validate last condition
             elif isinstance(port_entry, str) and port_entry.isdigit():
             elif isinstance(port_entry, str) and port_entry.isdigit():
                 # port_entry describes a single port
                 # port_entry describes a single port
                 port_entry = int(port_entry)
                 port_entry = int(port_entry)
@@ -200,7 +200,7 @@ class BaseAttack(metaclass=ABCMeta):
             elif '-' in port_entry or '..' in port_entry:
             elif '-' in port_entry or '..' in port_entry:
                 # port_entry describes a port range
                 # port_entry describes a port range
                 # allowed format: '1-49151', '1..49151', '1...49151'
                 # allowed format: '1-49151', '1..49151', '1...49151'
-                match = re.match(r'^([0-9]{1,5})(?:-|\.{2,3})([0-9]{1,5})$', port_entry)
+                match = re.match(r'^([0-9]{1,5})(?:-|\.{2,3})([0-9]{1,5})$', str(port_entry))
                 # check validity of port range
                 # check validity of port range
                 # and create list of ports derived from given start and end port
                 # and create list of ports derived from given start and end port
                 (port_start, port_end) = int(match.group(1)), int(match.group(2))
                 (port_start, port_end) = int(match.group(1)), int(match.group(2))
@@ -247,6 +247,7 @@ class BaseAttack(metaclass=ABCMeta):
         # Raises ValueError if value is anything else.
         # Raises ValueError if value is anything else.
         try:
         try:
             import distutils.core
             import distutils.core
+            import distutils.util
             value = distutils.util.strtobool(value.lower())
             value = distutils.util.strtobool(value.lower())
             is_bool = True
             is_bool = True
         except ValueError:
         except ValueError:
@@ -272,18 +273,18 @@ class BaseAttack(metaclass=ABCMeta):
         """
         """
         Verifies that the given string is a valid URI.
         Verifies that the given string is a valid URI.
 
 
-        :param uri: The URI as string.
+        :param val: The URI as string.
         :return: True if URI is valid, otherwise False.
         :return: True if URI is valid, otherwise False.
         """
         """
         domain = re.match(r'^(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+$', val)
         domain = re.match(r'^(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+$', val)
-        return (domain is not None)
-
+        return domain is not None
 
 
     #########################################
     #########################################
     # HELPER METHODS
     # HELPER METHODS
     #########################################
     #########################################
 
 
-    def set_seed(self, seed: int):
+    @staticmethod
+    def set_seed(seed: int):
         """
         """
         :param seed: The random seed to be set.
         :param seed: The random seed to be set.
         """
         """
@@ -291,12 +292,21 @@ class BaseAttack(metaclass=ABCMeta):
             random.seed(seed)
             random.seed(seed)
 
 
     def set_start_time(self):
     def set_start_time(self):
+        """
+        Set the current time as global starting time.
+        """
         self.start_time = time.time()
         self.start_time = time.time()
 
 
     def set_finish_time(self):
     def set_finish_time(self):
+        """
+        Set the current time as global finishing time.
+        """
         self.finish_time = time.time()
         self.finish_time = time.time()
 
 
     def get_packet_generation_time(self):
     def get_packet_generation_time(self):
+        """
+        :return difference between starting and finishing time.
+        """
         return self.finish_time - self.start_time
         return self.finish_time - self.start_time
 
 
     def add_param_value(self, param, value):
     def add_param_value(self, param, value):
@@ -304,9 +314,8 @@ class BaseAttack(metaclass=ABCMeta):
         Adds the pair param : value to the dictionary of attack parameters. Prints and error message and skips the
         Adds the pair param : value to the dictionary of attack parameters. Prints and error message and skips the
         parameter if the validation fails.
         parameter if the validation fails.
 
 
-        :param stats: Statistics used to calculate user queries or default values.
         :param param: Name of the parameter that we wish to modify.
         :param param: Name of the parameter that we wish to modify.
-        :param value: The value we wish to assign to the specifried parameter.
+        :param value: The value we wish to assign to the specified parameter.
         :return: None.
         :return: None.
         """
         """
         # This function call is valid only if there is a statistics object available.
         # This function call is valid only if there is a statistics object available.
@@ -319,12 +328,12 @@ class BaseAttack(metaclass=ABCMeta):
 
 
         # get AttackParameters instance associated with param
         # get AttackParameters instance associated with param
         # for default values assigned in attack classes, like Parameter.PORT_OPEN
         # for default values assigned in attack classes, like Parameter.PORT_OPEN
-        if isinstance(param, AttackParameters.Parameter):
+        if isinstance(param, atkParam.Parameter):
             param_name = param
             param_name = param
         # for values given by user input, like port.open
         # for values given by user input, like port.open
         else:
         else:
             # Get Enum key of given string identifier
             # Get Enum key of given string identifier
-            param_name = AttackParameters.Parameter(param)
+            param_name = atkParam.Parameter(param)
 
 
         # Get parameter type of attack's required_params
         # Get parameter type of attack's required_params
         param_type = self.supported_params.get(param_name)
         param_type = self.supported_params.get(param_name)
@@ -339,43 +348,43 @@ class BaseAttack(metaclass=ABCMeta):
             if value is not None and value is not "":
             if value is not None and value is not "":
                 is_valid = True
                 is_valid = True
             else:
             else:
-                print('Error in given parameter value: ' + value + '. Data could not be retrieved.')
+                print('Error in given parameter value: ' + str(value) + '. Data could not be retrieved.')
 
 
         # Validate parameter depending on parameter's type
         # Validate parameter depending on parameter's type
-        elif param_type == ParameterTypes.TYPE_IP_ADDRESS:
+        elif param_type == atkParam.ParameterTypes.TYPE_IP_ADDRESS:
             is_valid, value = self._is_ip_address(value)
             is_valid, value = self._is_ip_address(value)
-        elif param_type == ParameterTypes.TYPE_PORT:
+        elif param_type == atkParam.ParameterTypes.TYPE_PORT:
             is_valid, value = self._is_port(value)
             is_valid, value = self._is_port(value)
-        elif param_type == ParameterTypes.TYPE_MAC_ADDRESS:
+        elif param_type == atkParam.ParameterTypes.TYPE_MAC_ADDRESS:
             is_valid = self._is_mac_address(value)
             is_valid = self._is_mac_address(value)
-        elif param_type == ParameterTypes.TYPE_INTEGER_POSITIVE:
+        elif param_type == atkParam.ParameterTypes.TYPE_INTEGER_POSITIVE:
             if isinstance(value, int) and int(value) >= 0:
             if isinstance(value, int) and int(value) >= 0:
                 is_valid = True
                 is_valid = True
             elif isinstance(value, str) and value.isdigit() and int(value) >= 0:
             elif isinstance(value, str) and value.isdigit() and int(value) >= 0:
                 is_valid = True
                 is_valid = True
                 value = int(value)
                 value = int(value)
-        elif param_type == ParameterTypes.TYPE_STRING:
+        elif param_type == atkParam.ParameterTypes.TYPE_STRING:
             if isinstance(value, str):
             if isinstance(value, str):
                 is_valid = True
                 is_valid = True
-        elif param_type == ParameterTypes.TYPE_FLOAT:
+        elif param_type == atkParam.ParameterTypes.TYPE_FLOAT:
             is_valid, value = self._is_float(value)
             is_valid, value = self._is_float(value)
             # this is required to avoid that the timestamp's microseconds of the first attack packet is '000000'
             # this is required to avoid that the timestamp's microseconds of the first attack packet is '000000'
             # but microseconds are only chosen randomly if the given parameter does not already specify it
             # but microseconds are only chosen randomly if the given parameter does not already specify it
             # e.g. inject.at-timestamp=123456.987654 -> is not changed
             # e.g. inject.at-timestamp=123456.987654 -> is not changed
             # e.g. inject.at-timestamp=123456 -> is changed to: 123456.[random digits]
             # e.g. inject.at-timestamp=123456 -> is changed to: 123456.[random digits]
-            if param_name == Parameter.INJECT_AT_TIMESTAMP and is_valid and ((value - int(value)) == 0):
+            if param_name == atkParam.Parameter.INJECT_AT_TIMESTAMP and is_valid and ((value - int(value)) == 0):
                 value = value + random.uniform(0, 0.999999)
                 value = value + random.uniform(0, 0.999999)
-        elif param_type == ParameterTypes.TYPE_TIMESTAMP:
+        elif param_type == atkParam.ParameterTypes.TYPE_TIMESTAMP:
             is_valid = self._is_timestamp(value)
             is_valid = self._is_timestamp(value)
-        elif param_type == ParameterTypes.TYPE_BOOLEAN:
+        elif param_type == atkParam.ParameterTypes.TYPE_BOOLEAN:
             is_valid, value = self._is_boolean(value)
             is_valid, value = self._is_boolean(value)
-        elif param_type == ParameterTypes.TYPE_PACKET_POSITION:
+        elif param_type == atkParam.ParameterTypes.TYPE_PACKET_POSITION:
             ts = pr.pcap_processor(self.statistics.pcap_filepath, "False").get_timestamp_mu_sec(int(value))
             ts = pr.pcap_processor(self.statistics.pcap_filepath, "False").get_timestamp_mu_sec(int(value))
             if 0 <= int(value) <= self.statistics.get_packet_count() and ts >= 0:
             if 0 <= int(value) <= self.statistics.get_packet_count() and ts >= 0:
                 is_valid = True
                 is_valid = True
-                param_name = Parameter.INJECT_AT_TIMESTAMP
+                param_name = atkParam.Parameter.INJECT_AT_TIMESTAMP
                 value = (ts / 1000000)  # convert microseconds from getTimestampMuSec into seconds
                 value = (ts / 1000000)  # convert microseconds from getTimestampMuSec into seconds
-        elif param_type == ParameterTypes.TYPE_DOMAIN:
+        elif param_type == atkParam.ParameterTypes.TYPE_DOMAIN:
             is_valid = self._is_domain(value)
             is_valid = self._is_domain(value)
 
 
         # add value iff validation was successful
         # add value iff validation was successful
@@ -385,7 +394,7 @@ class BaseAttack(metaclass=ABCMeta):
             print("ERROR: Parameter " + str(param) + " or parameter value " + str(value) +
             print("ERROR: Parameter " + str(param) + " or parameter value " + str(value) +
                   " not valid. Skipping parameter.")
                   " not valid. Skipping parameter.")
 
 
-    def get_param_value(self, param: Parameter):
+    def get_param_value(self, param: atkParam.Parameter):
         """
         """
         Returns the parameter value for a given parameter.
         Returns the parameter value for a given parameter.
 
 
@@ -400,8 +409,8 @@ class BaseAttack(metaclass=ABCMeta):
         However, this should not happen as all attack should define default parameter values.
         However, this should not happen as all attack should define default parameter values.
         """
         """
         # parameters which do not require default values
         # parameters which do not require default values
-        non_obligatory_params = [Parameter.INJECT_AFTER_PACKET, Parameter.NUMBER_ATTACKERS]
-        for param, type in self.supported_params.items():
+        non_obligatory_params = [atkParam.Parameter.INJECT_AFTER_PACKET, atkParam.Parameter.NUMBER_ATTACKERS]
+        for param, param_type in self.supported_params.items():
             # checks whether all params have assigned values, INJECT_AFTER_PACKET must not be considered because the
             # checks whether all params have assigned values, INJECT_AFTER_PACKET must not be considered because the
             # timestamp derived from it is set to Parameter.INJECT_AT_TIMESTAMP
             # timestamp derived from it is set to Parameter.INJECT_AT_TIMESTAMP
             if param not in self.params.keys() and param not in non_obligatory_params:
             if param not in self.params.keys() and param not in non_obligatory_params:
@@ -414,6 +423,7 @@ class BaseAttack(metaclass=ABCMeta):
     def write_attack_pcap(self, packets: list, append_flag: bool = False, destination_path: str = None):
     def write_attack_pcap(self, packets: list, append_flag: bool = False, destination_path: str = None):
         """
         """
         Writes the attack's packets into a PCAP file with a temporary filename.
         Writes the attack's packets into a PCAP file with a temporary filename.
+
         :return: The path of the written PCAP file.
         :return: The path of the written PCAP file.
         """
         """
         # Only check params initially when attack generation starts
         # Only check params initially when attack generation starts
@@ -429,7 +439,7 @@ class BaseAttack(metaclass=ABCMeta):
             destination = temp_file.name
             destination = temp_file.name
 
 
         # Write packets into pcap file
         # Write packets into pcap file
-        pktdump = PcapWriter(destination, append=append_flag)
+        pktdump = scapy.utils.PcapWriter(destination, append=append_flag)
         pktdump.write(packets)
         pktdump.write(packets)
 
 
         # Store pcap path and close file objects
         # Store pcap path and close file objects
@@ -440,6 +450,7 @@ class BaseAttack(metaclass=ABCMeta):
     def get_reply_delay(self, ip_dst):
     def get_reply_delay(self, ip_dst):
         """
         """
            Gets the minimum and the maximum reply delay for all the connections of a specific IP.
            Gets the minimum and the maximum reply delay for all the connections of a specific IP.
+
            :param ip_dst: The IP to reterive its reply delay.
            :param ip_dst: The IP to reterive its reply delay.
            :return minDelay: minimum delay
            :return minDelay: minimum delay
            :return maxDelay: maximum delay
            :return maxDelay: maximum delay
@@ -448,30 +459,32 @@ class BaseAttack(metaclass=ABCMeta):
         result = self.statistics.process_db_query(
         result = self.statistics.process_db_query(
             "SELECT AVG(minDelay), AVG(maxDelay) FROM conv_statistics WHERE ipAddressB='" + ip_dst + "';")
             "SELECT AVG(minDelay), AVG(maxDelay) FROM conv_statistics WHERE ipAddressB='" + ip_dst + "';")
         if result[0][0] and result[0][1]:
         if result[0][0] and result[0][1]:
-            minDelay = result[0][0]
-            maxDelay = result[0][1]
+            min_delay = result[0][0]
+            max_delay = result[0][1]
         else:
         else:
-            allMinDelays = self.statistics.process_db_query("SELECT minDelay FROM conv_statistics LIMIT 500;")
-            minDelay = np.median(allMinDelays)
-            allMaxDelays = self.statistics.process_db_query("SELECT maxDelay FROM conv_statistics LIMIT 500;")
-            maxDelay = np.median(allMaxDelays)
-        minDelay = int(minDelay) * 10 ** -6  # convert from micro to seconds
-        maxDelay = int(maxDelay) * 10 ** -6
-        return minDelay, maxDelay
-
-    def packets_to_convs(self,exploit_raw_packets):
-        """
-           Classifies a bunch of packets to conversations groups. A conversation is a set of packets go between host A (IP,port)
-           to host B (IP,port)
+            all_min_delays = self.statistics.process_db_query("SELECT minDelay FROM conv_statistics LIMIT 500;")
+            min_delay = np.median(all_min_delays)
+            all_max_delays = self.statistics.process_db_query("SELECT maxDelay FROM conv_statistics LIMIT 500;")
+            max_delay = np.median(all_max_delays)
+        min_delay = int(min_delay) * 10 ** -6  # convert from micro to seconds
+        max_delay = int(max_delay) * 10 ** -6
+        return min_delay, max_delay
+
+    @staticmethod
+    def packets_to_convs(exploit_raw_packets):
+        """
+           Classifies a bunch of packets to conversations groups. A conversation is a set of packets go between host A
+           (IP,port) to host B (IP,port)
+
            :param exploit_raw_packets: A set of packets contains several conversations.
            :param exploit_raw_packets: A set of packets contains several conversations.
            :return conversations: A set of arrays, each array contains the packet of specifc conversation
            :return conversations: A set of arrays, each array contains the packet of specifc conversation
-           :return orderList_conversations: An array contains the conversations ids (IP_A,port_A, IP_b,port_B) in the order
-           they appeared in the original packets.
+           :return orderList_conversations: An array contains the conversations ids (IP_A,port_A, IP_b,port_B) in the
+           order they appeared in the original packets.
            """
            """
         conversations = {}
         conversations = {}
-        orderList_conversations = []
+        order_list_conversations = []
         for pkt_num, pkt in enumerate(exploit_raw_packets):
         for pkt_num, pkt in enumerate(exploit_raw_packets):
-            eth_frame = Ether(pkt[0])
+            eth_frame = inet.Ether(pkt[0])
 
 
             ip_pkt = eth_frame.payload
             ip_pkt = eth_frame.payload
             ip_dst = ip_pkt.getfieldval("dst")
             ip_dst = ip_pkt.getfieldval("dst")
@@ -484,22 +497,23 @@ class BaseAttack(metaclass=ABCMeta):
             conv_req = (ip_src, port_src, ip_dst, port_dst)
             conv_req = (ip_src, port_src, ip_dst, port_dst)
             conv_rep = (ip_dst, port_dst, ip_src, port_src)
             conv_rep = (ip_dst, port_dst, ip_src, port_src)
             if conv_req not in conversations and conv_rep not in conversations:
             if conv_req not in conversations and conv_rep not in conversations:
-                pktList = [pkt]
-                conversations[conv_req] = pktList
+                pkt_list = [pkt]
+                conversations[conv_req] = pkt_list
                 # Order list of conv
                 # Order list of conv
-                orderList_conversations.append(conv_req)
+                order_list_conversations.append(conv_req)
             else:
             else:
                 if conv_req in conversations:
                 if conv_req in conversations:
-                    pktList = conversations[conv_req]
-                    pktList.append(pkt)
-                    conversations[conv_req] = pktList
+                    pkt_list = conversations[conv_req]
+                    pkt_list.append(pkt)
+                    conversations[conv_req] = pkt_list
                 else:
                 else:
-                    pktList = conversations[conv_rep]
-                    pktList.append(pkt)
-                    conversations[conv_rep] = pktList
-        return (conversations, orderList_conversations)
+                    pkt_list = conversations[conv_rep]
+                    pkt_list.append(pkt)
+                    conversations[conv_rep] = pkt_list
+        return conversations, order_list_conversations
 
 
-    def is_valid_ip_address(self,addr):
+    @staticmethod
+    def is_valid_ip_address(addr):
         """
         """
         Checks if the IP address family is supported.
         Checks if the IP address family is supported.
 
 
@@ -512,7 +526,8 @@ class BaseAttack(metaclass=ABCMeta):
         except socket.error:
         except socket.error:
             return False
             return False
 
 
-    def ip_src_dst_equal_check(self, ip_source, ip_destination):
+    @staticmethod
+    def ip_src_dst_equal_check(ip_source, ip_destination):
         """
         """
         Checks if the source IP and destination IP are equal.
         Checks if the source IP and destination IP are equal.
 
 
@@ -530,51 +545,53 @@ class BaseAttack(metaclass=ABCMeta):
             print("\nERROR: Invalid IP addresses; source IP is the same as destination IP: " + ip_destination + ".")
             print("\nERROR: Invalid IP addresses; source IP is the same as destination IP: " + ip_destination + ".")
             sys.exit(0)
             sys.exit(0)
 
 
-    def get_inter_arrival_time(self, packets, distribution:bool=False):
+    @staticmethod
+    def get_inter_arrival_time(packets, distribution: bool = False):
         """
         """
         Gets the inter-arrival times array and its distribution of a set of packets.
         Gets the inter-arrival times array and its distribution of a set of packets.
 
 
         :param packets: the packets to extract their inter-arrival time.
         :param packets: the packets to extract their inter-arrival time.
+        :param distribution: build distribution dictionary or not
         :return inter_arrival_times: array of the inter-arrival times
         :return inter_arrival_times: array of the inter-arrival times
         :return dict: the inter-arrival time distribution as a histogram {inter-arrival time:frequency}
         :return dict: the inter-arrival time distribution as a histogram {inter-arrival time:frequency}
         """
         """
         inter_arrival_times = []
         inter_arrival_times = []
-        prvsPktTime = 0
+        prvs_pkt_time = 0
         for index, pkt in enumerate(packets):
         for index, pkt in enumerate(packets):
-            timestamp = pkt[2][0] + pkt[2][1]/10**6
+            timestamp = pkt[2][0] + pkt[2][1] / 10 ** 6
 
 
             if index == 0:
             if index == 0:
-                prvsPktTime = timestamp
+                prvs_pkt_time = timestamp
                 inter_arrival_times.append(0)
                 inter_arrival_times.append(0)
             else:
             else:
-                inter_arrival_times.append(timestamp - prvsPktTime)
-                prvsPktTime = timestamp
+                inter_arrival_times.append(timestamp - prvs_pkt_time)
+                prvs_pkt_time = timestamp
 
 
         if distribution:
         if distribution:
             # Build a distribution dictionary
             # Build a distribution dictionary
-            import numpy as np
-            freq,values = np.histogram(inter_arrival_times,bins=20)
-            dict = {}
-            for i,val in enumerate(values):
+            freq, values = np.histogram(inter_arrival_times, bins=20)
+            dist_dict = {}
+            for i, val in enumerate(values):
                 if i < len(freq):
                 if i < len(freq):
-                    dict[str(val)] = freq[i]
-            return inter_arrival_times, dict
+                    dist_dict[str(val)] = freq[i]
+            return inter_arrival_times, dist_dict
         else:
         else:
             return inter_arrival_times
             return inter_arrival_times
 
 
-    def clean_white_spaces(self, str):
+    @staticmethod
+    def clean_white_spaces(str_param):
         """
         """
         Delete extra backslash from white spaces. This function is used to process the payload of packets.
         Delete extra backslash from white spaces. This function is used to process the payload of packets.
 
 
-        :param str: the payload to be processed.
+        :param str_param: the payload to be processed.
         """
         """
-        str = str.replace("\\n", "\n")
-        str = str.replace("\\r", "\r")
-        str = str.replace("\\t", "\t")
-        str = str.replace("\\\'", "\'")
-        return str
+        str_param = str_param.replace("\\n", "\n")
+        str_param = str_param.replace("\\r", "\r")
+        str_param = str_param.replace("\\t", "\t")
+        str_param = str_param.replace("\\\'", "\'")
+        return str_param
 
 
-    def modify_http_header(self,str_tcp_seg, orig_target_uri, target_uri, orig_ip_dst, target_host):
+    def modify_http_header(self, str_tcp_seg, orig_target_uri, target_uri, orig_ip_dst, target_host):
         """
         """
         Substitute the URI and HOST in a HTTP header with new values.
         Substitute the URI and HOST in a HTTP header with new values.
 
 
@@ -600,7 +617,7 @@ class BaseAttack(metaclass=ABCMeta):
         # Set MSS (Maximum Segment Size) based on MSS distribution of IP address
         # Set MSS (Maximum Segment Size) based on MSS distribution of IP address
         mss_dist = self.statistics.get_mss_distribution(ip_address)
         mss_dist = self.statistics.get_mss_distribution(ip_address)
         if len(mss_dist) > 0:
         if len(mss_dist) > 0:
-            mss_prob_dict = Lea.fromValFreqsDict(mss_dist)
+            mss_prob_dict = lea.Lea.fromValFreqsDict(mss_dist)
             mss_value = mss_prob_dict.random()
             mss_value = mss_prob_dict.random()
         else:
         else:
             mss_value = Util.handle_most_used_outputs(self.statistics.process_db_query("most_used(mssValue)"))
             mss_value = Util.handle_most_used_outputs(self.statistics.process_db_query("most_used(mssValue)"))
@@ -608,7 +625,7 @@ class BaseAttack(metaclass=ABCMeta):
         # Set TTL based on TTL distribution of IP address
         # Set TTL based on TTL distribution of IP address
         ttl_dist = self.statistics.get_ttl_distribution(ip_address)
         ttl_dist = self.statistics.get_ttl_distribution(ip_address)
         if len(ttl_dist) > 0:
         if len(ttl_dist) > 0:
-            ttl_prob_dict = Lea.fromValFreqsDict(ttl_dist)
+            ttl_prob_dict = lea.Lea.fromValFreqsDict(ttl_dist)
             ttl_value = ttl_prob_dict.random()
             ttl_value = ttl_prob_dict.random()
         else:
         else:
             ttl_value = Util.handle_most_used_outputs(self.statistics.process_db_query("most_used(ttlValue)"))
             ttl_value = Util.handle_most_used_outputs(self.statistics.process_db_query("most_used(ttlValue)"))
@@ -616,56 +633,58 @@ class BaseAttack(metaclass=ABCMeta):
         # Set Window Size based on Window Size distribution of IP address
         # Set Window Size based on Window Size distribution of IP address
         win_dist = self.statistics.get_win_distribution(ip_address)
         win_dist = self.statistics.get_win_distribution(ip_address)
         if len(win_dist) > 0:
         if len(win_dist) > 0:
-            win_prob_dict = Lea.fromValFreqsDict(win_dist)
+            win_prob_dict = lea.Lea.fromValFreqsDict(win_dist)
             win_value = win_prob_dict.random()
             win_value = win_prob_dict.random()
         else:
         else:
             win_value = Util.handle_most_used_outputs(self.statistics.process_db_query("most_used(winSize)"))
             win_value = Util.handle_most_used_outputs(self.statistics.process_db_query("most_used(winSize)"))
 
 
         return mss_value, ttl_value, win_value
         return mss_value, ttl_value, win_value
 
 
-
     #########################################
     #########################################
     # RANDOM IP/MAC ADDRESS GENERATORS
     # RANDOM IP/MAC ADDRESS GENERATORS
     #########################################
     #########################################
 
 
     @staticmethod
     @staticmethod
-    def generate_random_ipv4_address(ipClass, n: int = 1):
+    def generate_random_ipv4_address(ip_class, n: int = 1):
+        # TODO: document ip_class
         """
         """
         Generates n random IPv4 addresses.
         Generates n random IPv4 addresses.
+
+        :param ip_class:
         :param n: The number of IP addresses to be generated
         :param n: The number of IP addresses to be generated
         :return: A single IP address, or if n>1, a list of IP addresses
         :return: A single IP address, or if n>1, a list of IP addresses
         """
         """
 
 
-        def is_invalid(ipAddress: ipaddress.IPv4Address):
-            return ipAddress.is_multicast or ipAddress.is_unspecified or ipAddress.is_loopback or \
-                   ipAddress.is_link_local or ipAddress.is_reserved or ipAddress.is_private
+        def is_invalid(ip_address_param: ipaddress.IPv4Address):
+            return ip_address_param.is_multicast or ip_address_param.is_unspecified or ip_address_param.is_loopback or \
+                   ip_address_param.is_link_local or ip_address_param.is_reserved or ip_address_param.is_private
 
 
         # Generate a random IP from specific class
         # Generate a random IP from specific class
-        def generate_address(ipClass):
-            if ipClass == "Unknown":
+        def generate_address(ip_class_param):
+            if ip_class_param == "Unknown":
                 return ipaddress.IPv4Address(random.randint(0, 2 ** 32 - 1))
                 return ipaddress.IPv4Address(random.randint(0, 2 ** 32 - 1))
             else:
             else:
                 # For DDoS attack, we do not generate private IPs
                 # For DDoS attack, we do not generate private IPs
-                if "private" in ipClass:
-                    ipClass = ipClass[0] # convert A-private to A
-                ipClassesByte1 = {"A": {1,126}, "B": {128,191}, "C":{192, 223}, "D":{224, 239}, "E":{240, 254}}
-                temp = list(ipClassesByte1[ipClass])
-                minB1 = temp[0]
-                maxB1 = temp[1]
-                b1 = random.randint(minB1, maxB1)
+                if "private" in ip_class_param:
+                    ip_class_param = ip_class_param[0]  # convert A-private to A
+                ip_classes_byte1 = {"A": {1, 126}, "B": {128, 191}, "C": {192, 223}, "D": {224, 239}, "E": {240, 254}}
+                temp = list(ip_classes_byte1[ip_class_param])
+                min_b1 = temp[0]
+                max_b1 = temp[1]
+                b1 = random.randint(min_b1, max_b1)
                 b2 = random.randint(1, 255)
                 b2 = random.randint(1, 255)
                 b3 = random.randint(1, 255)
                 b3 = random.randint(1, 255)
                 b4 = random.randint(1, 255)
                 b4 = random.randint(1, 255)
 
 
-                ipAddress = ipaddress.IPv4Address(str(b1) +"."+ str(b2) + "." + str(b3) + "." + str(b4))
+                ip_address = ipaddress.IPv4Address(str(b1) + "." + str(b2) + "." + str(b3) + "." + str(b4))
 
 
-            return ipAddress
+            return ip_address
 
 
         ip_addresses = []
         ip_addresses = []
         for i in range(0, n):
         for i in range(0, n):
-            address = generate_address(ipClass)
+            address = generate_address(ip_class)
             while is_invalid(address):
             while is_invalid(address):
-                address = generate_address(ipClass)
+                address = generate_address(ip_class)
             ip_addresses.append(str(address))
             ip_addresses.append(str(address))
 
 
         if n == 1:
         if n == 1:
@@ -677,13 +696,14 @@ class BaseAttack(metaclass=ABCMeta):
     def generate_random_ipv6_address(n: int = 1):
     def generate_random_ipv6_address(n: int = 1):
         """
         """
         Generates n random IPv6 addresses.
         Generates n random IPv6 addresses.
+
         :param n: The number of IP addresses to be generated
         :param n: The number of IP addresses to be generated
         :return: A single IP address, or if n>1, a list of IP addresses
         :return: A single IP address, or if n>1, a list of IP addresses
         """
         """
 
 
-        def is_invalid(ipAddress: ipaddress.IPv6Address):
-            return ipAddress.is_multicast or ipAddress.is_unspecified or ipAddress.is_loopback or \
-                   ipAddress.is_link_local or ipAddress.is_private or ipAddress.is_reserved
+        def is_invalid(ip_address: ipaddress.IPv6Address):
+            return ip_address.is_multicast or ip_address.is_unspecified or ip_address.is_loopback or \
+                   ip_address.is_link_local or ip_address.is_private or ip_address.is_reserved
 
 
         def generate_address():
         def generate_address():
             return ipaddress.IPv6Address(random.randint(0, 2 ** 128 - 1))
             return ipaddress.IPv6Address(random.randint(0, 2 ** 128 - 1))
@@ -704,17 +724,19 @@ class BaseAttack(metaclass=ABCMeta):
     def generate_random_mac_address(n: int = 1):
     def generate_random_mac_address(n: int = 1):
         """
         """
         Generates n random MAC addresses.
         Generates n random MAC addresses.
+
         :param n: The number of MAC addresses to be generated.
         :param n: The number of MAC addresses to be generated.
         :return: A single MAC addres, or if n>1, a list of MAC addresses
         :return: A single MAC addres, or if n>1, a list of MAC addresses
         """
         """
 
 
-        def is_invalid(address: str):
-            first_octet = int(address[0:2], 16)
+        def is_invalid(address_param: str):
+            first_octet = int(address_param[0:2], 16)
             is_multicast_address = bool(first_octet & 0b01)
             is_multicast_address = bool(first_octet & 0b01)
             is_locally_administered = bool(first_octet & 0b10)
             is_locally_administered = bool(first_octet & 0b10)
             return is_multicast_address or is_locally_administered
             return is_multicast_address or is_locally_administered
 
 
         def generate_address():
         def generate_address():
+            # FIXME: cleanup
             mac = [random.randint(0x00, 0xff) for i in range(0, 6)]
             mac = [random.randint(0x00, 0xff) for i in range(0, 6)]
             return ':'.join(map(lambda x: "%02x" % x, mac))
             return ':'.join(map(lambda x: "%02x" % x, mac))
 
 

+ 5 - 5
code/Test/test_BaseAttack.py

@@ -148,20 +148,20 @@ class TestBaseAttack(unittest.TestCase):
         self.assertFalse(BA.BaseAttack._is_domain("this is not a valid domain, I guess, maybe, let's find out."))
         self.assertFalse(BA.BaseAttack._is_domain("this is not a valid domain, I guess, maybe, let's find out."))
 
 
     def test_is_valid_ipaddress_valid(self):
     def test_is_valid_ipaddress_valid(self):
-        self.assertTrue(BA.BaseAttack.is_valid_ip_address(BA, "192.168.178.42"))
+        self.assertTrue(BA.BaseAttack.is_valid_ip_address("192.168.178.42"))
 
 
     def test_is_valid_ipaddress_invalid(self):
     def test_is_valid_ipaddress_invalid(self):
-        self.assertFalse(BA.BaseAttack.is_valid_ip_address(BA, "192.168.1789.42"))
+        self.assertFalse(BA.BaseAttack.is_valid_ip_address("192.168.1789.42"))
 
 
     def test_ip_src_dst_equal_check_equal(self):
     def test_ip_src_dst_equal_check_equal(self):
         with self.assertRaises(SystemExit):
         with self.assertRaises(SystemExit):
-            BA.BaseAttack.ip_src_dst_equal_check(BA, "192.168.178.42", "192.168.178.42")
+            BA.BaseAttack.ip_src_dst_equal_check("192.168.178.42", "192.168.178.42")
 
 
     def test_ip_src_dst_equal_check_unequal(self):
     def test_ip_src_dst_equal_check_unequal(self):
-        BA.BaseAttack.ip_src_dst_equal_check(BA, "192.168.178.42", "192.168.178.43")
+        BA.BaseAttack.ip_src_dst_equal_check("192.168.178.42", "192.168.178.43")
 
 
     def test_clean_whitespaces(self):
     def test_clean_whitespaces(self):
-        self.assertEqual("a\nb\rc\td\'e", BA.BaseAttack.clean_white_spaces(BA, "a\\nb\\rc\\td\\\'e"))
+        self.assertEqual("a\nb\rc\td\'e", BA.BaseAttack.clean_white_spaces("a\\nb\\rc\\td\\\'e"))
 
 
     def test_generate_random_ipv4_address(self):
     def test_generate_random_ipv4_address(self):
         ip_list = BA.BaseAttack.generate_random_ipv4_address("Unknown", 10)
         ip_list = BA.BaseAttack.generate_random_ipv4_address("Unknown", 10)