Procházet zdrojové kódy

- Improves using statistics in the PortscanAttack
- Removes (RST,ACK) answer from target host if port is not open
- Fixes bug: Temporary attack pcap is not deleted from /tmp directory
- Adds methods for retrieving statistics by Statistics class

Patrick Jattke před 7 roky
rodič
revize
2e9a925eb8

+ 7 - 3
code/Attack/BaseAttack.py

@@ -124,8 +124,8 @@ class BaseAttack(metaclass=ABCMeta):
                 ports_output.append(port_entry)
                 ports_output.append(port_entry)
             elif '-' in port_entry or '..' in port_entry:
             elif '-' in port_entry or '..' in port_entry:
                 # port_entry describes a port range
                 # port_entry describes a port range
-                # allowed format: '12-123', '12..123', '12...123'
-                match = re.match('^([0-9]{1,4})(?:-|\.{2,3})([0-9]{1,4})$', port_entry)
+                # allowed format: '1-49151', '1..49151', '1...49151'
+                match = re.match('^([0-9]{1,5})(?:-|\.{2,3})([0-9]{1,5})$', port_entry)
                 # check validity of port range
                 # check validity of port range
                 # and create list of ports derived from given start and end port
                 # and create list of ports derived from given start and end port
                 (port_start, port_end) = int(match.group(1)), int(match.group(2))
                 (port_start, port_end) = int(match.group(1)), int(match.group(2))
@@ -135,7 +135,11 @@ class BaseAttack(metaclass=ABCMeta):
                     ports_list = [i for i in range(port_start, port_end + 1)]
                     ports_list = [i for i in range(port_start, port_end + 1)]
                 # append ports at ports_output list
                 # append ports at ports_output list
                 ports_output += ports_list
                 ports_output += ports_list
-        return True, ports_output
+
+        if len(ports_output) == 1:
+            return True, ports_output[0]
+        else:
+            return True, ports_output
 
 
     @staticmethod
     @staticmethod
     def _is_timestamp(timestamp: str):
     def _is_timestamp(timestamp: str):

+ 34 - 26
code/Attack/PortscanAttack.py

@@ -42,25 +42,30 @@ class PortscanAttack(BaseAttack.BaseAttack):
 
 
         # PARAMETERS: initialize with default values
         # PARAMETERS: initialize with default values
         # (values are overwritten if user specifies them)
         # (values are overwritten if user specifies them)
-        most_used_ipAddress = self.statistics.process_db_query("most_used(ipAddress)")
-        if isinstance(most_used_ipAddress, list):
-            most_used_ipAddress = most_used_ipAddress[0]
-        self.add_param_value(Param.IP_SOURCE, most_used_ipAddress)
+        most_used_ip_address = self.statistics.get_most_used_ip_address()
+        if isinstance(most_used_ip_address, list):
+            most_used_ip_address = most_used_ip_address[0]
+
+        self.add_param_value(Param.IP_SOURCE, most_used_ip_address)
         self.add_param_value(Param.IP_SOURCE_RANDOMIZE, 'False')
         self.add_param_value(Param.IP_SOURCE_RANDOMIZE, 'False')
-        self.add_param_value(Param.IP_DESTINATION, '192.168.178.13')
+        self.add_param_value(Param.MAC_SOURCE, self.statistics.get_mac_address(most_used_ip_address))
+
+        random_ip_address = self.statistics.get_random_ip_address()
+        self.add_param_value(Param.IP_DESTINATION, random_ip_address)
+        self.add_param_value(Param.MAC_DESTINATION, self.statistics.get_mac_address(random_ip_address))
+
         self.add_param_value(Param.PORT_DESTINATION, '0-1023,1720,1900,8080')
         self.add_param_value(Param.PORT_DESTINATION, '0-1023,1720,1900,8080')
-        self.add_param_value(Param.PORT_SOURCE, '8542')
         self.add_param_value(Param.PORT_OPEN, '8080,9232,9233')
         self.add_param_value(Param.PORT_OPEN, '8080,9232,9233')
-        self.add_param_value(Param.PORT_SOURCE_RANDOM, 'False')
         self.add_param_value(Param.PORT_DEST_SHUFFLE, 'False')
         self.add_param_value(Param.PORT_DEST_SHUFFLE, 'False')
         self.add_param_value(Param.PORT_ORDER_DESC, 'False')
         self.add_param_value(Param.PORT_ORDER_DESC, 'False')
-        macAddress = self.statistics.process_db_query('macAddress(ipAddress=' + most_used_ipAddress + ")")
-        self.add_param_value(Param.MAC_SOURCE, macAddress)
-        self.add_param_value(Param.MAC_DESTINATION, 'A0:1A:28:0B:62:F4')
+
+        self.add_param_value(Param.PORT_SOURCE, '8542')
+        self.add_param_value(Param.PORT_SOURCE_RANDOM, 'False')
+
         self.add_param_value(Param.PACKETS_PER_SECOND,
         self.add_param_value(Param.PACKETS_PER_SECOND,
-                             (self.statistics.get_pps_sent(most_used_ipAddress) +
-                              self.statistics.get_pps_received(most_used_ipAddress)) / 2)
-        self.add_param_value(Param.INJECT_AT_TIMESTAMP, '1410733342')  # Sun, 14 Sep 2014 22:22:22 GMT
+                             (self.statistics.get_pps_sent(most_used_ip_address) +
+                              self.statistics.get_pps_received(most_used_ip_address)) / 2)
+        self.add_param_value(Param.INJECT_AFTER_PACKET, randint(0, self.statistics.get_packet_count()))
 
 
     def get_packets(self):
     def get_packets(self):
         def update_timestamp(timestamp, pps, maxdelay):
         def update_timestamp(timestamp, pps, maxdelay):
@@ -88,9 +93,6 @@ class PortscanAttack(BaseAttack.BaseAttack):
         # TTL_samples = numpy.random.choice(keys, size=len(dest_ports), replace=True, dport=values)
         # TTL_samples = numpy.random.choice(keys, size=len(dest_ports), replace=True, dport=values)
         ttl_value = self.statistics.process_db_query("most_used(ttlValue)")
         ttl_value = self.statistics.process_db_query("most_used(ttlValue)")
 
 
-        # MSS (Maximum Segment Size) for Ethernet. Allowed values [536,1500]
-        mss = ('MSS', int(self.statistics.process_db_query('avg(mss)')))
-
         # Timestamp
         # Timestamp
         timestamp_next_pkt = self.get_param_value(Param.INJECT_AT_TIMESTAMP)
         timestamp_next_pkt = self.get_param_value(Param.INJECT_AT_TIMESTAMP)
         self.attack_start_utime = timestamp_next_pkt  # store start time of attack
         self.attack_start_utime = timestamp_next_pkt  # store start time of attack
@@ -105,6 +107,9 @@ class PortscanAttack(BaseAttack.BaseAttack):
         mac_source = self.get_param_value(Param.MAC_SOURCE)
         mac_source = self.get_param_value(Param.MAC_SOURCE)
         mac_destination = self.get_param_value(Param.MAC_DESTINATION)
         mac_destination = self.get_param_value(Param.MAC_DESTINATION)
 
 
+        # MSS (Maximum Segment Size) for Ethernet. Allowed values [536,1500]
+        mss = self.statistics.get_mss(ip_destination)
+
         for dport in dest_ports:
         for dport in dest_ports:
             # Parameters changing each iteration
             # Parameters changing each iteration
             if self.get_param_value(Param.IP_SOURCE_RANDOMIZE) and isinstance(ip_source, list):
             if self.get_param_value(Param.IP_SOURCE_RANDOMIZE) and isinstance(ip_source, list):
@@ -125,10 +130,13 @@ class PortscanAttack(BaseAttack.BaseAttack):
             reply_ether = Ether(src=mac_destination, dst=mac_source)
             reply_ether = Ether(src=mac_destination, dst=mac_source)
             reply_ip = IP(src=ip_destination, dst=ip_source, flags='DF')
             reply_ip = IP(src=ip_destination, dst=ip_source, flags='DF')
 
 
-            if str(dport) in self.get_param_value(Param.PORT_OPEN):  # destination port is OPEN
+            if dport in self.get_param_value(Param.PORT_OPEN):  # destination port is OPEN
                 # target answers
                 # target answers
-                reply_tcp = TCP(sport=dport, dport=sport, seq=0, ack=1, flags='SA', window=29200,
-                                options=[mss])
+                if mss is None:
+                    reply_tcp = TCP(sport=dport, dport=sport, seq=0, ack=1, flags='SA', window=29200)
+                else:
+                    reply_tcp = TCP(sport=dport, dport=sport, seq=0, ack=1, flags='SA', window=29200,
+                                    options=[('MSS', mss)])
                 # reply_tcp.time = time_sec_start + random.uniform(0.00005, 0.00013)
                 # reply_tcp.time = time_sec_start + random.uniform(0.00005, 0.00013)
                 reply = (reply_ether / reply_ip / reply_tcp)
                 reply = (reply_ether / reply_ip / reply_tcp)
                 timestamp_next_pkt = update_timestamp(timestamp_next_pkt, pps, maxdelay)
                 timestamp_next_pkt = update_timestamp(timestamp_next_pkt, pps, maxdelay)
@@ -144,13 +152,13 @@ class PortscanAttack(BaseAttack.BaseAttack):
                 reply.time = timestamp_next_pkt
                 reply.time = timestamp_next_pkt
                 packets.append(reply)
                 packets.append(reply)
 
 
-            else:  # destination port is NOT OPEN
-                reply_tcp = TCP(sport=dport, dport=sport, flags='RA', seq=1, ack=1, window=0)
-                # reply_tcp.time = time_sec_start + random.uniform(0.00005, 0.00013)
-                reply = (reply_ether / reply_ip / reply_tcp)
-                timestamp_next_pkt = update_timestamp(timestamp_next_pkt, pps, maxdelay)
-                reply.time = timestamp_next_pkt
-                packets.append(reply)
+                # else:  # destination port is NOT OPEN -> no reply is sent by target
+                #     reply_tcp = TCP(sport=dport, dport=sport, flags='RA', seq=1, ack=1, window=0)
+                #     # reply_tcp.time = time_sec_start + random.uniform(0.00005, 0.00013)
+                #     reply = (reply_ether / reply_ip / reply_tcp)
+                #     timestamp_next_pkt = update_timestamp(timestamp_next_pkt, pps, maxdelay)
+                #     reply.time = timestamp_next_pkt
+                #     packets.append(reply)
 
 
         # store end time of attack
         # store end time of attack
         self.attack_end_utime = reply.time
         self.attack_end_utime = reply.time

+ 4 - 0
code/ID2TLib/AttackController.py

@@ -1,4 +1,5 @@
 import importlib
 import importlib
+import os
 import tempfile
 import tempfile
 
 
 from scapy.utils import PcapWriter
 from scapy.utils import PcapWriter
@@ -105,6 +106,9 @@ class AttackController:
         # Merge attack with existing pcap
         # Merge attack with existing pcap
         pcap_dest_path = self.pcap_file.merge_attack(temp_attack_pcap_path)
         pcap_dest_path = self.pcap_file.merge_attack(temp_attack_pcap_path)
 
 
+        # Delete temporary attack pcap
+        os.remove(temp_attack_pcap_path)
+
         # Store label into LabelManager
         # Store label into LabelManager
         l = Label(attack, self.get_attack_start_utime(),
         l = Label(attack, self.get_attack_start_utime(),
                   self.get_attack_end_utime(), attack_note)
                   self.get_attack_end_utime(), attack_note)

+ 39 - 0
code/ID2TLib/Statistics.py

@@ -207,6 +207,45 @@ class Statistics:
         """
         """
         return self.file_info['packetCount']
         return self.file_info['packetCount']
 
 
+    def get_most_used_ip_address(self):
+        """
+        :return: The IP address/addresses with the highest sum of packets sent and received
+        """
+        return self.process_db_query("most_used(ipAddress)")
+
+    def get_random_ip_address(self, count: int = 1):
+        """
+        :param count: The number of IP addreses to return
+        :return: A randomly chosen IP address from the dataset or iff param count is greater than one, a list of randomly
+         chosen IP addresses
+        """
+        if count == 1:
+            return self.process_db_query("random(all(ipAddress))")
+        else:
+            ip_address_list = []
+            for i in range(0, count):
+                ip_address_list.append(self.process_db_query("random(all(ipAddress))"))
+            return ip_address_list
+
+    def get_mac_address(self, ipAddress: str):
+        """
+        :return: The MAC address used in the dataset for the given IP address.
+        """
+        return self.process_db_query('macAddress(ipAddress=' + ipAddress + ")")
+
+    def get_mss(self, ipAddress: str):
+        """
+
+        :param ipAddress: The IP address whose used MSS should be determined
+        :return: The TCP MSS value used by the IP address, or if the IP addresses never specified a MSS,
+        then None is returned
+        """
+        mss_value = self.process_db_query('SELECT mss from tcp_mss WHERE ipAddress="' + ipAddress + '"')
+        if isinstance(mss_value, int):
+            return mss_value
+        else:
+            return None
+
     def get_statistics_database(self):
     def get_statistics_database(self):
         """
         """
         :return: A reference to the statistics database object
         :return: A reference to the statistics database object

+ 5 - 4
code/ID2TLib/StatsDatabase.py

@@ -116,7 +116,7 @@ class StatsDatabase:
 
 
     def get_field_types(self, *table_names):
     def get_field_types(self, *table_names):
         """
         """
-        Creates a dictionary whose keys are the fields of the given table(s) and whose values are the appropriates field
+        Creates a dictionary whose keys are the fields of the given table(s) and whose values are the appropriate field
         types, like TEXT for strings and REAL for float numbers.
         types, like TEXT for strings and REAL for float numbers.
 
 
         :param table_names: The name of table(s)
         :param table_names: The name of table(s)
@@ -213,7 +213,7 @@ class StatsDatabase:
                         isinstance(last_result, list) or isinstance(last_result, tuple)):
                         isinstance(last_result, list) or isinstance(last_result, tuple)):
                 extractor = q[0]
                 extractor = q[0]
                 if extractor == 'random':
                 if extractor == 'random':
-                    index = randint(a=0, b=len(last_result))
+                    index = randint(a=0, b=len(last_result) - 1)
                     last_result = last_result[index]
                     last_result = last_result[index]
                 elif extractor == 'first':
                 elif extractor == 'first':
                     last_result = last_result[0]
                     last_result = last_result[0]
@@ -277,7 +277,9 @@ class StatsDatabase:
             return
             return
 
 
         # If result is tuple/list with single element, extract value from list
         # If result is tuple/list with single element, extract value from list
-        requires_extraction = (isinstance(result, list) or isinstance(result, tuple)) and len(result) == 1 and len(result[0]) == 1
+        requires_extraction = (isinstance(result, list) or isinstance(result, tuple)) and len(result) == 1 and \
+                              (not isinstance(result[0], tuple) or len(result[0]) == 1)
+
         while requires_extraction:
         while requires_extraction:
             if isinstance(result, list) or isinstance(result, tuple):
             if isinstance(result, list) or isinstance(result, tuple):
                 result = result[0]
                 result = result[0]
@@ -312,7 +314,6 @@ class StatsDatabase:
         # Print number of results according to type of result
         # Print number of results according to type of result
         if isinstance(result, list):
         if isinstance(result, list):
             print("Query returned " + str(len(result)) + " records:\n")
             print("Query returned " + str(len(result)) + " records:\n")
-
         else:
         else:
             print("Query returned 1 record:\n")
             print("Query returned 1 record:\n")