# Created by Aidmar

import logging
import math
from operator import itemgetter
import operator
from random import randint, uniform

from lea import Lea

from Attack import BaseAttack
from Attack.AttackParameters import Parameter as Param
from Attack.AttackParameters import ParameterTypes

logging.getLogger("scapy.runtime").setLevel(logging.ERROR)
# noinspection PyPep8
from scapy.utils import RawPcapReader
from scapy.layers.inet import IP, Ether, TCP, RandShort

class EternalBlueExploit(BaseAttack.BaseAttack):
    # Metasploit default packet rate
    maxDefaultPPS = 100
    minDefaultPPS = 5
    # SMB port
    smb_port = 445
    # Metasploit experiments show this range of ports
    minDefaultPort = 30000
    maxDefaultPort = 50000
    last_conn_dst_port = 4444

    def __init__(self, statistics, pcap_file_path):
        """
        Creates a new instance of the EternalBlue Exploit.

        :param statistics: A reference to the statistics class.
        """
        # Initialize attack
        super(EternalBlueExploit, self).__init__(statistics, "EternalBlue Exploit", "Injects an EternalBlue exploit'",
                                        "Resource Exhaustion")

        # Define allowed parameters and their type
        self.supported_params = {
            Param.MAC_SOURCE: ParameterTypes.TYPE_MAC_ADDRESS,
            Param.IP_SOURCE: ParameterTypes.TYPE_IP_ADDRESS,
            #Param.PORT_SOURCE: ParameterTypes.TYPE_PORT,
            Param.MAC_DESTINATION: ParameterTypes.TYPE_MAC_ADDRESS,
            Param.IP_DESTINATION: ParameterTypes.TYPE_IP_ADDRESS,
            Param.INJECT_AT_TIMESTAMP: ParameterTypes.TYPE_FLOAT,
            Param.INJECT_AFTER_PACKET: ParameterTypes.TYPE_PACKET_POSITION,
            Param.PACKETS_PER_SECOND: ParameterTypes.TYPE_FLOAT
        }

        # PARAMETERS: initialize with default utilsvalues
        # (values are overwritten if user specifies them)
        most_used_ip_address = self.statistics.get_most_used_ip_address()
        if isinstance(most_used_ip_address, list):
            most_used_ip_address = most_used_ip_address[0]
        self.add_param_value(Param.IP_SOURCE, most_used_ip_address)
        self.add_param_value(Param.MAC_SOURCE, self.statistics.get_mac_address(most_used_ip_address))
        self.add_param_value(Param.INJECT_AFTER_PACKET, randint(0, self.statistics.get_packet_count()))
        #self.add_param_value(Param.PORT_SOURCE, str(RandShort()))
        self.add_param_value(Param.PACKETS_PER_SECOND,self.maxDefaultPPS)

        # victim configuration
        # TO-DO: confirm that ip.dst uses Win OS
        random_ip_address = self.statistics.get_random_ip_address()
        self.add_param_value(Param.IP_DESTINATION, random_ip_address)

        destination_mac = self.statistics.get_mac_address(random_ip_address)
        if isinstance(destination_mac, list) and len(destination_mac) == 0:
            destination_mac = self.generate_random_mac_address()
        self.add_param_value(Param.MAC_DESTINATION, destination_mac)

    def generate_attack_pcap(self):
        def update_timestamp(timestamp, pps, maxdelay):
            """
            Calculates the next timestamp to be used based on the packet per second rate (pps) and the maximum delay.

            :return: Timestamp to be used for the next packet.
            """
            return timestamp + uniform(1 / pps, maxdelay)

        # Aidmar
        def getIntervalPPS(complement_interval_pps, timestamp):
            """
            Gets the packet rate (pps) in specific time interval.

            :return: the corresponding packet rate for packet rate (pps) .
            """
            for row in complement_interval_pps:
                if timestamp<=row[0]:
                    return row[1]
            return complement_interval_pps[-1][1] # in case the timstamp > capture max timestamp

        # Timestamp
        timestamp_next_pkt = self.get_param_value(Param.INJECT_AT_TIMESTAMP)
        # TO-DO: find better pkt rate
        pps = self.get_param_value(Param.PACKETS_PER_SECOND)
        randomdelay = Lea.fromValFreqsDict({1 / pps: 70, 2 / pps: 30, 5 / pps: 15, 10 / pps: 3})

        # Aidmar - calculate complement packet rates of BG traffic per interval
        complement_interval_pps = self.statistics.calculate_complement_packet_rates(pps)

        # Initialize parameters
        packets = []
        mac_source = self.get_param_value(Param.MAC_SOURCE)
        ip_source = self.get_param_value(Param.IP_SOURCE)
        #port_source = self.get_param_value(Param.PORT_SOURCE)
        mac_destination = self.get_param_value(Param.MAC_DESTINATION)
        ip_destination = self.get_param_value(Param.IP_DESTINATION)

        # Aidmar - check ip.src == ip.dst
        if ip_source == ip_destination:
            print("\nERROR: Invalid IP addresses; source IP is the same as destination IP: " + ip_source + ".")
            import sys
            sys.exit(0)

        path_attack_pcap = None
        replayDelay = self.get_reply_delay(ip_destination)

        # Scan (MS17) for EternalBlue
        # Read Win7_eternalblue_scan_vulnerable pcap file
        orig_ip_dst = None
        exploit_raw_packets = RawPcapReader("Win7_eternalblue_scan.pcap")

        port_source = randint(self.minDefaultPort,self.maxDefaultPort) # experiments show this range of ports

        for pkt_num, pkt in enumerate(exploit_raw_packets):
            eth_frame = Ether(pkt[0])
            ip_pkt = eth_frame.payload
            tcp_pkt = ip_pkt.payload

            if pkt_num == 0:
                if tcp_pkt.getfieldval("dport") == self.smb_port:
                    orig_ip_dst = ip_pkt.getfieldval("dst") # victim IP

            # Request
            if ip_pkt.getfieldval("dst") == orig_ip_dst: # victim IP
                # Ether
                eth_frame.setfieldval("src", mac_source)
                eth_frame.setfieldval("dst", mac_destination)
                # IP
                ip_pkt.setfieldval("src", ip_source)
                ip_pkt.setfieldval("dst", ip_destination)
                # TCP
                tcp_pkt.setfieldval("sport",port_source)

                new_pkt = (eth_frame / ip_pkt / tcp_pkt)
                new_pkt.time = timestamp_next_pkt

                maxdelay = randomdelay.random()
                pps = self.minDefaultPPS if getIntervalPPS(complement_interval_pps,timestamp_next_pkt) is None else max(
                    getIntervalPPS(complement_interval_pps,timestamp_next_pkt), self.minDefaultPPS)
                timestamp_next_pkt = update_timestamp(timestamp_next_pkt, pps, maxdelay)
            # Reply
            else:
                # Ether
                eth_frame.setfieldval("src", mac_destination)
                eth_frame.setfieldval("dst", mac_source)
                # IP
                ip_pkt.setfieldval("src", ip_destination)
                ip_pkt.setfieldval("dst", ip_source)
                # TCP
                tcp_pkt.setfieldval("dport", port_source)

                new_pkt = (eth_frame / ip_pkt / tcp_pkt)
                timestamp_next_pkt = timestamp_next_pkt + uniform(replayDelay, 2 * replayDelay)
                new_pkt.time = timestamp_next_pkt

            packets.append(new_pkt)


        # Inject EternalBlue exploit packets
        # Read Win7_eternalblue_exploit pcap file
        exploit_raw_packets = RawPcapReader("Win7_eternalblue_exploit.pcap")

        # Group the packets in conversations
        def packetsToConvs(exploit_raw_packets):
            conversations = {}
            orderList_conversations = []
            for pkt_num, pkt in enumerate(exploit_raw_packets):
                eth_frame = Ether(pkt[0])

                ip_pkt = eth_frame.payload
                ip_dst = ip_pkt.getfieldval("dst")
                ip_src = ip_pkt.getfieldval("src")

                tcp_pkt = ip_pkt.payload
                port_dst = tcp_pkt.getfieldval("dport")
                port_src = tcp_pkt.getfieldval("sport")

                conv_req = (ip_src, port_src, ip_dst, port_dst)
                conv_rep = (ip_dst, port_dst, ip_src, port_src)
                if conv_req not in conversations and conv_rep not in conversations:
                    pktList = [pkt]
                    conversations[conv_req] = pktList
                    # Order list of conv
                    orderList_conversations.append(conv_req)
                else:
                    if conv_req in conversations:
                        pktList = conversations[conv_req]
                        pktList.append(pkt)
                        conversations[conv_req] = pktList
                    else:
                        pktList = conversations[conv_rep]
                        pktList.append(pkt)
                        conversations[conv_rep] = pktList
            return (conversations,orderList_conversations)

        port_source = randint(self.minDefaultPort,self.maxDefaultPort) # experiments show this range of ports
        # conversations = {(ip.src, ip.dst, port.src, port.dst): packets}
        temp_tuple = packetsToConvs(exploit_raw_packets)
        conversations = temp_tuple[0]
        orderList_conversations = temp_tuple[1]

        for conv_index, conv in enumerate(orderList_conversations):
            conv_pkts = conversations[conv]
            if conv_index != len(orderList_conversations)-1: # Not the last conversation
                port_source += 2
                for pkt_num, pkt in enumerate(conv_pkts):
                    eth_frame = Ether(pkt[0])
                    ip_pkt = eth_frame.payload
                    tcp_pkt = ip_pkt.payload

                    if pkt_num == 0:
                        if tcp_pkt.getfieldval("dport") == self.smb_port:
                            orig_ip_dst = ip_pkt.getfieldval("dst")


                    # defining req/rep should be adapted to fit the last converstaion where
                    # victim start a connection with the attacker
                    # Request
                    if ip_pkt.getfieldval("dst") == orig_ip_dst: # victim IP
                        # Ether
                        eth_frame.setfieldval("src", mac_source)
                        eth_frame.setfieldval("dst", mac_destination)
                        # IP
                        ip_pkt.setfieldval("src", ip_source)
                        ip_pkt.setfieldval("dst", ip_destination)
                        # TCP
                        tcp_pkt.setfieldval("sport", port_source)
                        new_pkt = (eth_frame / ip_pkt / tcp_pkt)
                        # TO-DO: reply should have different timestamp delay
                        new_pkt.time = timestamp_next_pkt

                        maxdelay = randomdelay.random()
                        pps = self.minDefaultPPS if getIntervalPPS(complement_interval_pps, timestamp_next_pkt) is None else max(
                            getIntervalPPS(complement_interval_pps, timestamp_next_pkt), self.minDefaultPPS)
                        timestamp_next_pkt = update_timestamp(timestamp_next_pkt, pps, maxdelay)
                        # Not perfect timestamp
                        #req_time = req_time + randomDelay ||  rep_time + randomDelay

                    # Reply
                    else:
                        # Ether
                        eth_frame.setfieldval("src", mac_destination)
                        eth_frame.setfieldval("dst", mac_source)
                        # IP
                        ip_pkt.setfieldval("src", ip_destination)
                        ip_pkt.setfieldval("dst", ip_source)
                        # TCP
                        tcp_pkt.setfieldval("dport", port_source)
                        new_pkt = (eth_frame / ip_pkt / tcp_pkt)
                        timestamp_next_pkt = timestamp_next_pkt + uniform(replayDelay, 2 * replayDelay)
                        new_pkt.time = timestamp_next_pkt
                        # Not perfect timestamp
                        # rep_time = req_time + replayDelay

                    packets.append(new_pkt)

            else: # Last conversation where the victim start a connection with the attacker
                port_source = randint(self.minDefaultPort,self.maxDefaultPort)
                for pkt_num, pkt in enumerate(conv_pkts):
                    eth_frame = Ether(pkt[0])
                    ip_pkt = eth_frame.payload
                    tcp_pkt = ip_pkt.payload

                    # defining req/rep should be adapted to fit the last converstaion where
                    # victim start a connection with the attacker
                    # Request
                    if tcp_pkt.getfieldval("dport") == self.last_conn_dst_port:
                        # Ether
                        eth_frame.setfieldval("src", mac_destination)
                        eth_frame.setfieldval("dst", mac_source)
                        # IP
                        ip_pkt.setfieldval("src", ip_destination)
                        ip_pkt.setfieldval("dst", ip_source)
                        # TCP
                        tcp_pkt.setfieldval("sport", port_source)
                        new_pkt = (eth_frame / ip_pkt / tcp_pkt)
                        # TO-DO: reply should have different timestamp delay
                        new_pkt.time = timestamp_next_pkt

                        maxdelay = randomdelay.random()
                        pps = self.minDefaultPPS if getIntervalPPS(complement_interval_pps, timestamp_next_pkt) is None else max(
                            getIntervalPPS(complement_interval_pps, timestamp_next_pkt), self.minDefaultPPS)
                        timestamp_next_pkt = update_timestamp(timestamp_next_pkt, pps, maxdelay)
                        # Not perfect timestamp
                        # req_time = req_time + randomDelay ||  rep_time + randomDelay

                    # Reply
                    else:
                        # Ether
                        eth_frame.setfieldval("src", mac_source)
                        eth_frame.setfieldval("dst", mac_destination)
                        # IP
                        ip_pkt.setfieldval("src", ip_source)
                        ip_pkt.setfieldval("dst", ip_destination)
                        # TCP
                        tcp_pkt.setfieldval("dport", port_source)
                        new_pkt = (eth_frame / ip_pkt / tcp_pkt)
                        timestamp_next_pkt = timestamp_next_pkt + uniform(replayDelay, 2 * replayDelay)
                        new_pkt.time = timestamp_next_pkt
                        # Not perfect timestamp
                        # rep_time = req_time + replayDelay

                    packets.append(new_pkt)


        # Store timestamp of first packet (for attack label)
        self.attack_start_utime = packets[0].time
        self.attack_end_utime = packets[-1].time

        if len(packets) > 0:
            packets = sorted(packets, key=lambda pkt: pkt.time)
            path_attack_pcap = self.write_attack_pcap(packets, True, path_attack_pcap)

        # return packets sorted by packet time_sec_start
        # pkt_num+1: because pkt_num starts at 0
        return pkt_num + 1, path_attack_pcap