Browse Source

add ip range handling for ip param

Jens Keim 6 years ago
parent
commit
11788daab4
3 changed files with 39 additions and 13 deletions
  1. 31 13
      code/Attack/BaseAttack.py
  2. 4 0
      code/Attack/PortscanAttack.py
  3. 4 0
      code/Attack/SQLiAttack.py

+ 31 - 13
code/Attack/BaseAttack.py

@@ -15,7 +15,7 @@ from scapy.utils import PcapWriter
 from Attack import AttackParameters
 from Attack.AttackParameters import Parameter
 from Attack.AttackParameters import ParameterTypes
-from ID2TLib.Utility import handle_most_used_outputs
+import ID2TLib.Utility as Util
 from lea import Lea
 import ID2TLib.libpcapreader as pr
 
@@ -115,26 +115,44 @@ class BaseAttack(metaclass=ABCMeta):
         Verifies that the given string or list of IP addresses (strings) is a valid IPv4/IPv6 address.
         Accepts comma-separated lists of IP addresses, like "192.169.178.1, 192.168.178.2"
 
-        :param ip_address: The IP address(es) as list of strings or comma-separated string.
+        :param ip_address: The IP address(es) as list of strings, comma-separated or dash-separated string.
         :return: True if all IP addresses are valid, otherwise False. And a list of IP addresses as string.
         """
+        def append_ips(ip_address_input):
+            """
+            Recursive appending function to handle lists and ranges of IP addresses.
+
+            :param ip_address_input: The IP address(es) as list of strings, comma-separated or dash-separated string.
+            :return: List of all given IP addresses.
+            """
+            ip_list = []
+            is_valid = True
+            for ip in ip_address_input:
+                if '-' in ip:
+                    ip_range = ip.split('-')
+                    ip_range = Util.get_ip_range(ip_range[0], ip_range[1])
+                    is_valid, ips = append_ips(ip_range)
+                    ip_list.extend(ips)
+                else:
+                    try:
+                        ipaddress.ip_address(ip)
+                        ip_list.append(ip)
+                    except ValueError:
+                        return False, ip_list
+            return is_valid, ip_list
+
         ip_address_output = []
 
         # a comma-separated list of IP addresses must be splitted first
         if isinstance(ip_address, str):
             ip_address = ip_address.split(',')
 
-        for ip in ip_address:
-            try:
-                ipaddress.ip_address(ip)
-                ip_address_output.append(ip)
-            except ValueError:
-                return False, ip_address_output
+        result, ip_address_output = append_ips(ip_address)
 
         if len(ip_address_output) == 1:
-            return True, ip_address_output[0]
+            return result, ip_address_output[0]
         else:
-            return True, ip_address_output
+            return result, ip_address_output
 
     @staticmethod
     def _is_port(ports_input: str):
@@ -585,7 +603,7 @@ class BaseAttack(metaclass=ABCMeta):
             mss_prob_dict = Lea.fromValFreqsDict(mss_dist)
             mss_value = mss_prob_dict.random()
         else:
-            mss_value = handle_most_used_outputs(self.statistics.process_db_query("most_used(mssValue)"))
+            mss_value = Util.handle_most_used_outputs(self.statistics.process_db_query("most_used(mssValue)"))
 
         # Set TTL based on TTL distribution of IP address
         ttl_dist = self.statistics.get_ttl_distribution(ip_address)
@@ -593,7 +611,7 @@ class BaseAttack(metaclass=ABCMeta):
             ttl_prob_dict = Lea.fromValFreqsDict(ttl_dist)
             ttl_value = ttl_prob_dict.random()
         else:
-            ttl_value = handle_most_used_outputs(self.statistics.process_db_query("most_used(ttlValue)"))
+            ttl_value = Util.handle_most_used_outputs(self.statistics.process_db_query("most_used(ttlValue)"))
 
         # Set Window Size based on Window Size distribution of IP address
         win_dist = self.statistics.get_win_distribution(ip_address)
@@ -601,7 +619,7 @@ class BaseAttack(metaclass=ABCMeta):
             win_prob_dict = Lea.fromValFreqsDict(win_dist)
             win_value = win_prob_dict.random()
         else:
-            win_value = handle_most_used_outputs(self.statistics.process_db_query("most_used(winSize)"))
+            win_value = Util.handle_most_used_outputs(self.statistics.process_db_query("most_used(winSize)"))
 
         return mss_value, ttl_value, win_value
 

+ 4 - 0
code/Attack/PortscanAttack.py

@@ -136,7 +136,11 @@ class PortscanAttack(BaseAttack.BaseAttack):
         # Initialize parameters
         self.packets = []
         ip_source = self.get_param_value(Param.IP_SOURCE)
+        if isinstance(ip_source, list):
+            ip_source = ip_source[0]
         ip_destination = self.get_param_value(Param.IP_DESTINATION)
+        if isinstance(ip_destination, list):
+            ip_destination = ip_destination[0]
 
         # Check ip.src == ip.dst
         self.ip_src_dst_equal_check(ip_source, ip_destination)

+ 4 - 0
code/Attack/SQLiAttack.py

@@ -89,8 +89,12 @@ class SQLiAttack(BaseAttack.BaseAttack):
         self.packets = []
         mac_source = self.get_param_value(Param.MAC_SOURCE)
         ip_source = self.get_param_value(Param.IP_SOURCE)
+        if isinstance(ip_source, list):
+            ip_source = ip_source[0]
         mac_destination = self.get_param_value(Param.MAC_DESTINATION)
         ip_destination = self.get_param_value(Param.IP_DESTINATION)
+        if isinstance(ip_destination, list):
+            ip_destination = ip_destination[0]
         port_destination = self.get_param_value(Param.PORT_DESTINATION)
 
         target_host = self.get_param_value(Param.TARGET_HOST)