from enum import Enum

class MessageType(Enum):
    """
    Defines possible botnet message types
    """

    TIMEOUT = 3
    SALITY_NL_REQUEST = 101
    SALITY_NL_REPLY = 102
    SALITY_HELLO = 103
    SALITY_HELLO_REPLY = 104

class Message():
    """
    Defines a compact message type that contains all necessary information.
    """
    def __init__(self, msg_id: int, src, dst, type_: MessageType, time: float, refer_msg_id: int=-1):
        """
        Constructs a message with the given parameters.

        :param msg_id: the ID of the message
        :param src: something identifiying the source, e.g. ID or configuration
        :param dst: something identifiying the destination, e.g. ID or configuration
        :param type_: the type of the message
        :param time: the timestamp of the message
        :param refer_msg_id: the ID this message is a request for or reply to. -1 if there is no related message.
        """
        self.msg_id = msg_id
        self.src = src
        self.dst = dst
        self.type = type_
        self.time = time
        self.refer_msg_id = refer_msg_id

    def __str__(self):
        str_ = "{0}. at {1}: {2}-->{3}, {4}, refer:{5}".format(self.msg_id, self.time, self.src, self.dst, self.type, self.refer_msg_id)
        return str_

from random import randint, randrange, choice, uniform
from collections import deque
from scipy.stats import gamma
from lea import Lea
from datetime import datetime

from Attack import BaseAttack
from Attack.AttackParameters import Parameter as Param
from Attack.AttackParameters import ParameterTypes

from ID2TLib import FileUtils, PaddingGenerator
from ID2TLib.PacketGenerator import PacketGenerator
from ID2TLib.IPGenerator import IPGenerator
from ID2TLib.PcapAddressOperations import PcapAddressOperations
from ID2TLib.CommunicationProcessor import CommunicationProcessor
from ID2TLib.MacAddressGenerator import MacAddressGenerator
from ID2TLib.PortGenerator import gen_random_server_port


class MembersMgmtCommAttack(BaseAttack.BaseAttack):
    def __init__(self):
        """
        Creates a new instance of the Membership Management Communication.

        """
        # Initialize communication
        super(MembersMgmtCommAttack, self).__init__("Membership Management Communication Attack (MembersMgmtCommAttack)", 
                                        "Injects Membership Management Communication", "Botnet communication")

        # Define allowed parameters and their type
        self.supported_params = {
            # parameters regarding attack 
            Param.INJECT_AT_TIMESTAMP: ParameterTypes.TYPE_FLOAT,
            Param.INJECT_AFTER_PACKET: ParameterTypes.TYPE_PACKET_POSITION,
            Param.PACKETS_PER_SECOND: ParameterTypes.TYPE_FLOAT,
            Param.PACKETS_LIMIT: ParameterTypes.TYPE_INTEGER_POSITIVE,
            Param.ATTACK_DURATION: ParameterTypes.TYPE_INTEGER_POSITIVE,

            # use num_attackers to specify number of communicating devices?
            Param.NUMBER_INITIATOR_BOTS: ParameterTypes.TYPE_INTEGER_POSITIVE,

            # input file containing botnet communication
            Param.FILE_CSV: ParameterTypes.TYPE_FILEPATH,
            Param.FILE_XML: ParameterTypes.TYPE_FILEPATH,

            # the percentage of IP reuse (if total and other is specified, percentages are multiplied)
            Param.IP_REUSE_TOTAL: ParameterTypes.TYPE_PERCENTAGE,
            Param.IP_REUSE_LOCAL: ParameterTypes.TYPE_PERCENTAGE,
            Param.IP_REUSE_EXTERNAL: ParameterTypes.TYPE_PERCENTAGE,

            # the user-selected padding to add to every packet
            Param.PACKET_PADDING: ParameterTypes.TYPE_PADDING,

            # presence of NAT at the gateway of the network
            Param.NAT_PRESENT: ParameterTypes.TYPE_BOOLEAN
        }

        # create dict with MessageType values for fast name lookup
        self.msg_types = {}
        for msg_type in MessageType:
            self.msg_types[msg_type.value] = msg_type

    def init_params(self):
        """
        Initialize some parameters of this communication-attack using the user supplied command line parameters.
        The remaining parameters are implicitly set in the provided data file. Note: the timestamps in the file 
        have to be sorted in ascending order

        :param statistics: Reference to a statistics object.
        """
        # set class constants
        self.DEFAULT_XML_PATH = "resources/MembersMgmtComm_example.xml"
        # probability for responder ID to be local if comm_type is mixed
        self.PROB_RESPND_IS_LOCAL = 0

        # PARAMETERS: initialize with default values
        # (values are overwritten if user specifies them)
        self.add_param_value(Param.INJECT_AFTER_PACKET, randint(1, int(self.statistics.get_packet_count()/5)))
        #self.add_param_value(Param.INJECT_AFTER_PACKET, 1)

        self.add_param_value(Param.PACKETS_PER_SECOND, 0)
        self.add_param_value(Param.FILE_XML, self.DEFAULT_XML_PATH)

        # Alternatively new attack parameter?
        duration = int(float(self._get_capture_duration()))
        # if duration == 0:
        #     duration = 1
        self.add_param_value(Param.ATTACK_DURATION, duration)
        self.add_param_value(Param.NUMBER_INITIATOR_BOTS, 1)
        # NAT on by default
        self.add_param_value(Param.NAT_PRESENT, True)

        # default locality behavior
        # self.add_param_value(Param.COMM_TYPE, "mixed")
        # TODO: change 1 to something better
        self.add_param_value(Param.IP_REUSE_TOTAL, 1)
        self.add_param_value(Param.IP_REUSE_LOCAL, 0.5)
        self.add_param_value(Param.IP_REUSE_EXTERNAL, 0.5)

        # add default additional padding
        self.add_param_value(Param.PACKET_PADDING, 20)


        
    def generate_attack_pcap(self):
        # create the final messages that have to be sent, including all bot configurations
        messages = self._create_messages()

        if messages == []:
            return 0, []

        # Setup (initial) parameters for packet creation loop
        BUFFER_SIZE = 1000
        pkt_gen = PacketGenerator()
        file_timestamp_prv = messages[0].time
        pcap_timestamp = self.get_param_value(Param.INJECT_AT_TIMESTAMP)
        padding = self.get_param_value(Param.PACKET_PADDING)
        packets = deque(maxlen=BUFFER_SIZE)
        total_pkts = 0
        limit_packetcount = self.get_param_value(Param.PACKETS_LIMIT)
        limit_duration = self.get_param_value(Param.ATTACK_DURATION)
        duration = 0
        path_attack_pcap = None
        # create packets to write to PCAP file
        for msg in messages:
            # retrieve the source and destination configurations
            id_src, id_dst = msg.src["ID"], msg.dst["ID"]
            ip_src, ip_dst = msg.src["IP"], msg.dst["IP"]
            mac_src, mac_dst = msg.src["MAC"], msg.dst["MAC"]
            port_src, port_dst = msg.src["Port"], msg.dst["Port"]
            ttl = msg.src["TTL"]

            # update timestamps and duration
            file_timestamp = msg.time
            file_time_delta = file_timestamp - file_timestamp_prv
            pcap_timestamp += file_time_delta
            duration += file_time_delta
            file_timestamp_prv = file_timestamp

            # if total number of packets has been sent or the attack duration has been exceeded, stop
            if ((limit_packetcount is not None and total_pkts >= limit_packetcount) or 
                    (limit_duration is not None and duration >= limit_duration)):
                break
        
            # if the type of the message is a NL reply, determine the number of entries
            nl_size = 0     
            if msg.type == MessageType.SALITY_NL_REPLY:
                nl_size = randint(1, 25)    # what is max NL entries? 

            # create suitable IP/UDP packet and add to packets list
            packet = pkt_gen.generate_mmcom_packet(ip_src=ip_src, ip_dst=ip_dst, ttl=ttl, mac_src=mac_src, mac_dst=mac_dst, 
                port_src=port_src, port_dst=port_dst, message_type=msg.type, neighborlist_entries=nl_size)
            PaddingGenerator.add_padding(packet, padding,True, True)

            packet.time = pcap_timestamp
            packets.append(packet)
            total_pkts += 1

            # Store timestamp of first packet (for attack label)
            if total_pkts <= 1:
                self.attack_start_utime = packets[0].time
            elif total_pkts % BUFFER_SIZE == 0: # every 1000 packets write them to the PCAP file (append)
                packets = list(packets)
                PaddingGenerator.equal_length(packets, padding = padding)
                last_packet = packets[-1]
                path_attack_pcap = self.write_attack_pcap(packets, True, path_attack_pcap)
                packets = deque(maxlen=BUFFER_SIZE)

        # if there are unwritten packets remaining, write them to the PCAP file
        if len(packets) > 0:
            packets = list(packets)
            PaddingGenerator.equal_length(packets, padding = padding)
            path_attack_pcap = self.write_attack_pcap(packets, True, path_attack_pcap)
            last_packet = packets[-1]

        # Store timestamp of last packet
        self.attack_end_utime = last_packet.time

        # Return packets sorted by packet by timestamp and total number of packets (sent)
        return total_pkts , path_attack_pcap


    def _create_messages(self):
        def add_ids_to_config(ids_to_add: list, existing_ips: list, new_ips: list, bot_configs: dict, idtype:str="local", router_mac:str=""):
            """
            Creates IP and MAC configurations for the given IDs and adds them to the existing configurations object.

            :param ids_to_add: all sorted IDs that have to be configured and added
            :param existing_ips: the existing IPs in the PCAP file that should be assigned to some, or all, IDs
            :param new_ips: the newly generated IPs that should be assigned to some, or all, IDs
            :param bot_configs: the existing configurations for the bots
            :param idtype: the locality type of the IDs
            :param router_mac: the MAC address of the router in the PCAP 
            """

            ids = ids_to_add.copy()
            # macgen only needed, when IPs are new local IPs (therefore creating the object here suffices for the current callers
            # to not end up with the same MAC paired with different IPs)
            macgen = MacAddressGenerator()

            # assign existing IPs and the corresponding MAC addresses in the PCAP to the IDs
            for ip in existing_ips:
                random_id = choice(ids)
                mac = self.statistics.process_db_query("macAddress(IPAddress=%s)" % ip)
                bot_configs[random_id] = {"Type": idtype, "IP": ip, "MAC": mac}
                ids.remove(random_id)

            # assign new IPs and for local IPs new MACs or for external IPs the router MAC to the IDs
            for ip in new_ips:
                random_id = choice(ids)
                if idtype == "local":
                    mac = macgen.random_mac()
                elif idtype == "external":
                    mac = router_mac
                bot_configs[random_id] = {"Type": idtype, "IP": ip, "MAC": mac}
                ids.remove(random_id)

        def index_increment(number: int, max: int):
            """
            Number increment with rollover.
            """
            if number + 1 < max:
                return number + 1
            else:
                return 0

        def assign_realistic_ttls2(bot_configs):
            # Gamma distribution parameters derived from MAWI 13.8G dataset
            ids = sorted(bot_configs.keys())
            alpha, loc, beta = (2.3261710235, -0.188306914406, 44.4853123884)
            gd = gamma.rvs(alpha, loc=loc, scale=beta, size=len(ids))

            for pos, bot in enumerate(ids):
                # print(bot)
                is_invalid = True
                pos_max = len(gd)
                while is_invalid:
                    ttl = int(round(gd[pos]))
                    if 0 < ttl < 256:  # validity check
                        is_invalid = False
                    else:
                        pos = index_increment(pos, pos_max)
                bot_configs[bot]["TTL"] = ttl


        def assign_realistic_ttls(bot_configs):
            '''
            Assigns a realisitic ttl to each bot from @param: bot_configs. Uses statistics and distribution to be able
            to calculate a realisitc ttl.
            :param bot_configs:
            :return:
            '''
            ids = sorted(bot_configs.keys())
            for pos,bot in enumerate(ids):
                bot_type = bot_configs[bot]["Type"]
                # print(bot_type)
                if(bot_type == "local"): # Set fix TTL for local Bots
                    bot_configs[bot]["TTL"] = 128
                    # Set TTL based on TTL distribution of IP address
                else: # Set varying TTl for external Bots
                    bot_ttl_dist = self.statistics.get_ttl_distribution(bot_configs[bot]["IP"])
                    if len(bot_ttl_dist) > 0:
                         source_ttl_prob_dict = Lea.fromValFreqsDict(bot_ttl_dist)
                         bot_configs[bot]["TTL"] = source_ttl_prob_dict.random()
                    else:
                         bot_configs[bot]["TTL"] = self.statistics.process_db_query("most_used(ttlValue)")


        def add_delay(timestamp, minDelay, delay):
            '''
            Adds delay to a timestamp, with a minimum value of minDelay. But usually a value close to delay
            :param timestamp: the timestamp that is to be increased
            :param minDelay: the minimum value that is to add to the timestamp
            :param delay: The general size of the delay. Statistically speaking: the expected value
            :return: the updated timestamp
            '''

            randomdelay = Lea.fromValFreqsDict({0.15*delay: 7, 0.3*delay: 10, 0.7*delay:20,
                                delay:33, 1.2*delay:20, 1.6*delay: 10, 1.9*delay: 7, 2.5*delay: 3, 4*delay: 1})
            if 0.1*delay < minDelay:
                print("Warning: minDelay probably too big when computing time_stamps")

            general_offset = randomdelay.random()
            unique_offset = uniform(-0.1*general_offset, 0.1*general_offset)
            return timestamp + minDelay + general_offset + unique_offset

        # parse input CSV or XML
        filepath_xml = self.get_param_value(Param.FILE_XML)
        filepath_csv = self.get_param_value(Param.FILE_CSV)

        # prefer XML input over CSV input (in case both are given)
        if filepath_csv and filepath_xml == self.DEFAULT_XML_PATH:
            filepath_xml = FileUtils.parse_csv_to_xml(filepath_csv) 

        abstract_packets = FileUtils.parse_xml(filepath_xml)

        # find a good communication mapping in the input file that matches the users parameters
        duration = self.get_param_value(Param.ATTACK_DURATION)
        number_init_bots = self.get_param_value(Param.NUMBER_INITIATOR_BOTS)
        nat = self.get_param_value(Param.NAT_PRESENT)
        comm_proc = CommunicationProcessor(abstract_packets, self.msg_types, nat)

        comm_intervals = comm_proc.find_interval_most_comm(number_init_bots, duration)
        if comm_intervals == []:
            print("Error: There is no interval in the given CSV/XML that has enough communication initiating bots.")
            return []
        comm_interval = comm_intervals[randrange(0, len(comm_intervals))]

        # retrieve the mapping information
        mapped_ids, packet_start_idx, packet_end_idx = comm_interval["IDs"], comm_interval["Start"], comm_interval["End"]
        while len(mapped_ids) > number_init_bots:
            rm_idx = randrange(0, len(mapped_ids))
            del(mapped_ids[sorted(mapped_ids)[rm_idx]])

        # assign the communication processor this mapping for further processing
        comm_proc.set_mapping(abstract_packets[packet_start_idx:packet_end_idx+1], mapped_ids)
        # print start and end time of mapped interval
        #print(abstract_packets[packet_start_idx]["Time"])
        #print(abstract_packets[packet_end_idx]["Time"])
        #print(mapped_ids.keys())

        # determine number of reused local and external IPs
        reuse_percent_total = self.get_param_value(Param.IP_REUSE_TOTAL)
        reuse_percent_external = self.get_param_value(Param.IP_REUSE_EXTERNAL)
        reuse_percent_local = self.get_param_value(Param.IP_REUSE_LOCAL)
        reuse_count_external = int(reuse_percent_total * reuse_percent_external * len(mapped_ids))
        reuse_count_local = int(reuse_percent_total * reuse_percent_local * len(mapped_ids))

        # create locality, IP and MAC configurations for the IDs/Bots
        ipgen = IPGenerator()
        pcapops = PcapAddressOperations(self.statistics)
        router_mac = pcapops.get_probable_router_mac()
        bot_configs = {}
        # determine the roles of the IDs in the mapping communication-{initiator, responder}
        local_init_ids, external_init_ids, respnd_ids, messages = comm_proc.det_id_roles_and_msgs()
        # use these roles to determine which IDs are to be local and which external
        local_ids, external_ids = comm_proc.det_ext_and_local_ids(self.PROB_RESPND_IS_LOCAL)

        # retrieve and assign the IPs and MACs for the bots with respect to the given parameters
        # (IDs are always added to bot_configs in the same order under a given seed)
        number_local_ids, number_external_ids = len(local_ids), len(external_ids)
        # assign addresses for local IDs
        if number_local_ids > 0:
            reuse_count_local = int(reuse_percent_total * reuse_percent_local * number_local_ids) 
            existing_local_ips = sorted(pcapops.get_existing_local_ips(reuse_count_local))
            new_local_ips = sorted(pcapops.get_new_local_ips(number_local_ids - len(existing_local_ips)))
            add_ids_to_config(sorted(local_ids), existing_local_ips, new_local_ips, bot_configs)

        # assign addresses for external IDs
        if number_external_ids > 0:
            reuse_count_external = int(reuse_percent_total * reuse_percent_external * number_external_ids) 
            existing_external_ips = sorted(pcapops.get_existing_external_ips(reuse_count_external))
            remaining = len(external_ids) - len(existing_external_ips)
            new_external_ips = sorted([ipgen.random_ip() for _ in range(remaining)])
            add_ids_to_config(sorted(external_ids), existing_external_ips, new_external_ips, bot_configs, idtype="external", router_mac=router_mac)

        #### Set realistic timestamps for messages ####

        most_used_ip_address = self.statistics.get_most_used_ip_address()
        minDelay, maxDelay = self.get_reply_delay(most_used_ip_address)
        next_timestamp = self.get_param_value(Param.INJECT_AT_TIMESTAMP)
        pcap_duration = float(self._get_capture_duration())
        equi_timeslice = pcap_duration/len(messages)

        # Dict, takes a tuple of 2 Bots as a key (IP with lower number first), returns the time when the Hello_reply came in
        Hello_times = {}
        # msg_IDs with already updated timestamps
        updated_msgs = []

        for req_msg in messages:
            updated = 0
            if(req_msg.msg_id in updated_msgs):
                #message already updated
                continue

            if(req_msg.msg_id == -1):
                #message has no corresponding request/response
                req_msg.time = next_timestamp
                next_timestamp = add_delay(next_timestamp, minDelay, equi_timeslice)
                updated_msgs.append(req_msg.msg_id)
                continue


            elif req_msg.type != MessageType.SALITY_HELLO:
                #Hello msg must have preceded, so make sure the timestamp of this msg is after the HELLO_REPLY
                if int(req_msg.src) < int(req_msg.dst):
                    hello_time = Hello_times[(req_msg.src, req_msg.dst)]
                else:
                    hello_time = Hello_times[(req_msg.dst, req_msg.src)] 
                
                if next_timestamp < hello_time:
                    #use the time of the hello_reply instead of next_timestamp to update this pair of messages
                    post_hello = add_delay(hello_time, minDelay, equi_timeslice)
                    respns_msg = messages[req_msg.refer_msg_id]
                    respns_msg.time = add_delay(post_hello, minDelay, equi_timeslice)
                    req_msg.time = post_hello
                    updated = 1

            if not updated:
                #update normally
                respns_msg = messages[req_msg.refer_msg_id]
                respns_msg.time = add_delay(next_timestamp, minDelay, equi_timeslice)
                req_msg.time = next_timestamp
                next_timestamp = add_delay(next_timestamp, minDelay, equi_timeslice)

            updated_msgs.append(req_msg.msg_id)
            updated_msgs.append(req_msg.refer_msg_id)

            if req_msg.type == MessageType.SALITY_HELLO:
                if int(req_msg.src) < int(req_msg.dst):
                    Hello_times[(req_msg.src, req_msg.dst)] = respns_msg.time
                else:
                    Hello_times[(req_msg.dst, req_msg.src)] = respns_msg.time
                    
        # create port configurations for the bots
        for bot in bot_configs:
            bot_configs[bot]["Port"] = gen_random_server_port()    

        # print(local_init_ids)
        # print(bot_configs)

        # assign realistic TTL for every bot
        assign_realistic_ttls(bot_configs)
        # put together the final messages including the full sender and receiver
        # configurations (i.e. IP, MAC, port, ...) for easier later use
        final_messages = []
        messages = sorted(messages, key=lambda msg: msg.time)
        new_id = 0

        for msg in messages:
            type_src, type_dst = bot_configs[msg.src]["Type"], bot_configs[msg.dst]["Type"]
            id_src, id_dst = msg.src, msg.dst

            # sort out messages that do not have a suitable locality setting
            if type_src == "external" and type_dst == "external":
                continue
            
            msg.src, msg.dst = bot_configs[id_src], bot_configs[id_dst]
            msg.src["ID"], msg.dst["ID"] = id_src, id_dst
            msg.msg_id = new_id
            new_id += 1
            ### Important here to update refers, if needed later?
            final_messages.append(msg)

        return final_messages


    def _get_capture_duration(self):
        """
        Returns the duration of the input PCAP (since statistics duration seems to be incorrect)
        """
        ts_date_format = "%Y-%m-%d %H:%M:%S.%f"
        ts_first_date = datetime.strptime(self.statistics.get_pcap_timestamp_start(), ts_date_format)
        ts_last_date = datetime.strptime(self.statistics.get_pcap_timestamp_end(), ts_date_format)
        diff_date = ts_last_date - ts_first_date
        duration = "%d.%d" % (diff_date.total_seconds(), diff_date.microseconds)
        return duration