|
@@ -1,10 +1,15 @@
|
|
|
+import java.util.ArrayList;
|
|
|
import java.util.HashMap;
|
|
|
+import java.util.Iterator;
|
|
|
import java.util.LinkedList;
|
|
|
+import java.util.Map.Entry;
|
|
|
|
|
|
import de.tu_darmstadt.tk.SmartHomeNetworkSim.core.Link;
|
|
|
import de.tu_darmstadt.tk.SmartHomeNetworkSim.core.Packet;
|
|
|
import de.tu_darmstadt.tk.SmartHomeNetworkSim.core.PacketSniffer;
|
|
|
import weka.clusterers.SimpleKMeans;
|
|
|
+import weka.core.Attribute;
|
|
|
+import weka.core.DenseInstance;
|
|
|
import weka.core.Instance;
|
|
|
import weka.core.Instances;
|
|
|
|
|
@@ -25,55 +30,268 @@ public class UnsupervisedAnomalyDetectionExample implements PacketSniffer {
|
|
|
*/
|
|
|
private boolean training = true;
|
|
|
|
|
|
+ /**
|
|
|
+ * Attributes which should be taken into account
|
|
|
+ */
|
|
|
+ private ArrayList<Attribute> atts = new ArrayList<Attribute>();
|
|
|
+
|
|
|
+ /**
|
|
|
+ * Collected Packets
|
|
|
+ */
|
|
|
+ private Instances dataset;
|
|
|
+
|
|
|
+ /**
|
|
|
+ * HashMap for calculating transmission delay
|
|
|
+ */
|
|
|
+ private HashMap<Link, LinkedList<Packet>> lastPackets = new HashMap<Link, LinkedList<Packet>>();
|
|
|
+
|
|
|
+ /**
|
|
|
+ * Number of Clusters
|
|
|
+ */
|
|
|
+ private int NUMBER_OF_CLUSTERS = 2;
|
|
|
+
|
|
|
+ /**
|
|
|
+ * Number of packets used for number of packets per second
|
|
|
+ */
|
|
|
+ private int NUMBER_OF_PACKETS = 30;
|
|
|
+
|
|
|
+ /**
|
|
|
+ *
|
|
|
+ */
|
|
|
+ private HashMap<String,Integer> link_mappings = new HashMap<String, Integer>();
|
|
|
+
|
|
|
+ private HashMap<String,Integer> source_mappings = new HashMap<String, Integer>();
|
|
|
+
|
|
|
+ private HashMap<String,Integer> destination_mappings = new HashMap<String, Integer>();
|
|
|
+
|
|
|
+ private HashMap<String,Integer> protocol_mappings = new HashMap<String, Integer>();
|
|
|
/**
|
|
|
*
|
|
|
*/
|
|
|
public UnsupervisedAnomalyDetectionExample() {
|
|
|
+ // Initialize Attribute list
|
|
|
+ link_mappings.put("unknown", 0);
|
|
|
+ atts.add(new Attribute("Link-Name", false));//TODO:??
|
|
|
+ source_mappings.put("unknown", 0);
|
|
|
+ atts.add(new Attribute("Source-Device", false));
|
|
|
+ atts.add(new Attribute("Source-Port-number", false));
|
|
|
+ destination_mappings.put("unknown", 0);
|
|
|
+ atts.add(new Attribute("Destination-Device", false));
|
|
|
+ atts.add(new Attribute("Destination-Port-number", false));
|
|
|
+ protocol_mappings.put("unknown", 0);
|
|
|
+ atts.add(new Attribute("Protocol-name", false));
|
|
|
+ atts.add(new Attribute("Packets-per-second", false));
|
|
|
+ // Initialize data set
|
|
|
+ dataset = new Instances("Packets", atts, 100000);
|
|
|
+ // Initialize Clusterer
|
|
|
clusterer = new SimpleKMeans();
|
|
|
clusterer.setSeed(42);
|
|
|
try {
|
|
|
- clusterer.setNumClusters(20);
|
|
|
+ clusterer.setNumClusters(NUMBER_OF_CLUSTERS);
|
|
|
} catch (Exception e) {
|
|
|
System.out.println("Error while building cluster");
|
|
|
e.printStackTrace();
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
@Override
|
|
|
public void processPackets(HashMap<Link, LinkedList<Packet>> packets) {
|
|
|
- if(!packets.entrySet().isEmpty() && packets.entrySet().iterator().next().getValue().getFirst().getTimestamp()>10000)
|
|
|
+ if(!packets.entrySet().isEmpty() && packets.entrySet().iterator().next().getValue().getFirst().getTimestamp()>10000) {
|
|
|
training = false;
|
|
|
-
|
|
|
- Instances processed = preProcess(packets);
|
|
|
+ // Build Clusterer
|
|
|
+ try {
|
|
|
+ finishDataCollection();
|
|
|
+ } catch (Exception e) {
|
|
|
+ System.out.println("Clustering failed");
|
|
|
+ e.printStackTrace();
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
if(training)
|
|
|
try {
|
|
|
- training(processed);
|
|
|
+ training(packets);
|
|
|
} catch (Exception e) {
|
|
|
- // TODO Auto-generated catch block
|
|
|
e.printStackTrace();
|
|
|
}
|
|
|
else
|
|
|
- classify(processed);
|
|
|
+ classify(packets);
|
|
|
+ }
|
|
|
+ /**
|
|
|
+ * Estimates the current Packets per second (depending on the last 100 packets of the link)
|
|
|
+ * @param link Link which should be checked
|
|
|
+ * @param packet Packet which should investigated
|
|
|
+ * @return estimated number of packets per second
|
|
|
+ */
|
|
|
+ private double getEstimatedPacketsPerSecond(Link link, Packet packet) {
|
|
|
+ /**
|
|
|
+ * Packets used to calculated the packets per second
|
|
|
+ */
|
|
|
+ LinkedList<Packet> list = lastPackets.get(link);
|
|
|
+ if(list == null) {
|
|
|
+ /**
|
|
|
+ * Add list if not present
|
|
|
+ */
|
|
|
+ list = new LinkedList<Packet>();
|
|
|
+ lastPackets.put(link, list);
|
|
|
+ }
|
|
|
+ if(list.isEmpty()) {
|
|
|
+ list.addLast(packet);
|
|
|
+ // Default 1 packet per second
|
|
|
+ return 1.0;
|
|
|
+ }
|
|
|
+ if(list.size() == NUMBER_OF_PACKETS){
|
|
|
+ list.removeFirst();
|
|
|
+ }
|
|
|
+ list.addLast(packet);
|
|
|
+ /**
|
|
|
+ * elapsed time in milliseconds since last packet
|
|
|
+ */
|
|
|
+ long elapsed_time = packet.getTimestamp()-list.getFirst().getTimestamp()/list.size();
|
|
|
+ if(elapsed_time<=0)
|
|
|
+ return Double.POSITIVE_INFINITY;
|
|
|
+ /**
|
|
|
+ * Return number of packets per second
|
|
|
+ */
|
|
|
+ return 1000.0/elapsed_time;
|
|
|
+
|
|
|
}
|
|
|
|
|
|
- private Instances preProcess(HashMap<Link, LinkedList<Packet>> packets) {
|
|
|
+ /**
|
|
|
+ * Returns the instance representation of the given packet and link
|
|
|
+ * @param link link the packet was sent on
|
|
|
+ * @param packet packet which should be transformed
|
|
|
+ * @param dataset distribution the packet is part of
|
|
|
+ * @return instance representation
|
|
|
+ */
|
|
|
+ private Instance packet2Instance(Link link, Packet packet, Instances dataset) {
|
|
|
+ /**
|
|
|
+ * Instance for the given Packet
|
|
|
+ */
|
|
|
+ DenseInstance instance = new DenseInstance(dataset.numAttributes());
|
|
|
+ instance.setDataset(dataset);
|
|
|
+
|
|
|
+ // link
|
|
|
+ instance.setValue(0, link == null ? 0 : stringToNumber(link_mappings, link.getName()));
|
|
|
+
|
|
|
+ // source
|
|
|
+ if(packet.getSource()==null) {
|
|
|
+ instance.setValue(1, 0);
|
|
|
+ instance.setValue(2, Double.NEGATIVE_INFINITY);
|
|
|
+ }else if(packet.getSource().getOwner()==null){
|
|
|
+ instance.setValue(1, 0);
|
|
|
+ instance.setValue(2, packet.getSource().getPortNumber());
|
|
|
+ }else {
|
|
|
+ instance.setValue(1, stringToNumber(source_mappings, packet.getSource().getOwner().getName()));
|
|
|
+ instance.setValue(2, packet.getSource().getPortNumber());
|
|
|
+ }
|
|
|
|
|
|
+ // Destination
|
|
|
+ if(packet.getDestination()==null) {
|
|
|
+ instance.setValue(3, 0);
|
|
|
+ instance.setValue(4, Double.NEGATIVE_INFINITY);
|
|
|
+ }else if(packet.getDestination().getOwner()==null){
|
|
|
+ instance.setValue(3, 0);
|
|
|
+ instance.setValue(4, packet.getDestination().getPortNumber());
|
|
|
+ }else {
|
|
|
+ instance.setValue(3, stringToNumber(destination_mappings, packet.getDestination().getOwner().getName()));
|
|
|
+ instance.setValue(4, packet.getDestination().getPortNumber());
|
|
|
+ }
|
|
|
+
|
|
|
+ // Protocol name
|
|
|
+ instance.setValue(5, stringToNumber(protocol_mappings, packet.getProtocolName()));
|
|
|
|
|
|
- return null;
|
|
|
+ // Packets per second
|
|
|
+ instance.setValue(6, getEstimatedPacketsPerSecond(link, packet));
|
|
|
|
|
|
+ return instance;
|
|
|
}
|
|
|
|
|
|
- private void training(Instances processed) throws Exception {
|
|
|
- clusterer.buildClusterer(processed);
|
|
|
+ /**
|
|
|
+ * Transforms the String into an Number
|
|
|
+ * @param map
|
|
|
+ * @param s
|
|
|
+ * @return
|
|
|
+ */
|
|
|
+ double stringToNumber(HashMap<String, Integer> map, String s) {
|
|
|
+ Integer i = map.get(s);
|
|
|
+ if(i == null) {
|
|
|
+ int size = map.size();
|
|
|
+ map.put(s, size);
|
|
|
+ return size;
|
|
|
+ }else {
|
|
|
+ return i;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ /**
|
|
|
+ * Train the clusterer by collecting the packets
|
|
|
+ *
|
|
|
+ * @param packets packets to be learned
|
|
|
+ */
|
|
|
+ private void training(HashMap<Link, LinkedList<Packet>> packets) {
|
|
|
+ for (Iterator<Entry<Link, LinkedList<Packet>>> it = packets.entrySet().iterator(); it.hasNext();) {
|
|
|
+ Entry<Link, LinkedList<Packet>> entry = it.next();
|
|
|
+ /**
|
|
|
+ * Link the packet was captured on
|
|
|
+ */
|
|
|
+ Link l = entry.getKey();
|
|
|
+ for (Iterator<Packet> itPacket = entry.getValue().iterator(); itPacket.hasNext();) {
|
|
|
+ /**
|
|
|
+ * Packets to be added to the dataset
|
|
|
+ */
|
|
|
+ Packet packet = (Packet) itPacket.next();
|
|
|
+ dataset.add(packet2Instance(l, packet, dataset));
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
- private void classify(Instances processed) {
|
|
|
- for(Instance i:processed)
|
|
|
- try {
|
|
|
- clusterer.clusterInstance(i);
|
|
|
- } catch (Exception e) {
|
|
|
- System.out.println("Anomaly "+i);
|
|
|
- e.printStackTrace();
|
|
|
+ /**
|
|
|
+ * Finishes the collection and trains the clusterer on the collected packets
|
|
|
+ *
|
|
|
+ * @throws Exception
|
|
|
+ */
|
|
|
+ private void finishDataCollection() throws Exception{
|
|
|
+ /**
|
|
|
+ * Build the clusterer for the given dataset
|
|
|
+ */
|
|
|
+ clusterer.buildClusterer(dataset);
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * Try to classify the given packets and detect anomalies
|
|
|
+ * @param packets packets to be classified
|
|
|
+ */
|
|
|
+ private void classify(HashMap<Link, LinkedList<Packet>> packets) {
|
|
|
+ for (Iterator<Entry<Link, LinkedList<Packet>>> it = packets.entrySet().iterator(); it.hasNext();) {
|
|
|
+ /**
|
|
|
+ * Link & its packets
|
|
|
+ */
|
|
|
+ Entry<Link, LinkedList<Packet>> entry = it.next();
|
|
|
+ /**
|
|
|
+ * Link the packets were captured on
|
|
|
+ */
|
|
|
+ Link l = entry.getKey();
|
|
|
+ for (Iterator<Packet> itPacket = entry.getValue().iterator(); itPacket.hasNext();) {
|
|
|
+ /**
|
|
|
+ * Packet which should be checked
|
|
|
+ */
|
|
|
+ Packet packet = (Packet) itPacket.next();
|
|
|
+ /**
|
|
|
+ * Instance Representation
|
|
|
+ */
|
|
|
+ Instance packet_instance = packet2Instance(l, packet, dataset);
|
|
|
+ try {
|
|
|
+ /**
|
|
|
+ * Try to classify (find appropriate cluster)
|
|
|
+ */
|
|
|
+ clusterer.clusterInstance(packet_instance);
|
|
|
+ } catch (Exception e) {
|
|
|
+ /**
|
|
|
+ * Anomaly found
|
|
|
+ */
|
|
|
+ System.out.println("Anomaly: "+packet.getTextualRepresentation());
|
|
|
+ //e.printStackTrace();
|
|
|
+ }
|
|
|
}
|
|
|
+ }
|
|
|
}
|
|
|
-
|
|
|
}
|