from enum import Enum
from random import randint, randrange, choice, uniform
from collections import deque
from scipy.stats import gamma
from lea import Lea
from datetime import datetime
import os

from Attack import BaseAttack
from Attack.AttackParameters import Parameter as Param
from Attack.AttackParameters import ParameterTypes
# from ID2TLib import PcapFile
# from ID2TLib.PcapFile import PcapFile

class MessageType(Enum):
    """
    Defines possible botnet message types
    """

    TIMEOUT = 3
    SALITY_NL_REQUEST = 101
    SALITY_NL_REPLY = 102
    SALITY_HELLO = 103
    SALITY_HELLO_REPLY = 104

    def is_request(mtype):
        return mtype in {MessageType.SALITY_HELLO, MessageType.SALITY_NL_REQUEST}

    def is_response(mtype):
        return mtype in {MessageType.SALITY_HELLO_REPLY, MessageType.SALITY_NL_REPLY}

class Message():
    INVALID_LINENO = -1

    """
    Defines a compact message type that contains all necessary information.
    """
    def __init__(self, msg_id: int, src, dst, type_: MessageType, time: float, refer_msg_id: int=-1, line_no = -1):
        """
        Constructs a message with the given parameters.

        :param msg_id: the ID of the message
        :param src: something identifiying the source, e.g. ID or configuration
        :param dst: something identifiying the destination, e.g. ID or configuration
        :param type_: the type of the message
        :param time: the timestamp of the message
        :param refer_msg_id: the ID this message is a request for or reply to. -1 if there is no related message.
        :param line_no: The line number this message appeared in the original file
        """
        self.msg_id = msg_id
        self.src = src
        self.dst = dst
        self.type = type_
        self.time = time
        self.refer_msg_id = refer_msg_id
        # if similar fields to line_no should be added consider a separate class
        self.line_no = line_no

    def __str__(self):
        str_ = "{0}. at {1}: {2}-->{3}, {4}, refer:{5}".format(self.msg_id, self.time, self.src, self.dst, self.type, self.refer_msg_id)
        return str_


from ID2TLib import FileUtils, Generator
from ID2TLib.IPv4 import IPAddress
from ID2TLib.PcapAddressOperations import PcapAddressOperations
from ID2TLib.CommunicationProcessor import CommunicationProcessor
from ID2TLib.Botnet.MessageMapping import MessageMapping
from ID2TLib.PcapFile import PcapFile
from ID2TLib.Statistics import Statistics


class MembersMgmtCommAttack(BaseAttack.BaseAttack):
    def __init__(self):
        """
        Creates a new instance of the Membership Management Communication.

        """
        # Initialize communication
        super(MembersMgmtCommAttack, self).__init__("Membership Management Communication Attack (MembersMgmtCommAttack)",
                                        "Injects Membership Management Communication", "Botnet communication")

        # Define allowed parameters and their type
        self.supported_params = {
            # parameters regarding attack
            Param.INJECT_AT_TIMESTAMP: ParameterTypes.TYPE_FLOAT,
            Param.INJECT_AFTER_PACKET: ParameterTypes.TYPE_PACKET_POSITION,
            Param.PACKETS_PER_SECOND: ParameterTypes.TYPE_FLOAT,
            Param.PACKETS_LIMIT: ParameterTypes.TYPE_INTEGER_POSITIVE,
            Param.ATTACK_DURATION: ParameterTypes.TYPE_INTEGER_POSITIVE,

            # use num_attackers to specify number of communicating devices?
            Param.NUMBER_INITIATOR_BOTS: ParameterTypes.TYPE_INTEGER_POSITIVE,

            # input file containing botnet communication
            Param.FILE_CSV: ParameterTypes.TYPE_FILEPATH,
            Param.FILE_XML: ParameterTypes.TYPE_FILEPATH,

            # the percentage of IP reuse (if total and other is specified, percentages are multiplied)
            Param.IP_REUSE_TOTAL: ParameterTypes.TYPE_PERCENTAGE,
            Param.IP_REUSE_LOCAL: ParameterTypes.TYPE_PERCENTAGE,
            Param.IP_REUSE_EXTERNAL: ParameterTypes.TYPE_PERCENTAGE,

            # the user-selected padding to add to every packet
            Param.PACKET_PADDING: ParameterTypes.TYPE_PADDING,

            # presence of NAT at the gateway of the network
            Param.NAT_PRESENT: ParameterTypes.TYPE_BOOLEAN
        }

        # create dict with MessageType values for fast name lookup
        self.msg_types = {}
        for msg_type in MessageType:
            self.msg_types[msg_type.value] = msg_type

    def init_params(self):
        """
        Initialize some parameters of this communication-attack using the user supplied command line parameters.
        The remaining parameters are implicitly set in the provided data file. Note: the timestamps in the file
        have to be sorted in ascending order

        :param statistics: Reference to a statistics object.
        """
        # set class constants
        self.DEFAULT_XML_PATH = "resources/MembersMgmtComm_example.xml"
        # probability for responder ID to be local if comm_type is mixed
        self.PROB_RESPND_IS_LOCAL = 0

        # PARAMETERS: initialize with default values
        # (values are overwritten if user specifies them)
        self.add_param_value(Param.INJECT_AFTER_PACKET, randint(1, int(self.statistics.get_packet_count()/5)))

        self.add_param_value(Param.PACKETS_PER_SECOND, 0)
        self.add_param_value(Param.FILE_XML, self.DEFAULT_XML_PATH)

        # Alternatively new attack parameter?
        duration = int(float(self._get_capture_duration()))
        self.add_param_value(Param.ATTACK_DURATION, duration)
        self.add_param_value(Param.NUMBER_INITIATOR_BOTS, 1)
        # NAT on by default
        self.add_param_value(Param.NAT_PRESENT, True)

        # default locality behavior
        # self.add_param_value(Param.COMM_TYPE, "mixed")
        # TODO: change 1 to something better
        self.add_param_value(Param.IP_REUSE_TOTAL, 1)
        self.add_param_value(Param.IP_REUSE_LOCAL, 0.5)
        self.add_param_value(Param.IP_REUSE_EXTERNAL, 0.5)

        # add default additional padding
        self.add_param_value(Param.PACKET_PADDING, 20)



    def generate_attack_pcap(self, context):
        # create the final messages that have to be sent, including all bot configurations
        messages = self._create_messages(context)

        if messages == []:
            return 0, []

        # Setup (initial) parameters for packet creation loop
        BUFFER_SIZE = 1000
        pkt_gen = Generator.PacketGenerator()
        padding = self.get_param_value(Param.PACKET_PADDING)
        packets = deque(maxlen=BUFFER_SIZE)
        total_pkts = 0
        limit_packetcount = self.get_param_value(Param.PACKETS_LIMIT)
        limit_duration = self.get_param_value(Param.ATTACK_DURATION)
        path_attack_pcap = None

        msg_packet_mapping = MessageMapping(messages)


        # create packets to write to PCAP file
        for msg in messages:
            # retrieve the source and destination configurations
            id_src, id_dst = msg.src["ID"], msg.dst["ID"]
            ip_src, ip_dst = msg.src["IP"], msg.dst["IP"]
            mac_src, mac_dst = msg.src["MAC"], msg.dst["MAC"]
            port_src, port_dst = msg.src["Port"], msg.dst["Port"]
            ttl = msg.src["TTL"]

            # update duration
            duration = msg.time - messages[0].time

            # if total number of packets has been sent or the attack duration has been exceeded, stop
            if ((limit_packetcount is not None and total_pkts >= limit_packetcount) or
                    (limit_duration is not None and duration >= limit_duration)):
                break

            # if the type of the message is a NL reply, determine the number of entries
            nl_size = 0
            if msg.type == MessageType.SALITY_NL_REPLY:
                nl_size = randint(1, 25)    # what is max NL entries?

            # create suitable IP/UDP packet and add to packets list
            packet = pkt_gen.generate_mmcom_packet(ip_src=ip_src, ip_dst=ip_dst, ttl=ttl, mac_src=mac_src, mac_dst=mac_dst,
                port_src=port_src, port_dst=port_dst, message_type=msg.type, neighborlist_entries=nl_size)
            Generator.add_padding(packet, padding,True, True)

            packet.time = msg.time
            packets.append(packet)
            msg_packet_mapping.map_message(msg, packet)
            total_pkts += 1

            # Store timestamp of first packet (for attack label)
            if total_pkts <= 1:
                self.attack_start_utime = packets[0].time
            elif total_pkts % BUFFER_SIZE == 0: # every 1000 packets write them to the PCAP file (append)
                packets = list(packets)
                Generator.equal_length(packets, padding = padding)
                last_packet = packets[-1]
                path_attack_pcap = self.write_attack_pcap(packets, True, path_attack_pcap)
                packets = deque(maxlen=BUFFER_SIZE)

        # if there are unwritten packets remaining, write them to the PCAP file
        if len(packets) > 0:
            packets = list(packets)
            Generator.equal_length(packets, padding = padding)
            path_attack_pcap = self.write_attack_pcap(packets, True, path_attack_pcap)
            last_packet = packets[-1]

        # write the mapping to a file
        msg_packet_mapping.write_to(context.allocate_file("_mapping.xml"))

        # Store timestamp of last packet
        self.attack_end_utime = last_packet.time

        # Return packets sorted by packet by timestamp and total number of packets (sent)
        return total_pkts , path_attack_pcap


    def _create_messages(self, context):
        def add_ids_to_config(ids_to_add: list, existing_ips: list, new_ips: list, bot_configs: dict, idtype:str="local", router_mac:str=""):
            """
            Creates IP and MAC configurations for the given IDs and adds them to the existing configurations object.

            :param ids_to_add: all sorted IDs that have to be configured and added
            :param existing_ips: the existing IPs in the PCAP file that should be assigned to some, or all, IDs
            :param new_ips: the newly generated IPs that should be assigned to some, or all, IDs
            :param bot_configs: the existing configurations for the bots
            :param idtype: the locality type of the IDs
            :param router_mac: the MAC address of the router in the PCAP
            """

            ids = ids_to_add.copy()
            # macgen only needed, when IPs are new local IPs (therefore creating the object here suffices for the current callers
            # to not end up with the same MAC paired with different IPs)
            macgen = Generator.MacAddressGenerator()

            # assign existing IPs and the corresponding MAC addresses in the PCAP to the IDs
            for ip in existing_ips:
                random_id = choice(ids)
                mac = self.statistics.process_db_query("macAddress(IPAddress=%s)" % ip)
                bot_configs[random_id] = {"Type": idtype, "IP": ip, "MAC": mac}
                ids.remove(random_id)

            # assign new IPs and for local IPs new MACs or for external IPs the router MAC to the IDs
            for ip in new_ips:
                random_id = choice(ids)
                if idtype == "local":
                    mac = macgen.random_mac()
                elif idtype == "external":
                    mac = router_mac
                bot_configs[random_id] = {"Type": idtype, "IP": ip, "MAC": mac}
                ids.remove(random_id)

        def index_increment(number: int, max: int):
            """
            Number increment with rollover.
            """
            if number + 1 < max:
                return number + 1
            else:
                return 0

        def assign_realistic_ttls(bot_configs):
            '''
            Assigns a realisitic ttl to each bot from @param: bot_configs. Uses statistics and distribution to be able
            to calculate a realisitc ttl.
            :param bot_configs:
            :return:
            '''
            ids = sorted(bot_configs.keys())
            for pos,bot in enumerate(ids):
                bot_type = bot_configs[bot]["Type"]
                # print(bot_type)
                if(bot_type == "local"): # Set fix TTL for local Bots
                    bot_configs[bot]["TTL"] = 128
                    # Set TTL based on TTL distribution of IP address
                else: # Set varying TTl for external Bots
                    bot_ttl_dist = self.statistics.get_ttl_distribution(bot_configs[bot]["IP"])
                    if len(bot_ttl_dist) > 0:
                         source_ttl_prob_dict = Lea.fromValFreqsDict(bot_ttl_dist)
                         bot_configs[bot]["TTL"] = source_ttl_prob_dict.random()
                    else:
                         bot_configs[bot]["TTL"] = self.statistics.process_db_query("most_used(ttlValue)")

        def assign_realworld_ttls(bot_configs):
            '''
            Assigns realistic ttl values to each bot from a realworld pcap file.
            :param bot_configs: the existing configurations for the bots
            '''

            # create a PcapFile
            pcap = PcapFile("resources/oc48-mfn.dirB.20030424-074500.UTC.anon.pcap")
            # create new instance of an Statistics Object
            stat = Statistics(pcap)
            # recalculate the statistic, because there doesn't exist one
            stat.load_pcap_statistics(False, True, False) # does not work! Why? Won't create DB
            bot_ttl_dist = stat.get_ttl_distribution("*")
            # assign local and external TTL randomly
            for pos,bot in enumerate(sorted(bot_configs.keys())):
                bot_type = bot_configs[bot]["Type"]
                if bot_type == "local":
                    bot_configs[bot]["TTL"] = 128
                else:
                    source_ttl_prob_dict = Lea.fromValFreqsDict(bot_ttl_dist)
                    bot_configs[bot]["TTL"] = source_ttl_prob_dict.random()


        def move_xml_to_outdir(filepath_xml: str):
            """
            Moves the XML file at filepath_xml to the output directory of the PCAP
            :param filepath_xml: the filepath to the XML file
            :return: the new filepath to the XML file
            """

            pcap_dir = context.get_output_dir()
            xml_name = os.path.basename(filepath_xml)
            if pcap_dir.endswith("/"):
                new_xml_path = pcap_dir + xml_name
            else:
                new_xml_path = pcap_dir + "/" + xml_name
            os.rename(filepath_xml, new_xml_path)
            context.add_other_created_file(new_xml_path)
            return new_xml_path

        # parse input CSV or XML
        filepath_xml = self.get_param_value(Param.FILE_XML)
        filepath_csv = self.get_param_value(Param.FILE_CSV)

        # prefer XML input over CSV input (in case both are given)
        if filepath_csv and filepath_xml == self.DEFAULT_XML_PATH:
            filepath_xml = FileUtils.parse_csv_to_xml(filepath_csv)
            filepath_xml = move_xml_to_outdir(filepath_xml)


        abstract_packets = FileUtils.parse_xml(filepath_xml)

        # find a good communication mapping in the input file that matches the users parameters
        duration = self.get_param_value(Param.ATTACK_DURATION)
        number_init_bots = self.get_param_value(Param.NUMBER_INITIATOR_BOTS)
        nat = self.get_param_value(Param.NAT_PRESENT)
        comm_proc = CommunicationProcessor(abstract_packets, self.msg_types, nat)

        comm_intervals = comm_proc.find_interval_most_comm(number_init_bots, duration)
        if comm_intervals == []:
            print("Error: There is no interval in the given CSV/XML that has enough communication initiating bots.")
            return []
        comm_interval = comm_intervals[randrange(0, len(comm_intervals))]

        # retrieve the mapping information
        mapped_ids, packet_start_idx, packet_end_idx = comm_interval["IDs"], comm_interval["Start"], comm_interval["End"]
        # print(mapped_ids)
        while len(mapped_ids) > number_init_bots:
            rm_idx = randrange(0, len(mapped_ids))
            del mapped_ids[rm_idx]

        # assign the communication processor this mapping for further processing
        comm_proc.set_mapping(abstract_packets[packet_start_idx:packet_end_idx+1], mapped_ids)
        # print start and end time of mapped interval
        # print(abstract_packets[packet_start_idx]["Time"])
        # print(abstract_packets[packet_end_idx]["Time"])
        # print(mapped_ids)

        # determine number of reused local and external IPs
        reuse_percent_total = self.get_param_value(Param.IP_REUSE_TOTAL)
        reuse_percent_external = self.get_param_value(Param.IP_REUSE_EXTERNAL)
        reuse_percent_local = self.get_param_value(Param.IP_REUSE_LOCAL)
        reuse_count_external = int(reuse_percent_total * reuse_percent_external * len(mapped_ids))
        reuse_count_local = int(reuse_percent_total * reuse_percent_local * len(mapped_ids))

        # create locality, IP and MAC configurations for the IDs/Bots
        ipgen = Generator.IPGenerator()
        pcapops = PcapAddressOperations(self.statistics)
        router_mac = pcapops.get_probable_router_mac()
        bot_configs = {}
        # determine the roles of the IDs in the mapping communication-{initiator, responder}
        local_init_ids, external_init_ids, respnd_ids, messages = comm_proc.det_id_roles_and_msgs()
        # use these roles to determine which IDs are to be local and which external
        local_ids, external_ids = comm_proc.det_ext_and_local_ids()

        # retrieve and assign the IPs and MACs for the bots with respect to the given parameters
        # (IDs are always added to bot_configs in the same order under a given seed)
        number_local_ids, number_external_ids = len(local_ids), len(external_ids)
        # assign addresses for local IDs
        if number_local_ids > 0:
            reuse_count_local = int(reuse_percent_total * reuse_percent_local * number_local_ids)
            existing_local_ips = sorted(pcapops.get_existing_local_ips(reuse_count_local))
            new_local_ips = sorted(pcapops.get_new_local_ips(number_local_ids - len(existing_local_ips)))
            add_ids_to_config(sorted(local_ids), existing_local_ips, new_local_ips, bot_configs)

        # assign addresses for external IDs
        if number_external_ids > 0:
            reuse_count_external = int(reuse_percent_total * reuse_percent_external * number_external_ids)
            existing_external_ips = sorted(pcapops.get_existing_external_ips(reuse_count_external))
            remaining = len(external_ids) - len(existing_external_ips)
            new_external_ips = sorted([ipgen.random_ip() for _ in range(remaining)])
            add_ids_to_config(sorted(external_ids), existing_external_ips, new_external_ips, bot_configs, idtype="external", router_mac=router_mac)

        #### Set realistic timestamps for messages ####

        # this is the timestamp at which the first packet should be injected, the packets have to be shifted to the beginning of the
        # pcap file (INJECT_AT_TIMESTAMP) and then the offset of the packets have to be compensated to start at the given point in time
        zero_reference = self.get_param_value(Param.INJECT_AT_TIMESTAMP) - messages[0].time

        updated_msgs = []
        last_response = {}      # Dict, takes a tuple of 2 Bot_IDs as a key (requester, responder), returns the time of the last response, the requester received
                                # necessary in order to make sure, that additional requests are sent only after the response to the last one was received
        for msg in messages:    # init
            last_response[(msg.src, msg.dst)] = -1

        # calculate the average delay values for local and external responses
        avg_delay_local, avg_delay_external = self.statistics.get_avg_delay_local_ext()

        # update all timestamps
        for req_msg in messages:

            if(req_msg in updated_msgs):
                # message already updated
                continue

            # if req_msg.timestamp would be before the timestamp of the response to the last request, req_msg needs to be sent later (else branch)
            if last_response[(req_msg.src, req_msg.dst)] == -1 or last_response[(req_msg.src, req_msg.dst)] < (zero_reference + req_msg.time - 0.05):
                ## update req_msg timestamp with a variation of up to 50ms
                req_msg.time = zero_reference + req_msg.time + uniform(-0.05, 0.05)
                updated_msgs.append(req_msg)

            else:
                req_msg.time = last_response[(req_msg.src, req_msg.dst)] + 0.06 + uniform(-0.05, 0.05)

            # update response if necessary
            if req_msg.refer_msg_id != -1:
                respns_msg = messages[req_msg.refer_msg_id]

                # check for local or external communication and update response timestamp with the respective avg delay
                if req_msg.src in external_ids or req_msg.dst in external_ids:
                    #external communication
                    respns_msg.time = req_msg.time + avg_delay_external + uniform(-0.1*avg_delay_external, 0.1*avg_delay_external)
                
                else:
                    #local communication
                    respns_msg.time = req_msg.time + avg_delay_local + uniform(-0.1*avg_delay_local, 0.1*avg_delay_local)

                updated_msgs.append(respns_msg)
                last_response[(req_msg.src, req_msg.dst)] = respns_msg.time

        # create port configurations for the bots
        for bot in bot_configs:
            bot_configs[bot]["Port"] = Generator.gen_random_server_port()

        # print(local_init_ids)
        # print(bot_configs)

        # assign realistic TTL for every bot
        assign_realistic_ttls(bot_configs)
        # assign_realworld_ttls(bot_configs)

        # put together the final messages including the full sender and receiver
        # configurations (i.e. IP, MAC, port, ...) for easier later use
        final_messages = []
        messages = sorted(messages, key=lambda msg: msg.time)
        new_id = 0

        for msg in messages:
            type_src, type_dst = bot_configs[msg.src]["Type"], bot_configs[msg.dst]["Type"]
            id_src, id_dst = msg.src, msg.dst

            # sort out messages that do not have a suitable locality setting
            if type_src == "external" and type_dst == "external":
                continue

            msg.src, msg.dst = bot_configs[id_src], bot_configs[id_dst]
            msg.src["ID"], msg.dst["ID"] = id_src, id_dst
            msg.msg_id = new_id
            new_id += 1
            ### Important here to update refers, if needed later?
            final_messages.append(msg)

        return final_messages


    def _get_capture_duration(self):
        """
        Returns the duration of the input PCAP (since statistics duration seems to be incorrect)
        """
        ts_date_format = "%Y-%m-%d %H:%M:%S.%f"
        ts_first_date = datetime.strptime(self.statistics.get_pcap_timestamp_start(), ts_date_format)
        ts_last_date = datetime.strptime(self.statistics.get_pcap_timestamp_end(), ts_date_format)
        diff_date = ts_last_date - ts_first_date
        duration = "%d.%d" % (diff_date.total_seconds(), diff_date.microseconds)
        return duration