import logging

from random import randint, uniform
from lea import Lea
from scapy.utils import RawPcapReader
from scapy.layers.inet import Ether

from Attack import BaseAttack
from Attack.AttackParameters import Parameter as Param
from Attack.AttackParameters import ParameterTypes
from ID2TLib.Utility import update_timestamp, get_interval_pps
from ID2TLib.SMBLib import smb_port

logging.getLogger("scapy.runtime").setLevel(logging.ERROR)
# noinspection PyPep8


class EternalBlueExploit(BaseAttack.BaseAttack):
    template_scan_pcap_path = "resources/Win7_eternalblue_scan.pcap"
    template_attack_pcap_path = "resources/Win7_eternalblue_exploit.pcap"
    # Empirical values from Metasploit experiments
    minDefaultPort = 30000
    maxDefaultPort = 50000
    last_conn_dst_port = 4444

    def __init__(self):
        """
        Creates a new instance of the EternalBlue Exploit.

        """
        # Initialize attack
        super(EternalBlueExploit, self).__init__("EternalBlue Exploit", "Injects an EternalBlue exploit'",
                                        "Privilege elevation")

        # Define allowed parameters and their type
        self.supported_params = {
            Param.MAC_SOURCE: ParameterTypes.TYPE_MAC_ADDRESS,
            Param.IP_SOURCE: ParameterTypes.TYPE_IP_ADDRESS,
            Param.PORT_SOURCE: ParameterTypes.TYPE_PORT,
            Param.MAC_DESTINATION: ParameterTypes.TYPE_MAC_ADDRESS,
            Param.IP_DESTINATION: ParameterTypes.TYPE_IP_ADDRESS,
            Param.PORT_DESTINATION: ParameterTypes.TYPE_PORT,
            Param.INJECT_AT_TIMESTAMP: ParameterTypes.TYPE_FLOAT,
            Param.INJECT_AFTER_PACKET: ParameterTypes.TYPE_PACKET_POSITION,
            Param.PACKETS_PER_SECOND: ParameterTypes.TYPE_FLOAT
        }

    def init_params(self):
        """
        Initialize the parameters of this attack using the user supplied command line parameters.
        Use the provided statistics to calculate default parameters and to process user
        supplied queries.

        :param statistics: Reference to a statistics object.
        """
        # PARAMETERS: initialize with default utilsvalues
        # (values are overwritten if user specifies them)
        # Attacker configuration
        most_used_ip_address = self.statistics.get_most_used_ip_address()
        if isinstance(most_used_ip_address, list):
            most_used_ip_address = most_used_ip_address[0]
        random_ip_address = self.statistics.get_random_ip_address()
        self.add_param_value(Param.IP_SOURCE, random_ip_address)
        self.add_param_value(Param.MAC_SOURCE, self.statistics.get_mac_address(random_ip_address))
        self.add_param_value(Param.PORT_SOURCE, randint(self.minDefaultPort, self.maxDefaultPort))

        # Victim configuration
        self.add_param_value(Param.IP_DESTINATION, most_used_ip_address)
        destination_mac = self.statistics.get_mac_address(most_used_ip_address)
        if isinstance(destination_mac, list) and len(destination_mac) == 0:
            destination_mac = self.generate_random_mac_address()
        self.add_param_value(Param.MAC_DESTINATION, destination_mac)
        self.add_param_value(Param.PORT_DESTINATION, smb_port)

        # Attack configuration
        self.add_param_value(Param.PACKETS_PER_SECOND,
                             (self.statistics.get_pps_sent(most_used_ip_address) +
                              self.statistics.get_pps_received(most_used_ip_address)) / 2)
        self.add_param_value(Param.INJECT_AFTER_PACKET, randint(0, self.statistics.get_packet_count()))

    def generate_attack_pcap(self):


        # Timestamp
        timestamp_next_pkt = self.get_param_value(Param.INJECT_AT_TIMESTAMP)
        pps = self.get_param_value(Param.PACKETS_PER_SECOND)

        # calculate complement packet rates of BG traffic per interval
        complement_interval_pps = self.statistics.calculate_complement_packet_rates(pps)

        # Initialize parameters
        packets = []
        mac_source = self.get_param_value(Param.MAC_SOURCE)
        ip_source = self.get_param_value(Param.IP_SOURCE)
        port_source = self.get_param_value(Param.PORT_SOURCE)
        mac_destination = self.get_param_value(Param.MAC_DESTINATION)
        ip_destination = self.get_param_value(Param.IP_DESTINATION)
        port_destination = self.get_param_value(Param.PORT_DESTINATION)

        # Check ip.src == ip.dst
        self.ip_src_dst_equal_check(ip_source, ip_destination)

        path_attack_pcap = None

        # Set TTL based on TTL distribution of IP address
        source_ttl_dist = self.statistics.get_ttl_distribution(ip_source)
        if len(source_ttl_dist) > 0:
            source_ttl_prob_dict = Lea.fromValFreqsDict(source_ttl_dist)
            source_ttl_value = source_ttl_prob_dict.random()
        else:
            source_ttl_value = self.statistics.process_db_query("most_used(ttlValue)")

        destination_ttl_dist = self.statistics.get_ttl_distribution(ip_destination)
        if len(destination_ttl_dist) > 0:
            destination_ttl_prob_dict = Lea.fromValFreqsDict(destination_ttl_dist)
            destination_ttl_value = destination_ttl_prob_dict.random()
        else:
            destination_ttl_value = self.statistics.process_db_query("most_used(ttlValue)")

        # Set Window Size based on Window Size distribution of IP address
        source_win_dist = self.statistics.get_win_distribution(ip_source)
        if len(source_win_dist) > 0:
            source_win_prob_dict = Lea.fromValFreqsDict(source_win_dist)           
        else:
            source_win_dist =  self.statistics.get_win_distribution(self.statistics.get_most_used_ip_address())
            source_win_prob_dict = Lea.fromValFreqsDict(source_win_dist)
      
        destination_win_dist = self.statistics.get_win_distribution(ip_destination)
        if len(destination_win_dist) > 0:
            destination_win_prob_dict = Lea.fromValFreqsDict(destination_win_dist)
        else:
            destination_win_dist = self.statistics.get_win_distribution(self.statistics.get_most_used_ip_address())
            destination_win_prob_dict = Lea.fromValFreqsDict(destination_win_dist)

        # Set MSS (Maximum Segment Size) based on MSS distribution of IP address
        mss_value = self.statistics.process_db_query("most_used(mssValue)")
        if not mss_value:
            mss_value = 1465

        # Scan (MS17) for EternalBlue
        # Read Win7_eternalblue_scan pcap file
        orig_ip_dst = None
        exploit_raw_packets = RawPcapReader(self.template_scan_pcap_path)
        inter_arrival_times = self.get_inter_arrival_time(exploit_raw_packets)
        exploit_raw_packets = RawPcapReader(self.template_scan_pcap_path)

        source_origin_wins, destination_origin_wins = {}, {}

        for pkt_num, pkt in enumerate(exploit_raw_packets):
            eth_frame = Ether(pkt[0])
            ip_pkt = eth_frame.payload
            tcp_pkt = ip_pkt.payload

            if pkt_num == 0:
                if tcp_pkt.getfieldval("dport") == smb_port:
                    orig_ip_dst = ip_pkt.getfieldval("dst") # victim IP

            # Request
            if ip_pkt.getfieldval("dst") == orig_ip_dst: # victim IP
                # Ether
                eth_frame.setfieldval("src", mac_source)
                eth_frame.setfieldval("dst", mac_destination)
                # IP
                ip_pkt.setfieldval("src", ip_source)
                ip_pkt.setfieldval("dst", ip_destination)
                ip_pkt.setfieldval("ttl", source_ttl_value)
                # TCP
                tcp_pkt.setfieldval("sport",port_source)
                tcp_pkt.setfieldval("dport",port_destination)
                ## Window Size (mapping)
                source_origin_win = tcp_pkt.getfieldval("window")
                if source_origin_win not in source_origin_wins:
                    source_origin_wins[source_origin_win] = source_win_prob_dict.random()
                new_win = source_origin_wins[source_origin_win]
                tcp_pkt.setfieldval("window", new_win)
                ## MSS
                tcp_options = tcp_pkt.getfieldval("options")
                if tcp_options:
                    if tcp_options[0][0] == "MSS":
                        tcp_options [0] = ("MSS",mss_value)
                        tcp_pkt.setfieldval("options", tcp_options)

                new_pkt = (eth_frame / ip_pkt / tcp_pkt)
                new_pkt.time = timestamp_next_pkt

                pps = max(get_interval_pps(complement_interval_pps, timestamp_next_pkt), 10)
                timestamp_next_pkt = update_timestamp(timestamp_next_pkt, pps) + inter_arrival_times[pkt_num]#float(timeSteps.random())
            # Reply
            else:
                # Ether
                eth_frame.setfieldval("src", mac_destination)
                eth_frame.setfieldval("dst", mac_source)
                # IP
                ip_pkt.setfieldval("src", ip_destination)
                ip_pkt.setfieldval("dst", ip_source)
                ip_pkt.setfieldval("ttl", destination_ttl_value)
                # TCP
                tcp_pkt.setfieldval("dport", port_source)
                tcp_pkt.setfieldval("sport",port_destination)
                ## Window Size
                destination_origin_win = tcp_pkt.getfieldval("window")
                if destination_origin_win not in destination_origin_wins:
                    destination_origin_wins[destination_origin_win] = destination_win_prob_dict.random()
                new_win = destination_origin_wins[destination_origin_win]
                tcp_pkt.setfieldval("window", new_win)
                ## MSS
                tcp_options = tcp_pkt.getfieldval("options")
                if tcp_options:
                    if tcp_options[0][0] == "MSS":
                        tcp_options[0] = ("MSS", mss_value)
                        tcp_pkt.setfieldval("options", tcp_options)

                new_pkt = (eth_frame / ip_pkt / tcp_pkt)
                timestamp_next_pkt = update_timestamp(timestamp_next_pkt, pps) + inter_arrival_times[pkt_num]#+ float(timeSteps.random())
                new_pkt.time = timestamp_next_pkt

            packets.append(new_pkt)


        # Inject EternalBlue exploit packets
        # Read Win7_eternalblue_exploit pcap file
        exploit_raw_packets = RawPcapReader(self.template_attack_pcap_path)

        port_source = randint(self.minDefaultPort,self.maxDefaultPort) # experiments show this range of ports
        # conversations = {(ip.src, ip.dst, port.src, port.dst): packets}
        conversations, orderList_conversations = self.packetsToConvs(exploit_raw_packets)

        conv_start_timesamp = timestamp_next_pkt
        for conv_index, conv in enumerate(orderList_conversations):
            conv_start_timesamp = conv_start_timesamp + uniform(0.001,0.01) # the distance between the starts of the converstaions
            timestamp_next_pkt = conv_start_timesamp

            conv_pkts = conversations[conv]
            inter_arrival_times = self.get_inter_arrival_time(conv_pkts)

            if conv_index == len(orderList_conversations) - 2:  # Not the last conversation
                timestamp_next_pkt = packets[-1].time + uniform(0.001,0.01)

            if conv_index != len(orderList_conversations)-1: # Not the last conversation
                port_source += 2
                for pkt_num, pkt in enumerate(conv_pkts):
                    eth_frame = Ether(pkt[0])
                    ip_pkt = eth_frame.payload
                    tcp_pkt = ip_pkt.payload

                    if pkt_num == 0:
                        if tcp_pkt.getfieldval("dport") == smb_port:
                            orig_ip_dst = ip_pkt.getfieldval("dst")

                    # Request
                    if ip_pkt.getfieldval("dst") == orig_ip_dst: # victim IP
                        # Ether
                        eth_frame.setfieldval("src", mac_source)
                        eth_frame.setfieldval("dst", mac_destination)
                        # IP
                        ip_pkt.setfieldval("src", ip_source)
                        ip_pkt.setfieldval("dst", ip_destination)
                        ip_pkt.setfieldval("ttl", source_ttl_value)
                        # TCP
                        tcp_pkt.setfieldval("sport", port_source)
                        tcp_pkt.setfieldval("dport", port_destination)
                        ## Window Size
                        source_origin_win = tcp_pkt.getfieldval("window")
                        if source_origin_win not in source_origin_wins:
                            source_origin_wins[source_origin_win] = source_win_prob_dict.random()
                        new_win = source_origin_wins[source_origin_win]
                        tcp_pkt.setfieldval("window", new_win)
                        ## MSS
                        tcp_options = tcp_pkt.getfieldval("options")
                        if tcp_options:
                            if tcp_options[0][0] == "MSS":
                                tcp_options[0] = ("MSS", mss_value)
                                tcp_pkt.setfieldval("options", tcp_options)

                        new_pkt = (eth_frame / ip_pkt / tcp_pkt)
                        new_pkt.time = timestamp_next_pkt

                        pps = max(get_interval_pps(complement_interval_pps, timestamp_next_pkt), 10)
                        timestamp_next_pkt = update_timestamp(timestamp_next_pkt, pps) + inter_arrival_times[pkt_num] #float(timeSteps.random())

                    # Reply
                    else:
                        # Ether
                        eth_frame.setfieldval("src", mac_destination)
                        eth_frame.setfieldval("dst", mac_source)
                        # IP
                        ip_pkt.setfieldval("src", ip_destination)
                        ip_pkt.setfieldval("dst", ip_source)
                        ip_pkt.setfieldval("ttl", destination_ttl_value)
                        # TCP
                        tcp_pkt.setfieldval("dport", port_source)
                        tcp_pkt.setfieldval("sport", port_destination)
                        ## Window Size
                        destination_origin_win = tcp_pkt.getfieldval("window")
                        if destination_origin_win not in destination_origin_wins:
                            destination_origin_wins[destination_origin_win] = destination_win_prob_dict.random()
                        new_win = destination_origin_wins[destination_origin_win]
                        tcp_pkt.setfieldval("window", new_win)
                        ## MSS
                        tcp_options = tcp_pkt.getfieldval("options")
                        if tcp_options:
                            if tcp_options[0][0] == "MSS":
                                tcp_options[0] = ("MSS", mss_value)
                                tcp_pkt.setfieldval("options", tcp_options)

                        new_pkt = (eth_frame / ip_pkt / tcp_pkt)

                        pps = max(get_interval_pps(complement_interval_pps, timestamp_next_pkt), 10)
                        timestamp_next_pkt = update_timestamp(timestamp_next_pkt, pps) + inter_arrival_times[pkt_num]#float(timeSteps.random())

                        new_pkt.time = timestamp_next_pkt

                    packets.append(new_pkt)

            else: # Last conversation where the victim start a connection with the attacker
                timestamp_next_pkt = packets[-1].time + uniform(0.001, 0.01)
                port_source = randint(self.minDefaultPort,self.maxDefaultPort)
                for pkt_num, pkt in enumerate(conv_pkts):
                    eth_frame = Ether(pkt[0])
                    ip_pkt = eth_frame.payload
                    tcp_pkt = ip_pkt.payload

                    # Request
                    if tcp_pkt.getfieldval("dport") == self.last_conn_dst_port:
                        # Ether
                        eth_frame.setfieldval("src", mac_destination)
                        eth_frame.setfieldval("dst", mac_source)
                        # IP
                        ip_pkt.setfieldval("src", ip_destination)
                        ip_pkt.setfieldval("dst", ip_source)
                        ip_pkt.setfieldval("ttl", destination_ttl_value)
                        # TCP
                        tcp_pkt.setfieldval("sport", port_source)
                        # destination port is fixed 4444
                        ## Window Size
                        destination_origin_win = tcp_pkt.getfieldval("window")
                        if destination_origin_win not in destination_origin_wins:
                            destination_origin_wins[destination_origin_win] = destination_win_prob_dict.random()
                        new_win = destination_origin_wins[destination_origin_win]
                        tcp_pkt.setfieldval("window", new_win)
                        ## MSS
                        tcp_options = tcp_pkt.getfieldval("options")
                        if tcp_options:
                            if tcp_options[0][0] == "MSS":
                                tcp_options[0] = ("MSS", mss_value)
                                tcp_pkt.setfieldval("options", tcp_options)

                        new_pkt = (eth_frame / ip_pkt / tcp_pkt)
                        new_pkt.time = timestamp_next_pkt

                        pps = max(get_interval_pps(complement_interval_pps, timestamp_next_pkt), 10)
                        timestamp_next_pkt = update_timestamp(timestamp_next_pkt, pps) + inter_arrival_times[pkt_num]# float(timeSteps.random())

                    # Reply
                    else:
                        # Ether
                        eth_frame.setfieldval("src", mac_source)
                        eth_frame.setfieldval("dst", mac_destination)
                        # IP
                        ip_pkt.setfieldval("src", ip_source)
                        ip_pkt.setfieldval("dst", ip_destination)
                        ip_pkt.setfieldval("ttl", source_ttl_value)
                        # TCP
                        tcp_pkt.setfieldval("dport", port_source)
                        # source port is fixed 4444
                        ## Window Size
                        source_origin_win = tcp_pkt.getfieldval("window")
                        if source_origin_win not in source_origin_wins:
                            source_origin_wins[source_origin_win] = source_win_prob_dict.random()
                        new_win = source_origin_wins[source_origin_win]
                        tcp_pkt.setfieldval("window", new_win)
                        ## MSS
                        tcp_options = tcp_pkt.getfieldval("options")
                        if tcp_options:
                            if tcp_options[0][0] == "MSS":
                                tcp_options[0] = ("MSS", mss_value)
                                tcp_pkt.setfieldval("options", tcp_options)

                        new_pkt = (eth_frame / ip_pkt / tcp_pkt)

                        pps = max(get_interval_pps(complement_interval_pps, timestamp_next_pkt), 10)
                        timestamp_next_pkt = update_timestamp(timestamp_next_pkt, pps) + inter_arrival_times[pkt_num]# float(timeSteps.random())

                        new_pkt.time = timestamp_next_pkt

                    packets.append(new_pkt)

        # Store timestamp of first packet (for attack label)
        self.attack_start_utime = packets[0].time
        self.attack_end_utime = packets[-1].time

        if len(packets) > 0:
            packets = sorted(packets, key=lambda pkt: pkt.time)
            path_attack_pcap = self.write_attack_pcap(packets, True, path_attack_pcap)

        # return packets sorted by packet time_sec_start
        # pkt_num+1: because pkt_num starts at 0
        return pkt_num + 1, path_attack_pcap