Browse Source

refactor imports in Attack/BaseAttack.py

refactor whitspaces in Attack/BaseAttack.py
deep refactoring Attack/BaseAttack.py
Jens Keim 6 years ago
parent
commit
3a0648e8cc
2 changed files with 149 additions and 127 deletions
  1. 144 122
      code/Attack/BaseAttack.py
  2. 5 5
      code/Test/test_BaseAttack.py

+ 144 - 122
code/Attack/BaseAttack.py

@@ -1,26 +1,26 @@
-import socket
-import sys
+import abc
 import ipaddress
 import os
 import random
 import re
+import socket
+import sys
 import tempfile
 import time
-import numpy as np
 
-from abc import abstractmethod, ABCMeta
-from scapy.layers.inet import Ether
-from scapy.utils import PcapWriter
+# TODO: double check this import
+# does it complain because libpcapreader is not a .py?
+import ID2TLib.libpcapreader as pr
+import lea
+import numpy as np
+import scapy.layers.inet as inet
+import scapy.utils
 
-from Attack import AttackParameters
-from Attack.AttackParameters import Parameter
-from Attack.AttackParameters import ParameterTypes
+import Attack.AttackParameters as atkParam
 import ID2TLib.Utility as Util
-from lea import Lea
-import ID2TLib.libpcapreader as pr
 
 
-class BaseAttack(metaclass=ABCMeta):
+class BaseAttack(metaclass=abc.ABCMeta):
     """
     Abstract base class for all attack classes. Provides basic functionalities, like parameter validation.
     """
@@ -29,7 +29,6 @@ class BaseAttack(metaclass=ABCMeta):
         """
         To be called within the individual attack class to initialize the required parameters.
 
-        :param statistics: A reference to the Statistics class.
         :param name: The name of the attack class.
         :param description: A short description of the attack.
         :param attack_type: The type the attack belongs to, like probing/scanning, malware.
@@ -60,7 +59,7 @@ class BaseAttack(metaclass=ABCMeta):
         """
         self.statistics = statistics
 
-    @abstractmethod
+    @abc.abstractmethod
     def init_params(self):
         """
         Initialize all required parameters taking into account user supplied values. If no value is supplied,
@@ -69,14 +68,14 @@ class BaseAttack(metaclass=ABCMeta):
         """
         pass
 
-    @abstractmethod
+    @abc.abstractmethod
     def generate_attack_packets(self):
         """
         Creates the attack packets.
         """
         pass
 
-    @abstractmethod
+    @abc.abstractmethod
     def generate_attack_pcap(self):
         """
         Creates a pcap containing the attack packets.
@@ -93,7 +92,8 @@ class BaseAttack(metaclass=ABCMeta):
     @staticmethod
     def _is_mac_address(mac_address: str):
         """
-        Verifies if the given string is a valid MAC address. Accepts the formats 00:80:41:ae:fd:7e and 00-80-41-ae-fd-7e.
+        Verifies if the given string is a valid MAC address.
+        Accepts the formats 00:80:41:ae:fd:7e and 00-80-41-ae-fd-7e.
 
         :param mac_address: The MAC address as string.
         :return: True if the MAC address is valid, otherwise False.
@@ -118,6 +118,7 @@ class BaseAttack(metaclass=ABCMeta):
         :param ip_address: The IP address(es) as list of strings, comma-separated or dash-separated string.
         :return: True if all IP addresses are valid, otherwise False. And a list of IP addresses as string.
         """
+
         def append_ips(ip_address_input):
             """
             Recursive appending function to handle lists and ranges of IP addresses.
@@ -141,9 +142,7 @@ class BaseAttack(metaclass=ABCMeta):
                         return False, ip_list
             return is_valid, ip_list
 
-        ip_address_output = []
-
-        # a comma-separated list of IP addresses must be splitted first
+        # a comma-separated list of IP addresses must be split first
         if isinstance(ip_address, str):
             ip_address = ip_address.split(',')
 
@@ -191,6 +190,7 @@ class BaseAttack(metaclass=ABCMeta):
                 if _is_invalid_port(port_entry):
                     return False
                 ports_output.append(port_entry)
+            # TODO: validate last condition
             elif isinstance(port_entry, str) and port_entry.isdigit():
                 # port_entry describes a single port
                 port_entry = int(port_entry)
@@ -200,7 +200,7 @@ class BaseAttack(metaclass=ABCMeta):
             elif '-' in port_entry or '..' in port_entry:
                 # port_entry describes a port range
                 # allowed format: '1-49151', '1..49151', '1...49151'
-                match = re.match(r'^([0-9]{1,5})(?:-|\.{2,3})([0-9]{1,5})$', port_entry)
+                match = re.match(r'^([0-9]{1,5})(?:-|\.{2,3})([0-9]{1,5})$', str(port_entry))
                 # check validity of port range
                 # and create list of ports derived from given start and end port
                 (port_start, port_end) = int(match.group(1)), int(match.group(2))
@@ -247,6 +247,7 @@ class BaseAttack(metaclass=ABCMeta):
         # Raises ValueError if value is anything else.
         try:
             import distutils.core
+            import distutils.util
             value = distutils.util.strtobool(value.lower())
             is_bool = True
         except ValueError:
@@ -272,18 +273,18 @@ class BaseAttack(metaclass=ABCMeta):
         """
         Verifies that the given string is a valid URI.
 
-        :param uri: The URI as string.
+        :param val: The URI as string.
         :return: True if URI is valid, otherwise False.
         """
         domain = re.match(r'^(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+$', val)
-        return (domain is not None)
-
+        return domain is not None
 
     #########################################
     # HELPER METHODS
     #########################################
 
-    def set_seed(self, seed: int):
+    @staticmethod
+    def set_seed(seed: int):
         """
         :param seed: The random seed to be set.
         """
@@ -291,12 +292,21 @@ class BaseAttack(metaclass=ABCMeta):
             random.seed(seed)
 
     def set_start_time(self):
+        """
+        Set the current time as global starting time.
+        """
         self.start_time = time.time()
 
     def set_finish_time(self):
+        """
+        Set the current time as global finishing time.
+        """
         self.finish_time = time.time()
 
     def get_packet_generation_time(self):
+        """
+        :return difference between starting and finishing time.
+        """
         return self.finish_time - self.start_time
 
     def add_param_value(self, param, value):
@@ -304,9 +314,8 @@ class BaseAttack(metaclass=ABCMeta):
         Adds the pair param : value to the dictionary of attack parameters. Prints and error message and skips the
         parameter if the validation fails.
 
-        :param stats: Statistics used to calculate user queries or default values.
         :param param: Name of the parameter that we wish to modify.
-        :param value: The value we wish to assign to the specifried parameter.
+        :param value: The value we wish to assign to the specified parameter.
         :return: None.
         """
         # This function call is valid only if there is a statistics object available.
@@ -319,12 +328,12 @@ class BaseAttack(metaclass=ABCMeta):
 
         # get AttackParameters instance associated with param
         # for default values assigned in attack classes, like Parameter.PORT_OPEN
-        if isinstance(param, AttackParameters.Parameter):
+        if isinstance(param, atkParam.Parameter):
             param_name = param
         # for values given by user input, like port.open
         else:
             # Get Enum key of given string identifier
-            param_name = AttackParameters.Parameter(param)
+            param_name = atkParam.Parameter(param)
 
         # Get parameter type of attack's required_params
         param_type = self.supported_params.get(param_name)
@@ -339,43 +348,43 @@ class BaseAttack(metaclass=ABCMeta):
             if value is not None and value is not "":
                 is_valid = True
             else:
-                print('Error in given parameter value: ' + value + '. Data could not be retrieved.')
+                print('Error in given parameter value: ' + str(value) + '. Data could not be retrieved.')
 
         # Validate parameter depending on parameter's type
-        elif param_type == ParameterTypes.TYPE_IP_ADDRESS:
+        elif param_type == atkParam.ParameterTypes.TYPE_IP_ADDRESS:
             is_valid, value = self._is_ip_address(value)
-        elif param_type == ParameterTypes.TYPE_PORT:
+        elif param_type == atkParam.ParameterTypes.TYPE_PORT:
             is_valid, value = self._is_port(value)
-        elif param_type == ParameterTypes.TYPE_MAC_ADDRESS:
+        elif param_type == atkParam.ParameterTypes.TYPE_MAC_ADDRESS:
             is_valid = self._is_mac_address(value)
-        elif param_type == ParameterTypes.TYPE_INTEGER_POSITIVE:
+        elif param_type == atkParam.ParameterTypes.TYPE_INTEGER_POSITIVE:
             if isinstance(value, int) and int(value) >= 0:
                 is_valid = True
             elif isinstance(value, str) and value.isdigit() and int(value) >= 0:
                 is_valid = True
                 value = int(value)
-        elif param_type == ParameterTypes.TYPE_STRING:
+        elif param_type == atkParam.ParameterTypes.TYPE_STRING:
             if isinstance(value, str):
                 is_valid = True
-        elif param_type == ParameterTypes.TYPE_FLOAT:
+        elif param_type == atkParam.ParameterTypes.TYPE_FLOAT:
             is_valid, value = self._is_float(value)
             # this is required to avoid that the timestamp's microseconds of the first attack packet is '000000'
             # but microseconds are only chosen randomly if the given parameter does not already specify it
             # e.g. inject.at-timestamp=123456.987654 -> is not changed
             # e.g. inject.at-timestamp=123456 -> is changed to: 123456.[random digits]
-            if param_name == Parameter.INJECT_AT_TIMESTAMP and is_valid and ((value - int(value)) == 0):
+            if param_name == atkParam.Parameter.INJECT_AT_TIMESTAMP and is_valid and ((value - int(value)) == 0):
                 value = value + random.uniform(0, 0.999999)
-        elif param_type == ParameterTypes.TYPE_TIMESTAMP:
+        elif param_type == atkParam.ParameterTypes.TYPE_TIMESTAMP:
             is_valid = self._is_timestamp(value)
-        elif param_type == ParameterTypes.TYPE_BOOLEAN:
+        elif param_type == atkParam.ParameterTypes.TYPE_BOOLEAN:
             is_valid, value = self._is_boolean(value)
-        elif param_type == ParameterTypes.TYPE_PACKET_POSITION:
+        elif param_type == atkParam.ParameterTypes.TYPE_PACKET_POSITION:
             ts = pr.pcap_processor(self.statistics.pcap_filepath, "False").get_timestamp_mu_sec(int(value))
             if 0 <= int(value) <= self.statistics.get_packet_count() and ts >= 0:
                 is_valid = True
-                param_name = Parameter.INJECT_AT_TIMESTAMP
+                param_name = atkParam.Parameter.INJECT_AT_TIMESTAMP
                 value = (ts / 1000000)  # convert microseconds from getTimestampMuSec into seconds
-        elif param_type == ParameterTypes.TYPE_DOMAIN:
+        elif param_type == atkParam.ParameterTypes.TYPE_DOMAIN:
             is_valid = self._is_domain(value)
 
         # add value iff validation was successful
@@ -385,7 +394,7 @@ class BaseAttack(metaclass=ABCMeta):
             print("ERROR: Parameter " + str(param) + " or parameter value " + str(value) +
                   " not valid. Skipping parameter.")
 
-    def get_param_value(self, param: Parameter):
+    def get_param_value(self, param: atkParam.Parameter):
         """
         Returns the parameter value for a given parameter.
 
@@ -400,8 +409,8 @@ class BaseAttack(metaclass=ABCMeta):
         However, this should not happen as all attack should define default parameter values.
         """
         # parameters which do not require default values
-        non_obligatory_params = [Parameter.INJECT_AFTER_PACKET, Parameter.NUMBER_ATTACKERS]
-        for param, type in self.supported_params.items():
+        non_obligatory_params = [atkParam.Parameter.INJECT_AFTER_PACKET, atkParam.Parameter.NUMBER_ATTACKERS]
+        for param, param_type in self.supported_params.items():
             # checks whether all params have assigned values, INJECT_AFTER_PACKET must not be considered because the
             # timestamp derived from it is set to Parameter.INJECT_AT_TIMESTAMP
             if param not in self.params.keys() and param not in non_obligatory_params:
@@ -414,6 +423,7 @@ class BaseAttack(metaclass=ABCMeta):
     def write_attack_pcap(self, packets: list, append_flag: bool = False, destination_path: str = None):
         """
         Writes the attack's packets into a PCAP file with a temporary filename.
+
         :return: The path of the written PCAP file.
         """
         # Only check params initially when attack generation starts
@@ -429,7 +439,7 @@ class BaseAttack(metaclass=ABCMeta):
             destination = temp_file.name
 
         # Write packets into pcap file
-        pktdump = PcapWriter(destination, append=append_flag)
+        pktdump = scapy.utils.PcapWriter(destination, append=append_flag)
         pktdump.write(packets)
 
         # Store pcap path and close file objects
@@ -440,6 +450,7 @@ class BaseAttack(metaclass=ABCMeta):
     def get_reply_delay(self, ip_dst):
         """
            Gets the minimum and the maximum reply delay for all the connections of a specific IP.
+
            :param ip_dst: The IP to reterive its reply delay.
            :return minDelay: minimum delay
            :return maxDelay: maximum delay
@@ -448,30 +459,32 @@ class BaseAttack(metaclass=ABCMeta):
         result = self.statistics.process_db_query(
             "SELECT AVG(minDelay), AVG(maxDelay) FROM conv_statistics WHERE ipAddressB='" + ip_dst + "';")
         if result[0][0] and result[0][1]:
-            minDelay = result[0][0]
-            maxDelay = result[0][1]
+            min_delay = result[0][0]
+            max_delay = result[0][1]
         else:
-            allMinDelays = self.statistics.process_db_query("SELECT minDelay FROM conv_statistics LIMIT 500;")
-            minDelay = np.median(allMinDelays)
-            allMaxDelays = self.statistics.process_db_query("SELECT maxDelay FROM conv_statistics LIMIT 500;")
-            maxDelay = np.median(allMaxDelays)
-        minDelay = int(minDelay) * 10 ** -6  # convert from micro to seconds
-        maxDelay = int(maxDelay) * 10 ** -6
-        return minDelay, maxDelay
-
-    def packets_to_convs(self,exploit_raw_packets):
-        """
-           Classifies a bunch of packets to conversations groups. A conversation is a set of packets go between host A (IP,port)
-           to host B (IP,port)
+            all_min_delays = self.statistics.process_db_query("SELECT minDelay FROM conv_statistics LIMIT 500;")
+            min_delay = np.median(all_min_delays)
+            all_max_delays = self.statistics.process_db_query("SELECT maxDelay FROM conv_statistics LIMIT 500;")
+            max_delay = np.median(all_max_delays)
+        min_delay = int(min_delay) * 10 ** -6  # convert from micro to seconds
+        max_delay = int(max_delay) * 10 ** -6
+        return min_delay, max_delay
+
+    @staticmethod
+    def packets_to_convs(exploit_raw_packets):
+        """
+           Classifies a bunch of packets to conversations groups. A conversation is a set of packets go between host A
+           (IP,port) to host B (IP,port)
+
            :param exploit_raw_packets: A set of packets contains several conversations.
            :return conversations: A set of arrays, each array contains the packet of specifc conversation
-           :return orderList_conversations: An array contains the conversations ids (IP_A,port_A, IP_b,port_B) in the order
-           they appeared in the original packets.
+           :return orderList_conversations: An array contains the conversations ids (IP_A,port_A, IP_b,port_B) in the
+           order they appeared in the original packets.
            """
         conversations = {}
-        orderList_conversations = []
+        order_list_conversations = []
         for pkt_num, pkt in enumerate(exploit_raw_packets):
-            eth_frame = Ether(pkt[0])
+            eth_frame = inet.Ether(pkt[0])
 
             ip_pkt = eth_frame.payload
             ip_dst = ip_pkt.getfieldval("dst")
@@ -484,22 +497,23 @@ class BaseAttack(metaclass=ABCMeta):
             conv_req = (ip_src, port_src, ip_dst, port_dst)
             conv_rep = (ip_dst, port_dst, ip_src, port_src)
             if conv_req not in conversations and conv_rep not in conversations:
-                pktList = [pkt]
-                conversations[conv_req] = pktList
+                pkt_list = [pkt]
+                conversations[conv_req] = pkt_list
                 # Order list of conv
-                orderList_conversations.append(conv_req)
+                order_list_conversations.append(conv_req)
             else:
                 if conv_req in conversations:
-                    pktList = conversations[conv_req]
-                    pktList.append(pkt)
-                    conversations[conv_req] = pktList
+                    pkt_list = conversations[conv_req]
+                    pkt_list.append(pkt)
+                    conversations[conv_req] = pkt_list
                 else:
-                    pktList = conversations[conv_rep]
-                    pktList.append(pkt)
-                    conversations[conv_rep] = pktList
-        return (conversations, orderList_conversations)
+                    pkt_list = conversations[conv_rep]
+                    pkt_list.append(pkt)
+                    conversations[conv_rep] = pkt_list
+        return conversations, order_list_conversations
 
-    def is_valid_ip_address(self,addr):
+    @staticmethod
+    def is_valid_ip_address(addr):
         """
         Checks if the IP address family is supported.
 
@@ -512,7 +526,8 @@ class BaseAttack(metaclass=ABCMeta):
         except socket.error:
             return False
 
-    def ip_src_dst_equal_check(self, ip_source, ip_destination):
+    @staticmethod
+    def ip_src_dst_equal_check(ip_source, ip_destination):
         """
         Checks if the source IP and destination IP are equal.
 
@@ -530,51 +545,53 @@ class BaseAttack(metaclass=ABCMeta):
             print("\nERROR: Invalid IP addresses; source IP is the same as destination IP: " + ip_destination + ".")
             sys.exit(0)
 
-    def get_inter_arrival_time(self, packets, distribution:bool=False):
+    @staticmethod
+    def get_inter_arrival_time(packets, distribution: bool = False):
         """
         Gets the inter-arrival times array and its distribution of a set of packets.
 
         :param packets: the packets to extract their inter-arrival time.
+        :param distribution: build distribution dictionary or not
         :return inter_arrival_times: array of the inter-arrival times
         :return dict: the inter-arrival time distribution as a histogram {inter-arrival time:frequency}
         """
         inter_arrival_times = []
-        prvsPktTime = 0
+        prvs_pkt_time = 0
         for index, pkt in enumerate(packets):
-            timestamp = pkt[2][0] + pkt[2][1]/10**6
+            timestamp = pkt[2][0] + pkt[2][1] / 10 ** 6
 
             if index == 0:
-                prvsPktTime = timestamp
+                prvs_pkt_time = timestamp
                 inter_arrival_times.append(0)
             else:
-                inter_arrival_times.append(timestamp - prvsPktTime)
-                prvsPktTime = timestamp
+                inter_arrival_times.append(timestamp - prvs_pkt_time)
+                prvs_pkt_time = timestamp
 
         if distribution:
             # Build a distribution dictionary
-            import numpy as np
-            freq,values = np.histogram(inter_arrival_times,bins=20)
-            dict = {}
-            for i,val in enumerate(values):
+            freq, values = np.histogram(inter_arrival_times, bins=20)
+            dist_dict = {}
+            for i, val in enumerate(values):
                 if i < len(freq):
-                    dict[str(val)] = freq[i]
-            return inter_arrival_times, dict
+                    dist_dict[str(val)] = freq[i]
+            return inter_arrival_times, dist_dict
         else:
             return inter_arrival_times
 
-    def clean_white_spaces(self, str):
+    @staticmethod
+    def clean_white_spaces(str_param):
         """
         Delete extra backslash from white spaces. This function is used to process the payload of packets.
 
-        :param str: the payload to be processed.
+        :param str_param: the payload to be processed.
         """
-        str = str.replace("\\n", "\n")
-        str = str.replace("\\r", "\r")
-        str = str.replace("\\t", "\t")
-        str = str.replace("\\\'", "\'")
-        return str
+        str_param = str_param.replace("\\n", "\n")
+        str_param = str_param.replace("\\r", "\r")
+        str_param = str_param.replace("\\t", "\t")
+        str_param = str_param.replace("\\\'", "\'")
+        return str_param
 
-    def modify_http_header(self,str_tcp_seg, orig_target_uri, target_uri, orig_ip_dst, target_host):
+    def modify_http_header(self, str_tcp_seg, orig_target_uri, target_uri, orig_ip_dst, target_host):
         """
         Substitute the URI and HOST in a HTTP header with new values.
 
@@ -600,7 +617,7 @@ class BaseAttack(metaclass=ABCMeta):
         # Set MSS (Maximum Segment Size) based on MSS distribution of IP address
         mss_dist = self.statistics.get_mss_distribution(ip_address)
         if len(mss_dist) > 0:
-            mss_prob_dict = Lea.fromValFreqsDict(mss_dist)
+            mss_prob_dict = lea.Lea.fromValFreqsDict(mss_dist)
             mss_value = mss_prob_dict.random()
         else:
             mss_value = Util.handle_most_used_outputs(self.statistics.process_db_query("most_used(mssValue)"))
@@ -608,7 +625,7 @@ class BaseAttack(metaclass=ABCMeta):
         # Set TTL based on TTL distribution of IP address
         ttl_dist = self.statistics.get_ttl_distribution(ip_address)
         if len(ttl_dist) > 0:
-            ttl_prob_dict = Lea.fromValFreqsDict(ttl_dist)
+            ttl_prob_dict = lea.Lea.fromValFreqsDict(ttl_dist)
             ttl_value = ttl_prob_dict.random()
         else:
             ttl_value = Util.handle_most_used_outputs(self.statistics.process_db_query("most_used(ttlValue)"))
@@ -616,56 +633,58 @@ class BaseAttack(metaclass=ABCMeta):
         # Set Window Size based on Window Size distribution of IP address
         win_dist = self.statistics.get_win_distribution(ip_address)
         if len(win_dist) > 0:
-            win_prob_dict = Lea.fromValFreqsDict(win_dist)
+            win_prob_dict = lea.Lea.fromValFreqsDict(win_dist)
             win_value = win_prob_dict.random()
         else:
             win_value = Util.handle_most_used_outputs(self.statistics.process_db_query("most_used(winSize)"))
 
         return mss_value, ttl_value, win_value
 
-
     #########################################
     # RANDOM IP/MAC ADDRESS GENERATORS
     #########################################
 
     @staticmethod
-    def generate_random_ipv4_address(ipClass, n: int = 1):
+    def generate_random_ipv4_address(ip_class, n: int = 1):
+        # TODO: document ip_class
         """
         Generates n random IPv4 addresses.
+
+        :param ip_class:
         :param n: The number of IP addresses to be generated
         :return: A single IP address, or if n>1, a list of IP addresses
         """
 
-        def is_invalid(ipAddress: ipaddress.IPv4Address):
-            return ipAddress.is_multicast or ipAddress.is_unspecified or ipAddress.is_loopback or \
-                   ipAddress.is_link_local or ipAddress.is_reserved or ipAddress.is_private
+        def is_invalid(ip_address_param: ipaddress.IPv4Address):
+            return ip_address_param.is_multicast or ip_address_param.is_unspecified or ip_address_param.is_loopback or \
+                   ip_address_param.is_link_local or ip_address_param.is_reserved or ip_address_param.is_private
 
         # Generate a random IP from specific class
-        def generate_address(ipClass):
-            if ipClass == "Unknown":
+        def generate_address(ip_class_param):
+            if ip_class_param == "Unknown":
                 return ipaddress.IPv4Address(random.randint(0, 2 ** 32 - 1))
             else:
                 # For DDoS attack, we do not generate private IPs
-                if "private" in ipClass:
-                    ipClass = ipClass[0] # convert A-private to A
-                ipClassesByte1 = {"A": {1,126}, "B": {128,191}, "C":{192, 223}, "D":{224, 239}, "E":{240, 254}}
-                temp = list(ipClassesByte1[ipClass])
-                minB1 = temp[0]
-                maxB1 = temp[1]
-                b1 = random.randint(minB1, maxB1)
+                if "private" in ip_class_param:
+                    ip_class_param = ip_class_param[0]  # convert A-private to A
+                ip_classes_byte1 = {"A": {1, 126}, "B": {128, 191}, "C": {192, 223}, "D": {224, 239}, "E": {240, 254}}
+                temp = list(ip_classes_byte1[ip_class_param])
+                min_b1 = temp[0]
+                max_b1 = temp[1]
+                b1 = random.randint(min_b1, max_b1)
                 b2 = random.randint(1, 255)
                 b3 = random.randint(1, 255)
                 b4 = random.randint(1, 255)
 
-                ipAddress = ipaddress.IPv4Address(str(b1) +"."+ str(b2) + "." + str(b3) + "." + str(b4))
+                ip_address = ipaddress.IPv4Address(str(b1) + "." + str(b2) + "." + str(b3) + "." + str(b4))
 
-            return ipAddress
+            return ip_address
 
         ip_addresses = []
         for i in range(0, n):
-            address = generate_address(ipClass)
+            address = generate_address(ip_class)
             while is_invalid(address):
-                address = generate_address(ipClass)
+                address = generate_address(ip_class)
             ip_addresses.append(str(address))
 
         if n == 1:
@@ -677,13 +696,14 @@ class BaseAttack(metaclass=ABCMeta):
     def generate_random_ipv6_address(n: int = 1):
         """
         Generates n random IPv6 addresses.
+
         :param n: The number of IP addresses to be generated
         :return: A single IP address, or if n>1, a list of IP addresses
         """
 
-        def is_invalid(ipAddress: ipaddress.IPv6Address):
-            return ipAddress.is_multicast or ipAddress.is_unspecified or ipAddress.is_loopback or \
-                   ipAddress.is_link_local or ipAddress.is_private or ipAddress.is_reserved
+        def is_invalid(ip_address: ipaddress.IPv6Address):
+            return ip_address.is_multicast or ip_address.is_unspecified or ip_address.is_loopback or \
+                   ip_address.is_link_local or ip_address.is_private or ip_address.is_reserved
 
         def generate_address():
             return ipaddress.IPv6Address(random.randint(0, 2 ** 128 - 1))
@@ -704,17 +724,19 @@ class BaseAttack(metaclass=ABCMeta):
     def generate_random_mac_address(n: int = 1):
         """
         Generates n random MAC addresses.
+
         :param n: The number of MAC addresses to be generated.
         :return: A single MAC addres, or if n>1, a list of MAC addresses
         """
 
-        def is_invalid(address: str):
-            first_octet = int(address[0:2], 16)
+        def is_invalid(address_param: str):
+            first_octet = int(address_param[0:2], 16)
             is_multicast_address = bool(first_octet & 0b01)
             is_locally_administered = bool(first_octet & 0b10)
             return is_multicast_address or is_locally_administered
 
         def generate_address():
+            # FIXME: cleanup
             mac = [random.randint(0x00, 0xff) for i in range(0, 6)]
             return ':'.join(map(lambda x: "%02x" % x, mac))
 

+ 5 - 5
code/Test/test_BaseAttack.py

@@ -148,20 +148,20 @@ class TestBaseAttack(unittest.TestCase):
         self.assertFalse(BA.BaseAttack._is_domain("this is not a valid domain, I guess, maybe, let's find out."))
 
     def test_is_valid_ipaddress_valid(self):
-        self.assertTrue(BA.BaseAttack.is_valid_ip_address(BA, "192.168.178.42"))
+        self.assertTrue(BA.BaseAttack.is_valid_ip_address("192.168.178.42"))
 
     def test_is_valid_ipaddress_invalid(self):
-        self.assertFalse(BA.BaseAttack.is_valid_ip_address(BA, "192.168.1789.42"))
+        self.assertFalse(BA.BaseAttack.is_valid_ip_address("192.168.1789.42"))
 
     def test_ip_src_dst_equal_check_equal(self):
         with self.assertRaises(SystemExit):
-            BA.BaseAttack.ip_src_dst_equal_check(BA, "192.168.178.42", "192.168.178.42")
+            BA.BaseAttack.ip_src_dst_equal_check("192.168.178.42", "192.168.178.42")
 
     def test_ip_src_dst_equal_check_unequal(self):
-        BA.BaseAttack.ip_src_dst_equal_check(BA, "192.168.178.42", "192.168.178.43")
+        BA.BaseAttack.ip_src_dst_equal_check("192.168.178.42", "192.168.178.43")
 
     def test_clean_whitespaces(self):
-        self.assertEqual("a\nb\rc\td\'e", BA.BaseAttack.clean_white_spaces(BA, "a\\nb\\rc\\td\\\'e"))
+        self.assertEqual("a\nb\rc\td\'e", BA.BaseAttack.clean_white_spaces("a\\nb\\rc\\td\\\'e"))
 
     def test_generate_random_ipv4_address(self):
         ip_list = BA.BaseAttack.generate_random_ipv4_address("Unknown", 10)