Browse Source

refactor imports in Attack/MS17ScanAttack.py

refactor whitspaces in Attack/MS17ScanAttack.py
deep refactoring Attack/MS17ScanAttack.py
Jens Keim 6 years ago
parent
commit
408e6fab95
1 changed files with 50 additions and 52 deletions
  1. 50 52
      code/Attack/MS17ScanAttack.py

+ 50 - 52
code/Attack/MS17ScanAttack.py

@@ -1,19 +1,17 @@
 import logging
-from random import randint, uniform
+import random as rnd
 
-from lea import Lea
-from scapy.layers.inet import Ether
-from scapy.utils import RawPcapReader
+import lea
+import scapy.layers.inet as inet
+import scapy.utils
 
-from Attack import BaseAttack
-from Attack.AttackParameters import Parameter as Param
-from Attack.AttackParameters import ParameterTypes
-from ID2TLib.SMBLib import smb_port
+import Attack.AttackParameters as atkParam
+import Attack.BaseAttack as BaseAttack
+import ID2TLib.SMBLib as SMBLib
 import ID2TLib.Utility as Util
 
 logging.getLogger("scapy.runtime").setLevel(logging.ERROR)
 
-
 # noinspection PyPep8
 
 
@@ -27,23 +25,25 @@ class MS17ScanAttack(BaseAttack.BaseAttack):
     def __init__(self):
         """
         Creates a new instance of the EternalBlue Exploit.
-
         """
         # Initialize attack
         super(MS17ScanAttack, self).__init__("MS17ScanAttack", "Injects a MS17 scan'",
-                                                 "Scanning/Probing")
+                                             "Scanning/Probing")
+
+        self.pkt_num = 0
+        self.path_attack_pcap = None
 
         # Define allowed parameters and their type
         self.supported_params.update({
-            Param.MAC_SOURCE: ParameterTypes.TYPE_MAC_ADDRESS,
-            Param.IP_SOURCE: ParameterTypes.TYPE_IP_ADDRESS,
-            Param.PORT_SOURCE: ParameterTypes.TYPE_PORT,
-            Param.MAC_DESTINATION: ParameterTypes.TYPE_MAC_ADDRESS,
-            Param.IP_DESTINATION: ParameterTypes.TYPE_IP_ADDRESS,
-            Param.PORT_DESTINATION: ParameterTypes.TYPE_PORT,
-            Param.INJECT_AT_TIMESTAMP: ParameterTypes.TYPE_FLOAT,
-            Param.INJECT_AFTER_PACKET: ParameterTypes.TYPE_PACKET_POSITION,
-            Param.PACKETS_PER_SECOND: ParameterTypes.TYPE_FLOAT
+            atkParam.Parameter.MAC_SOURCE: atkParam.ParameterTypes.TYPE_MAC_ADDRESS,
+            atkParam.Parameter.IP_SOURCE: atkParam.ParameterTypes.TYPE_IP_ADDRESS,
+            atkParam.Parameter.PORT_SOURCE: atkParam.ParameterTypes.TYPE_PORT,
+            atkParam.Parameter.MAC_DESTINATION: atkParam.ParameterTypes.TYPE_MAC_ADDRESS,
+            atkParam.Parameter.IP_DESTINATION: atkParam.ParameterTypes.TYPE_IP_ADDRESS,
+            atkParam.Parameter.PORT_DESTINATION: atkParam.ParameterTypes.TYPE_PORT,
+            atkParam.Parameter.INJECT_AT_TIMESTAMP: atkParam.ParameterTypes.TYPE_FLOAT,
+            atkParam.Parameter.INJECT_AFTER_PACKET: atkParam.ParameterTypes.TYPE_PACKET_POSITION,
+            atkParam.Parameter.PACKETS_PER_SECOND: atkParam.ParameterTypes.TYPE_FLOAT
         })
 
     def init_params(self):
@@ -59,58 +59,56 @@ class MS17ScanAttack(BaseAttack.BaseAttack):
         random_ip_address = self.statistics.get_random_ip_address()
         while random_ip_address == most_used_ip_address:
             random_ip_address = self.statistics.get_random_ip_address()
-        self.add_param_value(Param.IP_SOURCE, random_ip_address)
-        self.add_param_value(Param.MAC_SOURCE, self.statistics.get_mac_address(random_ip_address))
-        self.add_param_value(Param.PORT_SOURCE, randint(self.minDefaultPort, self.maxDefaultPort))
+        self.add_param_value(atkParam.Parameter.IP_SOURCE, random_ip_address)
+        self.add_param_value(atkParam.Parameter.MAC_SOURCE, self.statistics.get_mac_address(random_ip_address))
+        self.add_param_value(atkParam.Parameter.PORT_SOURCE, rnd.randint(self.minDefaultPort, self.maxDefaultPort))
 
         # Victim configuration
-        self.add_param_value(Param.IP_DESTINATION, most_used_ip_address)
+        self.add_param_value(atkParam.Parameter.IP_DESTINATION, most_used_ip_address)
         destination_mac = self.statistics.get_mac_address(most_used_ip_address)
         if isinstance(destination_mac, list) and len(destination_mac) == 0:
             destination_mac = self.generate_random_mac_address()
-        self.add_param_value(Param.MAC_DESTINATION, destination_mac)
-        self.add_param_value(Param.PORT_DESTINATION, smb_port)
+        self.add_param_value(atkParam.Parameter.MAC_DESTINATION, destination_mac)
+        self.add_param_value(atkParam.Parameter.PORT_DESTINATION, SMBLib.smb_port)
 
         # Attack configuration
-        self.add_param_value(Param.PACKETS_PER_SECOND,
+        self.add_param_value(atkParam.Parameter.PACKETS_PER_SECOND,
                              (self.statistics.get_pps_sent(most_used_ip_address) +
                               self.statistics.get_pps_received(most_used_ip_address)) / 2)
-        self.add_param_value(Param.INJECT_AFTER_PACKET, randint(0, self.statistics.get_packet_count()))
+        self.add_param_value(atkParam.Parameter.INJECT_AFTER_PACKET, rnd.randint(0, self.statistics.get_packet_count()))
 
     def generate_attack_packets(self):
 
         # Timestamp
-        timestamp_next_pkt = self.get_param_value(Param.INJECT_AT_TIMESTAMP)
-        pps = self.get_param_value(Param.PACKETS_PER_SECOND)
+        timestamp_next_pkt = self.get_param_value(atkParam.Parameter.INJECT_AT_TIMESTAMP)
+        pps = self.get_param_value(atkParam.Parameter.PACKETS_PER_SECOND)
 
         # calculate complement packet rates of BG traffic per interval
         complement_interval_pps = self.statistics.calculate_complement_packet_rates(pps)
 
         # Initialize parameters
         self.packets = []
-        mac_source = self.get_param_value(Param.MAC_SOURCE)
-        ip_source = self.get_param_value(Param.IP_SOURCE)
-        port_source = self.get_param_value(Param.PORT_SOURCE)
-        mac_destination = self.get_param_value(Param.MAC_DESTINATION)
-        ip_destination = self.get_param_value(Param.IP_DESTINATION)
-        port_destination = self.get_param_value(Param.PORT_DESTINATION)
+        mac_source = self.get_param_value(atkParam.Parameter.MAC_SOURCE)
+        ip_source = self.get_param_value(atkParam.Parameter.IP_SOURCE)
+        port_source = self.get_param_value(atkParam.Parameter.PORT_SOURCE)
+        mac_destination = self.get_param_value(atkParam.Parameter.MAC_DESTINATION)
+        ip_destination = self.get_param_value(atkParam.Parameter.IP_DESTINATION)
+        port_destination = self.get_param_value(atkParam.Parameter.PORT_DESTINATION)
 
         # Check ip.src == ip.dst
         self.ip_src_dst_equal_check(ip_source, ip_destination)
 
-        self.path_attack_pcap = None
-
         # Set TTL based on TTL distribution of IP address
         source_ttl_dist = self.statistics.get_ttl_distribution(ip_source)
         if len(source_ttl_dist) > 0:
-            source_ttl_prob_dict = Lea.fromValFreqsDict(source_ttl_dist)
+            source_ttl_prob_dict = lea.Lea.fromValFreqsDict(source_ttl_dist)
             source_ttl_value = source_ttl_prob_dict.random()
         else:
             source_ttl_value = Util.handle_most_used_outputs(self.statistics.process_db_query("most_used(ttlValue)"))
 
         destination_ttl_dist = self.statistics.get_ttl_distribution(ip_destination)
         if len(destination_ttl_dist) > 0:
-            destination_ttl_prob_dict = Lea.fromValFreqsDict(destination_ttl_dist)
+            destination_ttl_prob_dict = lea.Lea.fromValFreqsDict(destination_ttl_dist)
             destination_ttl_value = destination_ttl_prob_dict.random()
         else:
             destination_ttl_value = Util.handle_most_used_outputs(
@@ -119,17 +117,17 @@ class MS17ScanAttack(BaseAttack.BaseAttack):
         # Set Window Size based on Window Size distribution of IP address
         source_win_dist = self.statistics.get_win_distribution(ip_source)
         if len(source_win_dist) > 0:
-            source_win_prob_dict = Lea.fromValFreqsDict(source_win_dist)
+            source_win_prob_dict = lea.Lea.fromValFreqsDict(source_win_dist)
         else:
             source_win_dist = self.statistics.get_win_distribution(self.statistics.get_most_used_ip_address())
-            source_win_prob_dict = Lea.fromValFreqsDict(source_win_dist)
+            source_win_prob_dict = lea.Lea.fromValFreqsDict(source_win_dist)
 
         destination_win_dist = self.statistics.get_win_distribution(ip_destination)
         if len(destination_win_dist) > 0:
-            destination_win_prob_dict = Lea.fromValFreqsDict(destination_win_dist)
+            destination_win_prob_dict = lea.Lea.fromValFreqsDict(destination_win_dist)
         else:
             destination_win_dist = self.statistics.get_win_distribution(self.statistics.get_most_used_ip_address())
-            destination_win_prob_dict = Lea.fromValFreqsDict(destination_win_dist)
+            destination_win_prob_dict = lea.Lea.fromValFreqsDict(destination_win_dist)
 
         # Set MSS (Maximum Segment Size) based on MSS distribution of IP address
         mss_value = Util.handle_most_used_outputs(self.statistics.process_db_query("most_used(mssValue)"))
@@ -139,20 +137,20 @@ class MS17ScanAttack(BaseAttack.BaseAttack):
         # Scan (MS17)
         # Read Win7_eternalblue_scan pcap file
         orig_ip_dst = None
-        exploit_raw_packets = RawPcapReader(self.template_scan_pcap_path)
+        exploit_raw_packets = scapy.utils.RawPcapReader(self.template_scan_pcap_path)
         inter_arrival_times = self.get_inter_arrival_time(exploit_raw_packets)
         exploit_raw_packets.close()
-        exploit_raw_packets = RawPcapReader(self.template_scan_pcap_path)
+        exploit_raw_packets = scapy.utils.RawPcapReader(self.template_scan_pcap_path)
 
         source_origin_wins, destination_origin_wins = {}, {}
 
         for self.pkt_num, pkt in enumerate(exploit_raw_packets):
-            eth_frame = Ether(pkt[0])
+            eth_frame = inet.Ether(pkt[0])
             ip_pkt = eth_frame.payload
             tcp_pkt = ip_pkt.payload
 
             if self.pkt_num == 0:
-                if tcp_pkt.getfieldval("dport") == smb_port:
+                if tcp_pkt.getfieldval("dport") == SMBLib.smb_port:
                     orig_ip_dst = ip_pkt.getfieldval("dst")  # victim IP
 
             # Request
@@ -167,13 +165,13 @@ class MS17ScanAttack(BaseAttack.BaseAttack):
                 # TCP
                 tcp_pkt.setfieldval("sport", port_source)
                 tcp_pkt.setfieldval("dport", port_destination)
-                ## Window Size (mapping)
+                # Window Size (mapping)
                 source_origin_win = tcp_pkt.getfieldval("window")
                 if source_origin_win not in source_origin_wins:
                     source_origin_wins[source_origin_win] = source_win_prob_dict.random()
                 new_win = source_origin_wins[source_origin_win]
                 tcp_pkt.setfieldval("window", new_win)
-                ## MSS
+                # MSS
                 tcp_options = tcp_pkt.getfieldval("options")
                 if tcp_options:
                     if tcp_options[0][0] == "MSS":
@@ -198,13 +196,13 @@ class MS17ScanAttack(BaseAttack.BaseAttack):
                 # TCP
                 tcp_pkt.setfieldval("dport", port_source)
                 tcp_pkt.setfieldval("sport", port_destination)
-                ## Window Size
+                # Window Size
                 destination_origin_win = tcp_pkt.getfieldval("window")
                 if destination_origin_win not in destination_origin_wins:
                     destination_origin_wins[destination_origin_win] = destination_win_prob_dict.random()
                 new_win = destination_origin_wins[destination_origin_win]
                 tcp_pkt.setfieldval("window", new_win)
-                ## MSS
+                # MSS
                 tcp_options = tcp_pkt.getfieldval("options")
                 if tcp_options:
                     if tcp_options[0][0] == "MSS":