Browse Source

First AnomalyDetection Version

* All numerical -> probably bad distinction between names
 -> Nominal Values
Andreas T. Meyer-Berg 4 years ago
parent
commit
80966d0d44
1 changed files with 237 additions and 19 deletions
  1. 237 19
      examples/UnsupervisedAnomalyDetectionExample.java

+ 237 - 19
examples/UnsupervisedAnomalyDetectionExample.java

@@ -1,10 +1,15 @@
+import java.util.ArrayList;
 import java.util.HashMap;
+import java.util.Iterator;
 import java.util.LinkedList;
+import java.util.Map.Entry;
 
 import de.tu_darmstadt.tk.SmartHomeNetworkSim.core.Link;
 import de.tu_darmstadt.tk.SmartHomeNetworkSim.core.Packet;
 import de.tu_darmstadt.tk.SmartHomeNetworkSim.core.PacketSniffer;
 import weka.clusterers.SimpleKMeans;
+import weka.core.Attribute;
+import weka.core.DenseInstance;
 import weka.core.Instance;
 import weka.core.Instances;
 
@@ -25,55 +30,268 @@ public class UnsupervisedAnomalyDetectionExample implements PacketSniffer {
 	 */
 	private boolean training = true;
 	
+	/**
+	 * Attributes which should be taken into account
+	 */
+	private ArrayList<Attribute> atts = new ArrayList<Attribute>();
+	
+	/**
+	 * Collected Packets
+	 */
+	private Instances dataset;
+	
+	/**
+	 * HashMap for calculating transmission delay
+	 */
+	private HashMap<Link, LinkedList<Packet>> lastPackets = new HashMap<Link, LinkedList<Packet>>();
+	
+	/**
+	 * Number of Clusters
+	 */
+	private int NUMBER_OF_CLUSTERS = 2;
+	
+	/**
+	 * Number of packets used for number of packets per second
+	 */
+	private int NUMBER_OF_PACKETS = 30;
+	
+	/**
+	 * 
+	 */
+	private HashMap<String,Integer> link_mappings = new HashMap<String, Integer>();
+
+	private HashMap<String,Integer> source_mappings = new HashMap<String, Integer>();
+	
+	private HashMap<String,Integer> destination_mappings = new HashMap<String, Integer>();
+	
+	private HashMap<String,Integer> protocol_mappings = new HashMap<String, Integer>();
 	/**
 	 * 
 	 */
 	public UnsupervisedAnomalyDetectionExample() {
+		// Initialize Attribute list
+		link_mappings.put("unknown", 0);
+		atts.add(new Attribute("Link-Name", false));//TODO:??
+		source_mappings.put("unknown", 0);
+		atts.add(new Attribute("Source-Device", false));
+		atts.add(new Attribute("Source-Port-number", false));
+		destination_mappings.put("unknown", 0);
+		atts.add(new Attribute("Destination-Device", false));
+		atts.add(new Attribute("Destination-Port-number", false));
+		protocol_mappings.put("unknown", 0);
+		atts.add(new Attribute("Protocol-name", false));
+		atts.add(new Attribute("Packets-per-second", false));
+		// Initialize data set
+		dataset = new Instances("Packets", atts, 100000);
+		// Initialize Clusterer
 		clusterer = new SimpleKMeans();
 		clusterer.setSeed(42);
 		try {
-			clusterer.setNumClusters(20);
+			clusterer.setNumClusters(NUMBER_OF_CLUSTERS);
 		} catch (Exception e) {
 			System.out.println("Error while building cluster");
 			e.printStackTrace();
 		}
 	}
+	
 	@Override
 	public void processPackets(HashMap<Link, LinkedList<Packet>> packets) {
-		if(!packets.entrySet().isEmpty() && packets.entrySet().iterator().next().getValue().getFirst().getTimestamp()>10000)
+		if(!packets.entrySet().isEmpty() && packets.entrySet().iterator().next().getValue().getFirst().getTimestamp()>10000) {
 			training = false;
-			
-		Instances processed = preProcess(packets);
+			// Build Clusterer
+			try {
+				finishDataCollection();
+			} catch (Exception e) {
+				System.out.println("Clustering failed");
+				e.printStackTrace();
+			}
+		}
+		
 		if(training)
 			try {
-				training(processed);
+				training(packets);
 			} catch (Exception e) {
-				// TODO Auto-generated catch block
 				e.printStackTrace();
 			}
 		else
-			classify(processed);
+			classify(packets);
+	}
+	/**
+	 * Estimates the current Packets per second (depending on the last 100 packets of the link)
+	 * @param link Link which should be checked
+	 * @param packet Packet which should investigated
+	 * @return estimated number of packets per second
+	 */
+	private double getEstimatedPacketsPerSecond(Link link, Packet packet) {
+		/**
+		 * Packets used to calculated the packets per second
+		 */
+		LinkedList<Packet> list = lastPackets.get(link);
+		if(list == null) {
+			/**
+			 * Add list if not present
+			 */
+			list = new LinkedList<Packet>();
+			lastPackets.put(link, list);
+		}
+		if(list.isEmpty()) {
+			list.addLast(packet);
+			// Default 1 packet per second
+			return 1.0;
+		}
+		if(list.size() == NUMBER_OF_PACKETS){
+			list.removeFirst();	
+		}
+		list.addLast(packet);
+		/**
+		 * elapsed time in milliseconds since last packet
+		 */
+		long elapsed_time = packet.getTimestamp()-list.getFirst().getTimestamp()/list.size();
+		if(elapsed_time<=0)
+			return Double.POSITIVE_INFINITY;
+		/**
+		 * Return number of packets per second
+		 */
+		return 1000.0/elapsed_time;
+		
 	}
 	
-	private Instances preProcess(HashMap<Link, LinkedList<Packet>> packets) {
+	/**
+	 * Returns the instance representation of the given packet and link
+	 * @param link link the packet was sent on
+	 * @param packet packet which should be transformed
+	 * @param dataset distribution the packet is part of
+	 * @return instance representation
+	 */
+	private Instance packet2Instance(Link link, Packet packet, Instances dataset) {
+		/**
+		 * Instance for the given Packet
+		 */
+		DenseInstance instance = new DenseInstance(dataset.numAttributes());
+		instance.setDataset(dataset);
+		
+		// link
+		instance.setValue(0, link == null ? 0 : stringToNumber(link_mappings, link.getName()));
+		
+		// source
+		if(packet.getSource()==null) {
+			instance.setValue(1, 0);
+			instance.setValue(2, Double.NEGATIVE_INFINITY);
+		}else if(packet.getSource().getOwner()==null){
+			instance.setValue(1, 0);
+			instance.setValue(2, packet.getSource().getPortNumber());
+		}else {
+			instance.setValue(1, stringToNumber(source_mappings, packet.getSource().getOwner().getName()));
+			instance.setValue(2, packet.getSource().getPortNumber());
+		}
 		
+		// Destination
+		if(packet.getDestination()==null) {
+			instance.setValue(3, 0);
+			instance.setValue(4, Double.NEGATIVE_INFINITY);
+		}else if(packet.getDestination().getOwner()==null){
+			instance.setValue(3, 0);
+			instance.setValue(4, packet.getDestination().getPortNumber());
+		}else {
+			instance.setValue(3, stringToNumber(destination_mappings, packet.getDestination().getOwner().getName()));
+			instance.setValue(4, packet.getDestination().getPortNumber());
+		}
+		
+		// Protocol name
+		instance.setValue(5, stringToNumber(protocol_mappings, packet.getProtocolName()));
 		
-		return null;
+		// Packets per second
+		instance.setValue(6, getEstimatedPacketsPerSecond(link, packet));
 		
+		return instance;
 	}
 	
-	private void training(Instances processed) throws Exception {
-		clusterer.buildClusterer(processed);
+	/**
+	 * Transforms the String into an Number
+	 * @param map
+	 * @param s
+	 * @return
+	 */
+	double stringToNumber(HashMap<String, Integer> map, String s) {
+		Integer i = map.get(s);
+		if(i == null) {
+			int size = map.size();
+			map.put(s, size);
+			return size;
+		}else {
+			return i;
+		}
+	}
+	/**
+	 * Train the clusterer by collecting the packets
+	 * 
+	 * @param packets packets to be learned
+	 */
+	private void training(HashMap<Link, LinkedList<Packet>> packets) {
+		for (Iterator<Entry<Link, LinkedList<Packet>>> it = packets.entrySet().iterator(); it.hasNext();) {
+			Entry<Link, LinkedList<Packet>> entry = it.next();
+			/**
+			 * Link the packet was captured on
+			 */
+			Link l = entry.getKey();
+			for (Iterator<Packet> itPacket = entry.getValue().iterator(); itPacket.hasNext();) {
+				/**
+				 * Packets to be added to the dataset
+				 */
+				Packet packet = (Packet) itPacket.next();
+				dataset.add(packet2Instance(l, packet, dataset));
+			}
+		}
 	}
 	
-	private void classify(Instances processed) {
-		for(Instance i:processed)
-			try {
-				clusterer.clusterInstance(i);
-			} catch (Exception e) {
-				System.out.println("Anomaly "+i);
-				e.printStackTrace();
+	/**
+	 * Finishes the collection and trains the clusterer on the collected packets
+	 * 
+	 * @throws Exception
+	 */
+	private void finishDataCollection() throws Exception{
+		/**
+		 * Build the clusterer for the given dataset
+		 */
+		clusterer.buildClusterer(dataset);
+	}
+	
+	/**
+	 * Try to classify the given packets and detect anomalies
+	 * @param packets packets to be classified
+	 */
+	private void classify(HashMap<Link, LinkedList<Packet>> packets) {
+		for (Iterator<Entry<Link, LinkedList<Packet>>> it = packets.entrySet().iterator(); it.hasNext();) {
+			/**
+			 * Link & its packets
+			 */
+			Entry<Link, LinkedList<Packet>> entry = it.next();
+			/**
+			 * Link the packets were captured on
+			 */
+			Link l = entry.getKey();
+			for (Iterator<Packet> itPacket = entry.getValue().iterator(); itPacket.hasNext();) {
+				/**
+				 * Packet which should be checked
+				 */
+				Packet packet = (Packet) itPacket.next();
+				/**
+				 * Instance Representation
+				 */
+				Instance packet_instance = packet2Instance(l, packet, dataset);
+				try {
+					/**
+					 * Try to classify (find appropriate cluster)
+					 */
+					clusterer.clusterInstance(packet_instance);
+				} catch (Exception e) {
+					/**
+					 * Anomaly found
+					 */
+					System.out.println("Anomaly: "+packet.getTextualRepresentation());
+					//e.printStackTrace();
+				}
 			}
+		}
 	}
-
 }