package classifier; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.LinkedList; import java.util.Map.Entry; import de.tu_darmstadt.tk.SmartHomeNetworkSim.core.Link; import de.tu_darmstadt.tk.SmartHomeNetworkSim.core.Packet; import de.tu_darmstadt.tk.SmartHomeNetworkSim.core.PacketSniffer; import de.tu_darmstadt.tk.SmartHomeNetworkSim.core.protocols.packets.MQTTpublishPacket; import weka.core.Attribute; import weka.core.DenseInstance; import weka.core.Instance; import weka.core.Instances; /** * Unsupervised Classifier Basis, which contains methods for transforming {@link Packet}s into {@link Instance}s. * * @author Andreas T. Meyer-Berg */ public abstract class BasicPacketClassifier implements PacketSniffer { /** * True, if instances should be used for training */ protected boolean training = true; /** * Attributes which should be taken into account */ protected ArrayList atts = new ArrayList(); /** * Collected Packets */ protected Instances dataset; /** * CollectedPackets */ protected HashMap> collectedPackets = new HashMap>(); /** * HashMap for calculating transmission delay */ protected HashMap> lastPackets = new HashMap>(); /** * Map for the different Link names */ protected HashSet link_mappings = new HashSet(); /** * Map for the difference source device names */ protected HashSet source_mappings = new HashSet(); /** * Map for the different destination device names */ protected HashSet destination_mappings = new HashSet(); /** * Map for the protocol names */ protected HashSet protocol_mappings = new HashSet(); /** * Number of packets which are used to calculate the current transmission speed */ protected int NUMBER_OF_PACKETS = 200; /** * Initializes the different maps */ public BasicPacketClassifier() { // Initialize Attribute list source_mappings.add("unknown"); link_mappings.add("unknown"); destination_mappings.add("unknown"); protocol_mappings.add("unknown"); } @Override public void processPackets(HashMap> packets) { if(training) try { training(packets); } catch (Exception e) { e.printStackTrace(); } else classify(packets); } /** * Estimates the current Packets per second (depending on the last 100 packets of the link) * @param link Link which should be checked * @param packet Packet which should investigated * @return estimated number of packets per second */ protected double getEstimatedPacketsPerSecond(Link link, Packet packet) { /** * Packets used to calculated the packets per second */ LinkedList list = lastPackets.get(link); if(list == null) { /** * Add list if not present */ list = new LinkedList(); lastPackets.put(link, list); } if(list.isEmpty()) { list.addLast(packet); // Default 1 packet per second return 1.0; } if(list.size() == NUMBER_OF_PACKETS){ list.removeFirst(); } list.addLast(packet); /** * elapsed time in milliseconds since last packet */ long elapsed_time = packet.getTimestamp()-list.getFirst().getTimestamp()/list.size(); if(elapsed_time<=0) return Double.POSITIVE_INFINITY; /** * Return number of packets per second */ return 1000.0/elapsed_time; } /** * Returns the instance representation of the given packet and link * @param link link the packet was sent on * @param packet packet which should be transformed * @param dataset distribution the packet is part of * @return instance representation */ protected Instance packet2Instance(Link link, Packet packet, Instances dataset) { /** * Instance for the given Packet */ DenseInstance instance = new DenseInstance(dataset.numAttributes()); instance.setDataset(dataset); // link instance.setValue(0, stringToNominal(link_mappings, link.getName())); // source if(packet.getSource()==null) { instance.setValue(1, "unknown"); instance.setValue(2, Double.NEGATIVE_INFINITY); }else if(packet.getSource().getOwner()==null){ instance.setValue(1, "unknown"); instance.setValue(2, packet.getSource().getPortNumber()); }else { instance.setValue(1, stringToNominal(source_mappings, packet.getSource().getOwner().getName())); instance.setValue(2, packet.getSource().getPortNumber()); } // Destination if(packet.getDestination()==null) { instance.setValue(3, "unknown"); instance.setValue(4, Double.NEGATIVE_INFINITY); }else if(packet.getDestination().getOwner()==null){ instance.setValue(3, "unknown"); instance.setValue(4, packet.getDestination().getPortNumber()); }else { instance.setValue(3, stringToNominal(destination_mappings, packet.getDestination().getOwner().getName())); instance.setValue(4, packet.getDestination().getPortNumber()); } // Protocol name instance.setValue(5, stringToNominal(protocol_mappings, packet.getProtocolName())); // Packets per second instance.setValue(6, getEstimatedPacketsPerSecond(link, packet)); // MQTT Value if(packet instanceof MQTTpublishPacket) instance.setValue(7, ((MQTTpublishPacket)packet).getValue()); else instance.setValue(7, -1); return instance; } /** * Inserts the * @param map * @param nominal */ protected void insertNominalIntoMap(HashSet map, String nominal) { if(map == null || nominal == null) return; map.add(nominal); } /** * Transforms the String into an Number * @param map * @param s * @return */ protected String stringToNominal(HashSet map, String s) { return map.contains(s)?s:"unknown"; } /** * Train the clusterer by collecting the packets * * @param packets packets to be learned */ protected void training(HashMap> packets) { for(Entry> e:packets.entrySet()) { Link l = e.getKey(); // TODO: ERROR ???????? LinkedList p = collectedPackets.get(l); if(p == null) { collectedPackets.put(l, new LinkedList(e.getValue())); } else p.addAll(e.getValue()); insertNominalIntoMap(link_mappings, l.getName()); for(Packet pac: e.getValue()) { if(pac == null || pac.getSource()==null ||pac.getDestination() == null || pac.getSource().getOwner() == null || pac.getDestination().getOwner() == null) continue; insertNominalIntoMap(destination_mappings, pac.getSource().getOwner().getName()); insertNominalIntoMap(destination_mappings, pac.getDestination().getOwner().getName()); insertNominalIntoMap(source_mappings, pac.getSource().getOwner().getName()); insertNominalIntoMap(source_mappings, pac.getDestination().getOwner().getName()); insertNominalIntoMap(protocol_mappings, pac.getProtocolName()); } //TODO: Add packet/Link/Names etc. to mappings } } /** * Finishes the collection and trains the clusterer on the collected packets * * @throws Exception */ protected void finishDataCollection() throws Exception{ /** printHashSet("Link-Name", link_mappings); printHashSet("Source-Device", source_mappings); printHashSet("Destination-Port", destination_mappings); printHashSet("Protocol-name", protocol_mappings); */ atts.add(new Attribute("Link-Name", new LinkedList(link_mappings)));//TODO:?? atts.add(new Attribute("Source-Device", new LinkedList(source_mappings))); atts.add(new Attribute("Source-Port-number", false)); atts.add(new Attribute("Destination-Device", new LinkedList(destination_mappings))); atts.add(new Attribute("Destination-Port-number", false)); Attribute pn = new Attribute("Protocol-name", new LinkedList(protocol_mappings)); //pn.setWeight(10); atts.add(pn); Attribute pps = new Attribute("Packets-per-second", false); //pps.setWeight(20); atts.add(pps); atts.add(new Attribute("PacketValue", false)); //atts.add(new Attribute("Anomaly", false)); /* atts = new ArrayList(); atts.add(new Attribute("LN", new LinkedList(link_mappings)));//TODO:?? atts.add(new Attribute("SD", new LinkedList(source_mappings))); atts.add(new Attribute("SPN", false)); atts.add(new Attribute("DD", new LinkedList(destination_mappings))); atts.add(new Attribute("DPN", false)); atts.add(new Attribute("PN", new LinkedList(protocol_mappings))); atts.add(new Attribute("PPS", false)); atts.add(new Attribute("A", false));*/ dataset = new Instances("Packets", atts, 100000); //dataset.setClassIndex(7); /** * Add Instances to dataset */ for (Iterator>> it = collectedPackets.entrySet().iterator(); it.hasNext();) { Entry> entry = it.next(); /** * Link the packet was captured on */ Link l = entry.getKey(); for (Iterator itPacket = entry.getValue().iterator(); itPacket.hasNext();) { /** * Packets to be added to the dataset */ Packet packet = (Packet) itPacket.next(); dataset.add(packet2Instance(l, packet, dataset)); } } trainModel(dataset); } private void printHashSet(String name, HashSet toPrint) { System.out.println(name+":"); for (Iterator iterator = toPrint.iterator(); iterator.hasNext();) { String string = (String) iterator.next(); System.out.print(string); if(iterator.hasNext()) System.out.print(", "); } System.out.println(); } /** * Try to classify the given packets and detect anomalies * @param packets packets to be classified */ protected void classify(HashMap> packets) { int tp = 0; int fp = 0; int tn = 0; int fn = 0; long start = Long.MAX_VALUE; long end = Long.MIN_VALUE; for (Iterator>> it = packets.entrySet().iterator(); it.hasNext();) { /** * Link & its packets */ Entry> entry = it.next(); /** * Link the packets were captured on */ Link l = entry.getKey(); for (Iterator itPacket = entry.getValue().iterator(); itPacket.hasNext();) { /** * Packet which should be checked */ Packet packet = (Packet) itPacket.next(); start = Math.min(start, packet.getTimestamp()); end = Math.max(end, packet.getTimestamp()); /** * Instance Representation */ Instance packet_instance = packet2Instance(l, packet, dataset); if(packet_instance == null)continue; try { double dist = classifyInstance(packet_instance, packet); if(dist<=1.0) { if(packet.getLabel()==0) tn++; else fn++; }else { if(packet.getLabel()==0) fp++; else tp++; } } catch (Exception e) { if(packet.getLabel()==0) fp++; else tp++; } } } int n = tp+tn+fp+fn; if(n!=0) { System.out.println(getAlgoName()+" Performance: ["+start+"ms, "+end+"ms]"); System.out.println("n: "+n); System.out.println("TP: "+tp); System.out.println("FP: "+fp); System.out.println("TN: "+tn); System.out.println("FN: "+fn); System.out.println("TPR: "+(tp/(tp+fn+0.0))); System.out.println("FPR: "+(fp/(fp+tn+0.0))); System.out.println(""); } } /** * Train the model using the given instances * @param instances training set, which should be learned */ public abstract void trainModel(Instances instances); /** * classifies the given instance * @param instance instance which should be classified * @param origin original packet, which was transformed into the instance * @return distance to next centroid * @throws Exception if anomaly was detected */ public abstract double classifyInstance(Instance instance, Packet origin) throws Exception; /** * Returns the timestep, after which the classifier should start classifying instead of training. * @return timestep of the testing begin. */ public abstract long getClassificationStart(); @Override public void setMode(boolean testing) { training = !testing; if(testing) { try { finishDataCollection(); } catch (Exception e) { System.out.println("Clustering failed"); e.printStackTrace(); } } } @Override public boolean getMode() { return !training; } /** * Short String representation of the classifier * @return */ public abstract String getAlgoName(); }