import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.Map.Entry;

import de.tu_darmstadt.tk.SmartHomeNetworkSim.core.Link;
import de.tu_darmstadt.tk.SmartHomeNetworkSim.core.Packet;
import de.tu_darmstadt.tk.SmartHomeNetworkSim.core.PacketSniffer;
import weka.clusterers.SimpleKMeans;
import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;

/**
 * Unsupervised Example - maybe Clustering
 *
 * @author Andreas T. Meyer-Berg
 */
public class UnsupervisedAnomalyDetectionExample implements PacketSniffer {

	/**
	 * Clusterer
	 */
	private SimpleKMeans clusterer;
	
	/**
	 * True, if instances should be used for training
	 */
	private boolean training = true;
	
	/**
	 * Attributes which should be taken into account
	 */
	private ArrayList<Attribute> atts = new ArrayList<Attribute>();
	
	/**
	 * Collected Packets
	 */
	private Instances dataset;
	
	/**
	 * HashMap for calculating transmission delay
	 */
	private HashMap<Link, LinkedList<Packet>> lastPackets = new HashMap<Link, LinkedList<Packet>>();
	
	/**
	 * Number of Clusters
	 */
	private int NUMBER_OF_CLUSTERS = 2;
	
	/**
	 * Number of packets used for number of packets per second
	 */
	private int NUMBER_OF_PACKETS = 30;
	
	/**
	 * 
	 */
	private HashMap<String,Integer> link_mappings = new HashMap<String, Integer>();

	private HashMap<String,Integer> source_mappings = new HashMap<String, Integer>();
	
	private HashMap<String,Integer> destination_mappings = new HashMap<String, Integer>();
	
	private HashMap<String,Integer> protocol_mappings = new HashMap<String, Integer>();
	/**
	 * 
	 */
	public UnsupervisedAnomalyDetectionExample() {
		// Initialize Attribute list
		link_mappings.put("unknown", 0);
		atts.add(new Attribute("Link-Name", false));//TODO:??
		source_mappings.put("unknown", 0);
		atts.add(new Attribute("Source-Device", false));
		atts.add(new Attribute("Source-Port-number", false));
		destination_mappings.put("unknown", 0);
		atts.add(new Attribute("Destination-Device", false));
		atts.add(new Attribute("Destination-Port-number", false));
		protocol_mappings.put("unknown", 0);
		atts.add(new Attribute("Protocol-name", false));
		atts.add(new Attribute("Packets-per-second", false));
		// Initialize data set
		dataset = new Instances("Packets", atts, 100000);
		// Initialize Clusterer
		clusterer = new SimpleKMeans();
		clusterer.setSeed(42);
		try {
			clusterer.setNumClusters(NUMBER_OF_CLUSTERS);
		} catch (Exception e) {
			System.out.println("Error while building cluster");
			e.printStackTrace();
		}
	}
	
	@Override
	public void processPackets(HashMap<Link, LinkedList<Packet>> packets) {
		if(training)
			try {
				training(packets);
			} catch (Exception e) {
				e.printStackTrace();
			}
		else
			classify(packets);
	}
	/**
	 * Estimates the current Packets per second (depending on the last 100 packets of the link)
	 * @param link Link which should be checked
	 * @param packet Packet which should investigated
	 * @return estimated number of packets per second
	 */
	private double getEstimatedPacketsPerSecond(Link link, Packet packet) {
		/**
		 * Packets used to calculated the packets per second
		 */
		LinkedList<Packet> list = lastPackets.get(link);
		if(list == null) {
			/**
			 * Add list if not present
			 */
			list = new LinkedList<Packet>();
			lastPackets.put(link, list);
		}
		if(list.isEmpty()) {
			list.addLast(packet);
			// Default 1 packet per second
			return 1.0;
		}
		if(list.size() == NUMBER_OF_PACKETS){
			list.removeFirst();	
		}
		list.addLast(packet);
		/**
		 * elapsed time in milliseconds since last packet
		 */
		long elapsed_time = packet.getTimestamp()-list.getFirst().getTimestamp()/list.size();
		if(elapsed_time<=0)
			return Double.POSITIVE_INFINITY;
		/**
		 * Return number of packets per second
		 */
		return 1000.0/elapsed_time;
		
	}
	
	/**
	 * Returns the instance representation of the given packet and link
	 * @param link link the packet was sent on
	 * @param packet packet which should be transformed
	 * @param dataset distribution the packet is part of
	 * @return instance representation
	 */
	private Instance packet2Instance(Link link, Packet packet, Instances dataset) {
		/**
		 * Instance for the given Packet
		 */
		DenseInstance instance = new DenseInstance(dataset.numAttributes());
		instance.setDataset(dataset);
		
		// link
		instance.setValue(0, link == null ? 0 : stringToNumber(link_mappings, link.getName()));
		
		// source
		if(packet.getSource()==null) {
			instance.setValue(1, 0);
			instance.setValue(2, Double.NEGATIVE_INFINITY);
		}else if(packet.getSource().getOwner()==null){
			instance.setValue(1, 0);
			instance.setValue(2, packet.getSource().getPortNumber());
		}else {
			instance.setValue(1, stringToNumber(source_mappings, packet.getSource().getOwner().getName()));
			instance.setValue(2, packet.getSource().getPortNumber());
		}
		
		// Destination
		if(packet.getDestination()==null) {
			instance.setValue(3, 0);
			instance.setValue(4, Double.NEGATIVE_INFINITY);
		}else if(packet.getDestination().getOwner()==null){
			instance.setValue(3, 0);
			instance.setValue(4, packet.getDestination().getPortNumber());
		}else {
			instance.setValue(3, stringToNumber(destination_mappings, packet.getDestination().getOwner().getName()));
			instance.setValue(4, packet.getDestination().getPortNumber());
		}
		
		// Protocol name
		instance.setValue(5, stringToNumber(protocol_mappings, packet.getProtocolName()));
		
		// Packets per second
		instance.setValue(6, getEstimatedPacketsPerSecond(link, packet));
		
		return instance;
	}
	
	/**
	 * Transforms the String into an Number
	 * @param map
	 * @param s
	 * @return
	 */
	double stringToNumber(HashMap<String, Integer> map, String s) {
		Integer i = map.get(s);
		if(i == null) {
			int size = map.size();
			map.put(s, size);
			return size;
		}else {
			return i;
		}
	}
	/**
	 * Train the clusterer by collecting the packets
	 * 
	 * @param packets packets to be learned
	 */
	private void training(HashMap<Link, LinkedList<Packet>> packets) {
		for (Iterator<Entry<Link, LinkedList<Packet>>> it = packets.entrySet().iterator(); it.hasNext();) {
			Entry<Link, LinkedList<Packet>> entry = it.next();
			/**
			 * Link the packet was captured on
			 */
			Link l = entry.getKey();
			for (Iterator<Packet> itPacket = entry.getValue().iterator(); itPacket.hasNext();) {
				/**
				 * Packets to be added to the dataset
				 */
				Packet packet = (Packet) itPacket.next();
				dataset.add(packet2Instance(l, packet, dataset));
			}
		}
	}
	
	/**
	 * Finishes the collection and trains the clusterer on the collected packets
	 * 
	 * @throws Exception
	 */
	private void finishDataCollection() throws Exception{
		/**
		 * Build the clusterer for the given dataset
		 */
		clusterer.buildClusterer(dataset);
	}
	
	/**
	 * Try to classify the given packets and detect anomalies
	 * @param packets packets to be classified
	 */
	private void classify(HashMap<Link, LinkedList<Packet>> packets) {
		for (Iterator<Entry<Link, LinkedList<Packet>>> it = packets.entrySet().iterator(); it.hasNext();) {
			/**
			 * Link & its packets
			 */
			Entry<Link, LinkedList<Packet>> entry = it.next();
			/**
			 * Link the packets were captured on
			 */
			Link l = entry.getKey();
			for (Iterator<Packet> itPacket = entry.getValue().iterator(); itPacket.hasNext();) {
				/**
				 * Packet which should be checked
				 */
				Packet packet = (Packet) itPacket.next();
				/**
				 * Instance Representation
				 */
				Instance packet_instance = packet2Instance(l, packet, dataset);
				try {
					/**
					 * Try to classify (find appropriate cluster)
					 */
					clusterer.clusterInstance(packet_instance);
				} catch (Exception e) {
					/**
					 * Anomaly found
					 */
					System.out.println("Anomaly: "+packet.getTextualRepresentation());
					//e.printStackTrace();
				}
			}
		}
	}
	

	
	@Override
	public void setMode(boolean testing) {
		training = !testing;
		if(testing) {
			// Build Clusterer
			try {
				finishDataCollection();
			} catch (Exception e) {
				System.out.println("Clustering failed");
				e.printStackTrace();
			}
		}
	}
	
	@Override
	public boolean getMode() {
		return !training;
	}
}