|
@@ -0,0 +1,426 @@
|
|
|
+package classifier;
|
|
|
+import java.util.ArrayList;
|
|
|
+import java.util.HashMap;
|
|
|
+import java.util.HashSet;
|
|
|
+import java.util.Iterator;
|
|
|
+import java.util.LinkedList;
|
|
|
+import java.util.Map.Entry;
|
|
|
+
|
|
|
+import de.tu_darmstadt.tk.SmartHomeNetworkSim.core.Link;
|
|
|
+import de.tu_darmstadt.tk.SmartHomeNetworkSim.core.Packet;
|
|
|
+import de.tu_darmstadt.tk.SmartHomeNetworkSim.core.PacketSniffer;
|
|
|
+import de.tu_darmstadt.tk.SmartHomeNetworkSim.core.protocols.packets.MQTTpublishPacket;
|
|
|
+import weka.core.Attribute;
|
|
|
+import weka.core.DenseInstance;
|
|
|
+import weka.core.Instance;
|
|
|
+import weka.core.Instances;
|
|
|
+
|
|
|
+/**
|
|
|
+ * Unsupervised Classifier Basis, which contains methods for transforming {@link Packet}s into {@link Instance}s.
|
|
|
+ *
|
|
|
+ * @author Andreas T. Meyer-Berg
|
|
|
+ */
|
|
|
+public abstract class BasicPacketClassifier implements PacketSniffer {
|
|
|
+
|
|
|
+ /**
|
|
|
+ * True, if instances should be used for training
|
|
|
+ */
|
|
|
+ protected boolean training = true;
|
|
|
+
|
|
|
+ /**
|
|
|
+ * Attributes which should be taken into account
|
|
|
+ */
|
|
|
+ protected ArrayList<Attribute> atts = new ArrayList<Attribute>();
|
|
|
+
|
|
|
+ /**
|
|
|
+ * Collected Packets
|
|
|
+ */
|
|
|
+ protected Instances dataset;
|
|
|
+
|
|
|
+ /**
|
|
|
+ * CollectedPackets
|
|
|
+ */
|
|
|
+ protected HashMap<Link, LinkedList<Packet>> collectedPackets = new HashMap<Link, LinkedList<Packet>>();
|
|
|
+
|
|
|
+ /**
|
|
|
+ * HashMap for calculating transmission delay
|
|
|
+ */
|
|
|
+ protected HashMap<Link, LinkedList<Packet>> lastPackets = new HashMap<Link, LinkedList<Packet>>();
|
|
|
+
|
|
|
+ /**
|
|
|
+ * Map for the different Link names
|
|
|
+ */
|
|
|
+ protected HashSet<String> link_mappings = new HashSet<String>();
|
|
|
+
|
|
|
+ /**
|
|
|
+ * Map for the difference source device names
|
|
|
+ */
|
|
|
+ protected HashSet<String> source_mappings = new HashSet<String>();
|
|
|
+
|
|
|
+ /**
|
|
|
+ * Map for the different destination device names
|
|
|
+ */
|
|
|
+ protected HashSet<String> destination_mappings = new HashSet<String>();
|
|
|
+
|
|
|
+ /**
|
|
|
+ * Map for the protocol names
|
|
|
+ */
|
|
|
+ protected HashSet<String> protocol_mappings = new HashSet<String>();
|
|
|
+
|
|
|
+ /**
|
|
|
+ * Number of packets which are used to calculate the current transmission speed
|
|
|
+ */
|
|
|
+ protected int NUMBER_OF_PACKETS = 200;
|
|
|
+
|
|
|
+ /**
|
|
|
+ * Initializes the different maps
|
|
|
+ */
|
|
|
+ public BasicPacketClassifier() {
|
|
|
+ // Initialize Attribute list
|
|
|
+ source_mappings.add("unknown");
|
|
|
+ link_mappings.add("unknown");
|
|
|
+ destination_mappings.add("unknown");
|
|
|
+ protocol_mappings.add("unknown");
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void processPackets(HashMap<Link, LinkedList<Packet>> packets) {
|
|
|
+ if(training)
|
|
|
+ try {
|
|
|
+ training(packets);
|
|
|
+ } catch (Exception e) {
|
|
|
+ e.printStackTrace();
|
|
|
+ }
|
|
|
+ else
|
|
|
+ classify(packets);
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * Estimates the current Packets per second (depending on the last 100 packets of the link)
|
|
|
+ * @param link Link which should be checked
|
|
|
+ * @param packet Packet which should investigated
|
|
|
+ * @return estimated number of packets per second
|
|
|
+ */
|
|
|
+ protected double getEstimatedPacketsPerSecond(Link link, Packet packet) {
|
|
|
+ /**
|
|
|
+ * Packets used to calculated the packets per second
|
|
|
+ */
|
|
|
+ LinkedList<Packet> list = lastPackets.get(link);
|
|
|
+ if(list == null) {
|
|
|
+ /**
|
|
|
+ * Add list if not present
|
|
|
+ */
|
|
|
+ list = new LinkedList<Packet>();
|
|
|
+ lastPackets.put(link, list);
|
|
|
+ }
|
|
|
+ if(list.isEmpty()) {
|
|
|
+ list.addLast(packet);
|
|
|
+ // Default 1 packet per second
|
|
|
+ return 1.0;
|
|
|
+ }
|
|
|
+ if(list.size() == NUMBER_OF_PACKETS){
|
|
|
+ list.removeFirst();
|
|
|
+ }
|
|
|
+ list.addLast(packet);
|
|
|
+ /**
|
|
|
+ * elapsed time in milliseconds since last packet
|
|
|
+ */
|
|
|
+ long elapsed_time = packet.getTimestamp()-list.getFirst().getTimestamp()/list.size();
|
|
|
+ if(elapsed_time<=0)
|
|
|
+ return Double.POSITIVE_INFINITY;
|
|
|
+ /**
|
|
|
+ * Return number of packets per second
|
|
|
+ */
|
|
|
+ return 1000.0/elapsed_time;
|
|
|
+
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * Returns the instance representation of the given packet and link
|
|
|
+ * @param link link the packet was sent on
|
|
|
+ * @param packet packet which should be transformed
|
|
|
+ * @param dataset distribution the packet is part of
|
|
|
+ * @return instance representation
|
|
|
+ */
|
|
|
+ protected Instance packet2Instance(Link link, Packet packet, Instances dataset) {
|
|
|
+ /**
|
|
|
+ * Instance for the given Packet
|
|
|
+ */
|
|
|
+ DenseInstance instance = new DenseInstance(dataset.numAttributes());
|
|
|
+ instance.setDataset(dataset);
|
|
|
+
|
|
|
+ // link
|
|
|
+ instance.setValue(0, stringToNominal(link_mappings, link.getName()));
|
|
|
+
|
|
|
+ // source
|
|
|
+ if(packet.getSource()==null) {
|
|
|
+ instance.setValue(1, "unknown");
|
|
|
+ instance.setValue(2, Double.NEGATIVE_INFINITY);
|
|
|
+ }else if(packet.getSource().getOwner()==null){
|
|
|
+ instance.setValue(1, "unknown");
|
|
|
+ instance.setValue(2, packet.getSource().getPortNumber());
|
|
|
+ }else {
|
|
|
+ instance.setValue(1, stringToNominal(source_mappings, packet.getSource().getOwner().getName()));
|
|
|
+
|
|
|
+ instance.setValue(2, packet.getSource().getPortNumber());
|
|
|
+ }
|
|
|
+
|
|
|
+ // Destination
|
|
|
+ if(packet.getDestination()==null) {
|
|
|
+ instance.setValue(3, "unknown");
|
|
|
+ instance.setValue(4, Double.NEGATIVE_INFINITY);
|
|
|
+ }else if(packet.getDestination().getOwner()==null){
|
|
|
+ instance.setValue(3, "unknown");
|
|
|
+
|
|
|
+ instance.setValue(4, packet.getDestination().getPortNumber());
|
|
|
+ }else {
|
|
|
+ instance.setValue(3, stringToNominal(destination_mappings, packet.getDestination().getOwner().getName()));
|
|
|
+ instance.setValue(4, packet.getDestination().getPortNumber());
|
|
|
+ }
|
|
|
+
|
|
|
+ // Protocol name
|
|
|
+ instance.setValue(5, stringToNominal(protocol_mappings, packet.getProtocolName()));
|
|
|
+
|
|
|
+ // Packets per second
|
|
|
+ instance.setValue(6, getEstimatedPacketsPerSecond(link, packet));
|
|
|
+ // MQTT Value
|
|
|
+ if(packet instanceof MQTTpublishPacket)
|
|
|
+ instance.setValue(7, ((MQTTpublishPacket)packet).getValue());
|
|
|
+ else
|
|
|
+ instance.setValue(7, -1);
|
|
|
+
|
|
|
+ return instance;
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * Inserts the
|
|
|
+ * @param map
|
|
|
+ * @param nominal
|
|
|
+ */
|
|
|
+ protected void insertNominalIntoMap(HashSet<String> map, String nominal) {
|
|
|
+ if(map == null || nominal == null)
|
|
|
+ return;
|
|
|
+ map.add(nominal);
|
|
|
+ }
|
|
|
+ /**
|
|
|
+ * Transforms the String into an Number
|
|
|
+ * @param map
|
|
|
+ * @param s
|
|
|
+ * @return
|
|
|
+ */
|
|
|
+ protected String stringToNominal(HashSet<String> map, String s) {
|
|
|
+ return map.contains(s)?s:"unknown";
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * Train the clusterer by collecting the packets
|
|
|
+ *
|
|
|
+ * @param packets packets to be learned
|
|
|
+ */
|
|
|
+ protected void training(HashMap<Link, LinkedList<Packet>> packets) {
|
|
|
+ for(Entry<Link, LinkedList<Packet>> e:packets.entrySet()) {
|
|
|
+ Link l = e.getKey();
|
|
|
+ // TODO: ERROR ????????
|
|
|
+ LinkedList<Packet> p = collectedPackets.get(l);
|
|
|
+ if(p == null) {
|
|
|
+ collectedPackets.put(l, new LinkedList<Packet>(e.getValue()));
|
|
|
+ } else
|
|
|
+ p.addAll(e.getValue());
|
|
|
+ insertNominalIntoMap(link_mappings, l.getName());
|
|
|
+ for(Packet pac: e.getValue()) {
|
|
|
+ if(pac == null || pac.getSource()==null ||pac.getDestination() == null || pac.getSource().getOwner() == null || pac.getDestination().getOwner() == null)
|
|
|
+ continue;
|
|
|
+ insertNominalIntoMap(destination_mappings, pac.getSource().getOwner().getName());
|
|
|
+ insertNominalIntoMap(destination_mappings, pac.getDestination().getOwner().getName());
|
|
|
+ insertNominalIntoMap(source_mappings, pac.getSource().getOwner().getName());
|
|
|
+ insertNominalIntoMap(source_mappings, pac.getDestination().getOwner().getName());
|
|
|
+ insertNominalIntoMap(protocol_mappings, pac.getProtocolName());
|
|
|
+ }
|
|
|
+ //TODO: Add packet/Link/Names etc. to mappings
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * Finishes the collection and trains the clusterer on the collected packets
|
|
|
+ *
|
|
|
+ * @throws Exception
|
|
|
+ */
|
|
|
+ protected void finishDataCollection() throws Exception{
|
|
|
+ /**
|
|
|
+ printHashSet("Link-Name", link_mappings);
|
|
|
+ printHashSet("Source-Device", source_mappings);
|
|
|
+ printHashSet("Destination-Port", destination_mappings);
|
|
|
+ printHashSet("Protocol-name", protocol_mappings);
|
|
|
+ */
|
|
|
+ atts.add(new Attribute("Link-Name", new LinkedList<String>(link_mappings)));//TODO:??
|
|
|
+ atts.add(new Attribute("Source-Device", new LinkedList<String>(source_mappings)));
|
|
|
+ atts.add(new Attribute("Source-Port-number", false));
|
|
|
+ atts.add(new Attribute("Destination-Device", new LinkedList<String>(destination_mappings)));
|
|
|
+ atts.add(new Attribute("Destination-Port-number", false));
|
|
|
+ Attribute pn = new Attribute("Protocol-name", new LinkedList<String>(protocol_mappings));
|
|
|
+ //pn.setWeight(10);
|
|
|
+ atts.add(pn);
|
|
|
+ Attribute pps = new Attribute("Packets-per-second", false);
|
|
|
+ //pps.setWeight(20);
|
|
|
+ atts.add(pps);
|
|
|
+ atts.add(new Attribute("PacketValue", false));
|
|
|
+ //atts.add(new Attribute("Anomaly", false));
|
|
|
+
|
|
|
+ /*
|
|
|
+ atts = new ArrayList<Attribute>();
|
|
|
+ atts.add(new Attribute("LN", new LinkedList<String>(link_mappings)));//TODO:??
|
|
|
+ atts.add(new Attribute("SD", new LinkedList<String>(source_mappings)));
|
|
|
+ atts.add(new Attribute("SPN", false));
|
|
|
+ atts.add(new Attribute("DD", new LinkedList<String>(destination_mappings)));
|
|
|
+ atts.add(new Attribute("DPN", false));
|
|
|
+ atts.add(new Attribute("PN", new LinkedList<String>(protocol_mappings)));
|
|
|
+ atts.add(new Attribute("PPS", false));
|
|
|
+ atts.add(new Attribute("A", false));*/
|
|
|
+ dataset = new Instances("Packets", atts, 100000);
|
|
|
+ //dataset.setClassIndex(7);
|
|
|
+
|
|
|
+ /**
|
|
|
+ * Add Instances to dataset
|
|
|
+ */
|
|
|
+ for (Iterator<Entry<Link, LinkedList<Packet>>> it = collectedPackets.entrySet().iterator(); it.hasNext();) {
|
|
|
+ Entry<Link, LinkedList<Packet>> entry = it.next();
|
|
|
+ /**
|
|
|
+ * Link the packet was captured on
|
|
|
+ */
|
|
|
+ Link l = entry.getKey();
|
|
|
+ for (Iterator<Packet> itPacket = entry.getValue().iterator(); itPacket.hasNext();) {
|
|
|
+ /**
|
|
|
+ * Packets to be added to the dataset
|
|
|
+ */
|
|
|
+ Packet packet = (Packet) itPacket.next();
|
|
|
+ dataset.add(packet2Instance(l, packet, dataset));
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ trainModel(dataset);
|
|
|
+ }
|
|
|
+
|
|
|
+ private void printHashSet(String name, HashSet<String> toPrint) {
|
|
|
+ System.out.println(name+":");
|
|
|
+ for (Iterator<String> iterator = toPrint.iterator(); iterator.hasNext();) {
|
|
|
+ String string = (String) iterator.next();
|
|
|
+ System.out.print(string);
|
|
|
+ if(iterator.hasNext())
|
|
|
+ System.out.print(", ");
|
|
|
+ }
|
|
|
+ System.out.println();
|
|
|
+ }
|
|
|
+ /**
|
|
|
+ * Try to classify the given packets and detect anomalies
|
|
|
+ * @param packets packets to be classified
|
|
|
+ */
|
|
|
+ protected void classify(HashMap<Link, LinkedList<Packet>> packets) {
|
|
|
+ int tp = 0;
|
|
|
+ int fp = 0;
|
|
|
+ int tn = 0;
|
|
|
+ int fn = 0;
|
|
|
+ long start = Long.MAX_VALUE;
|
|
|
+ long end = Long.MIN_VALUE;
|
|
|
+ for (Iterator<Entry<Link, LinkedList<Packet>>> it = packets.entrySet().iterator(); it.hasNext();) {
|
|
|
+ /**
|
|
|
+ * Link & its packets
|
|
|
+ */
|
|
|
+ Entry<Link, LinkedList<Packet>> entry = it.next();
|
|
|
+ /**
|
|
|
+ * Link the packets were captured on
|
|
|
+ */
|
|
|
+ Link l = entry.getKey();
|
|
|
+ for (Iterator<Packet> itPacket = entry.getValue().iterator(); itPacket.hasNext();) {
|
|
|
+ /**
|
|
|
+ * Packet which should be checked
|
|
|
+ */
|
|
|
+ Packet packet = (Packet) itPacket.next();
|
|
|
+
|
|
|
+ start = Math.min(start, packet.getTimestamp());
|
|
|
+ end = Math.max(end, packet.getTimestamp());
|
|
|
+ /**
|
|
|
+ * Instance Representation
|
|
|
+ */
|
|
|
+ Instance packet_instance = packet2Instance(l, packet, dataset);
|
|
|
+
|
|
|
+ if(packet_instance == null)continue;
|
|
|
+ try {
|
|
|
+ double dist = classifyInstance(packet_instance, packet);
|
|
|
+ if(dist<=1.0) {
|
|
|
+ if(packet.getLabel()==0)
|
|
|
+ tn++;
|
|
|
+ else
|
|
|
+ fn++;
|
|
|
+ }else {
|
|
|
+ if(packet.getLabel()==0)
|
|
|
+ fp++;
|
|
|
+ else
|
|
|
+ tp++;
|
|
|
+ }
|
|
|
+ } catch (Exception e) {
|
|
|
+ if(packet.getLabel()==0)
|
|
|
+ fp++;
|
|
|
+ else
|
|
|
+ tp++;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ int n = tp+tn+fp+fn;
|
|
|
+ if(n!=0) {
|
|
|
+ System.out.println(getAlgoName()+" Performance: ["+start+"ms, "+end+"ms]");
|
|
|
+ System.out.println("n: "+n);
|
|
|
+ System.out.println("TP: "+tp);
|
|
|
+ System.out.println("FP: "+fp);
|
|
|
+ System.out.println("TN: "+tn);
|
|
|
+ System.out.println("FN: "+fn);
|
|
|
+ System.out.println("TPR: "+(tp/(tp+fn+0.0)));
|
|
|
+ System.out.println("FPR: "+(fp/(fp+tn+0.0)));
|
|
|
+ System.out.println("");
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * Train the model using the given instances
|
|
|
+ * @param instances training set, which should be learned
|
|
|
+ */
|
|
|
+ public abstract void trainModel(Instances instances);
|
|
|
+
|
|
|
+ /**
|
|
|
+ * classifies the given instance
|
|
|
+ * @param instance instance which should be classified
|
|
|
+ * @param origin original packet, which was transformed into the instance
|
|
|
+ * @return distance to next centroid
|
|
|
+ * @throws Exception if anomaly was detected
|
|
|
+ */
|
|
|
+ public abstract double classifyInstance(Instance instance, Packet origin) throws Exception;
|
|
|
+
|
|
|
+ /**
|
|
|
+ * Returns the timestep, after which the classifier should start classifying instead of training.
|
|
|
+ * @return timestep of the testing begin.
|
|
|
+ */
|
|
|
+ public abstract long getClassificationStart();
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void setMode(boolean testing) {
|
|
|
+ training = !testing;
|
|
|
+ if(testing) {
|
|
|
+ try {
|
|
|
+ finishDataCollection();
|
|
|
+ } catch (Exception e) {
|
|
|
+ System.out.println("Clustering failed");
|
|
|
+ e.printStackTrace();
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public boolean getMode() {
|
|
|
+ return !training;
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * Short String representation of the classifier
|
|
|
+ * @return
|
|
|
+ */
|
|
|
+ public abstract String getAlgoName();
|
|
|
+}
|