123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426 |
- package classifier;
- import java.util.ArrayList;
- import java.util.HashMap;
- import java.util.HashSet;
- import java.util.Iterator;
- import java.util.LinkedList;
- import java.util.Map.Entry;
- import de.tu_darmstadt.tk.SmartHomeNetworkSim.core.Link;
- import de.tu_darmstadt.tk.SmartHomeNetworkSim.core.Packet;
- import de.tu_darmstadt.tk.SmartHomeNetworkSim.core.PacketSniffer;
- import de.tu_darmstadt.tk.SmartHomeNetworkSim.core.protocols.packets.MQTTpublishPacket;
- import weka.core.Attribute;
- import weka.core.DenseInstance;
- import weka.core.Instance;
- import weka.core.Instances;
- /**
- * Unsupervised Classifier Basis, which contains methods for transforming {@link Packet}s into {@link Instance}s.
- *
- * @author Andreas T. Meyer-Berg
- */
- public abstract class BasicPacketClassifier implements PacketSniffer {
- /**
- * True, if instances should be used for training
- */
- protected boolean training = true;
-
- /**
- * Attributes which should be taken into account
- */
- protected ArrayList<Attribute> atts = new ArrayList<Attribute>();
-
- /**
- * Collected Packets
- */
- protected Instances dataset;
-
- /**
- * CollectedPackets
- */
- protected HashMap<Link, LinkedList<Packet>> collectedPackets = new HashMap<Link, LinkedList<Packet>>();
-
- /**
- * HashMap for calculating transmission delay
- */
- protected HashMap<Link, LinkedList<Packet>> lastPackets = new HashMap<Link, LinkedList<Packet>>();
-
- /**
- * Map for the different Link names
- */
- protected HashSet<String> link_mappings = new HashSet<String>();
- /**
- * Map for the difference source device names
- */
- protected HashSet<String> source_mappings = new HashSet<String>();
-
- /**
- * Map for the different destination device names
- */
- protected HashSet<String> destination_mappings = new HashSet<String>();
-
- /**
- * Map for the protocol names
- */
- protected HashSet<String> protocol_mappings = new HashSet<String>();
- /**
- * Number of packets which are used to calculate the current transmission speed
- */
- protected int NUMBER_OF_PACKETS = 200;
-
- /**
- * Initializes the different maps
- */
- public BasicPacketClassifier() {
- // Initialize Attribute list
- source_mappings.add("unknown");
- link_mappings.add("unknown");
- destination_mappings.add("unknown");
- protocol_mappings.add("unknown");
- }
-
- @Override
- public void processPackets(HashMap<Link, LinkedList<Packet>> packets) {
- if(training)
- try {
- training(packets);
- } catch (Exception e) {
- e.printStackTrace();
- }
- else
- classify(packets);
- }
-
- /**
- * Estimates the current Packets per second (depending on the last 100 packets of the link)
- * @param link Link which should be checked
- * @param packet Packet which should investigated
- * @return estimated number of packets per second
- */
- protected double getEstimatedPacketsPerSecond(Link link, Packet packet) {
- /**
- * Packets used to calculated the packets per second
- */
- LinkedList<Packet> list = lastPackets.get(link);
- if(list == null) {
- /**
- * Add list if not present
- */
- list = new LinkedList<Packet>();
- lastPackets.put(link, list);
- }
- if(list.isEmpty()) {
- list.addLast(packet);
- // Default 1 packet per second
- return 1.0;
- }
- if(list.size() == NUMBER_OF_PACKETS){
- list.removeFirst();
- }
- list.addLast(packet);
- /**
- * elapsed time in milliseconds since last packet
- */
- long elapsed_time = packet.getTimestamp()-list.getFirst().getTimestamp()/list.size();
- if(elapsed_time<=0)
- return Double.POSITIVE_INFINITY;
- /**
- * Return number of packets per second
- */
- return 1000.0/elapsed_time;
-
- }
-
- /**
- * Returns the instance representation of the given packet and link
- * @param link link the packet was sent on
- * @param packet packet which should be transformed
- * @param dataset distribution the packet is part of
- * @return instance representation
- */
- protected Instance packet2Instance(Link link, Packet packet, Instances dataset) {
- /**
- * Instance for the given Packet
- */
- DenseInstance instance = new DenseInstance(dataset.numAttributes());
- instance.setDataset(dataset);
-
- // link
- instance.setValue(0, stringToNominal(link_mappings, link.getName()));
-
- // source
- if(packet.getSource()==null) {
- instance.setValue(1, "unknown");
- instance.setValue(2, Double.NEGATIVE_INFINITY);
- }else if(packet.getSource().getOwner()==null){
- instance.setValue(1, "unknown");
- instance.setValue(2, packet.getSource().getPortNumber());
- }else {
- instance.setValue(1, stringToNominal(source_mappings, packet.getSource().getOwner().getName()));
- instance.setValue(2, packet.getSource().getPortNumber());
- }
-
- // Destination
- if(packet.getDestination()==null) {
- instance.setValue(3, "unknown");
- instance.setValue(4, Double.NEGATIVE_INFINITY);
- }else if(packet.getDestination().getOwner()==null){
- instance.setValue(3, "unknown");
- instance.setValue(4, packet.getDestination().getPortNumber());
- }else {
- instance.setValue(3, stringToNominal(destination_mappings, packet.getDestination().getOwner().getName()));
- instance.setValue(4, packet.getDestination().getPortNumber());
- }
-
- // Protocol name
- instance.setValue(5, stringToNominal(protocol_mappings, packet.getProtocolName()));
-
- // Packets per second
- instance.setValue(6, getEstimatedPacketsPerSecond(link, packet));
- // MQTT Value
- if(packet instanceof MQTTpublishPacket)
- instance.setValue(7, ((MQTTpublishPacket)packet).getValue());
- else
- instance.setValue(7, -1);
-
- return instance;
- }
-
- /**
- * Inserts the
- * @param map
- * @param nominal
- */
- protected void insertNominalIntoMap(HashSet<String> map, String nominal) {
- if(map == null || nominal == null)
- return;
- map.add(nominal);
- }
- /**
- * Transforms the String into an Number
- * @param map
- * @param s
- * @return
- */
- protected String stringToNominal(HashSet<String> map, String s) {
- return map.contains(s)?s:"unknown";
- }
-
- /**
- * Train the clusterer by collecting the packets
- *
- * @param packets packets to be learned
- */
- protected void training(HashMap<Link, LinkedList<Packet>> packets) {
- for(Entry<Link, LinkedList<Packet>> e:packets.entrySet()) {
- Link l = e.getKey();
- // TODO: ERROR ????????
- LinkedList<Packet> p = collectedPackets.get(l);
- if(p == null) {
- collectedPackets.put(l, new LinkedList<Packet>(e.getValue()));
- } else
- p.addAll(e.getValue());
- insertNominalIntoMap(link_mappings, l.getName());
- for(Packet pac: e.getValue()) {
- if(pac == null || pac.getSource()==null ||pac.getDestination() == null || pac.getSource().getOwner() == null || pac.getDestination().getOwner() == null)
- continue;
- insertNominalIntoMap(destination_mappings, pac.getSource().getOwner().getName());
- insertNominalIntoMap(destination_mappings, pac.getDestination().getOwner().getName());
- insertNominalIntoMap(source_mappings, pac.getSource().getOwner().getName());
- insertNominalIntoMap(source_mappings, pac.getDestination().getOwner().getName());
- insertNominalIntoMap(protocol_mappings, pac.getProtocolName());
- }
- //TODO: Add packet/Link/Names etc. to mappings
- }
- }
-
- /**
- * Finishes the collection and trains the clusterer on the collected packets
- *
- * @throws Exception
- */
- protected void finishDataCollection() throws Exception{
- /**
- printHashSet("Link-Name", link_mappings);
- printHashSet("Source-Device", source_mappings);
- printHashSet("Destination-Port", destination_mappings);
- printHashSet("Protocol-name", protocol_mappings);
- */
- atts.add(new Attribute("Link-Name", new LinkedList<String>(link_mappings)));//TODO:??
- atts.add(new Attribute("Source-Device", new LinkedList<String>(source_mappings)));
- atts.add(new Attribute("Source-Port-number", false));
- atts.add(new Attribute("Destination-Device", new LinkedList<String>(destination_mappings)));
- atts.add(new Attribute("Destination-Port-number", false));
- Attribute pn = new Attribute("Protocol-name", new LinkedList<String>(protocol_mappings));
- //pn.setWeight(10);
- atts.add(pn);
- Attribute pps = new Attribute("Packets-per-second", false);
- //pps.setWeight(20);
- atts.add(pps);
- atts.add(new Attribute("PacketValue", false));
- //atts.add(new Attribute("Anomaly", false));
- /*
- atts = new ArrayList<Attribute>();
- atts.add(new Attribute("LN", new LinkedList<String>(link_mappings)));//TODO:??
- atts.add(new Attribute("SD", new LinkedList<String>(source_mappings)));
- atts.add(new Attribute("SPN", false));
- atts.add(new Attribute("DD", new LinkedList<String>(destination_mappings)));
- atts.add(new Attribute("DPN", false));
- atts.add(new Attribute("PN", new LinkedList<String>(protocol_mappings)));
- atts.add(new Attribute("PPS", false));
- atts.add(new Attribute("A", false));*/
- dataset = new Instances("Packets", atts, 100000);
- //dataset.setClassIndex(7);
- /**
- * Add Instances to dataset
- */
- for (Iterator<Entry<Link, LinkedList<Packet>>> it = collectedPackets.entrySet().iterator(); it.hasNext();) {
- Entry<Link, LinkedList<Packet>> entry = it.next();
- /**
- * Link the packet was captured on
- */
- Link l = entry.getKey();
- for (Iterator<Packet> itPacket = entry.getValue().iterator(); itPacket.hasNext();) {
- /**
- * Packets to be added to the dataset
- */
- Packet packet = (Packet) itPacket.next();
- dataset.add(packet2Instance(l, packet, dataset));
- }
- }
-
- trainModel(dataset);
- }
-
- private void printHashSet(String name, HashSet<String> toPrint) {
- System.out.println(name+":");
- for (Iterator<String> iterator = toPrint.iterator(); iterator.hasNext();) {
- String string = (String) iterator.next();
- System.out.print(string);
- if(iterator.hasNext())
- System.out.print(", ");
- }
- System.out.println();
- }
- /**
- * Try to classify the given packets and detect anomalies
- * @param packets packets to be classified
- */
- protected void classify(HashMap<Link, LinkedList<Packet>> packets) {
- int tp = 0;
- int fp = 0;
- int tn = 0;
- int fn = 0;
- long start = Long.MAX_VALUE;
- long end = Long.MIN_VALUE;
- for (Iterator<Entry<Link, LinkedList<Packet>>> it = packets.entrySet().iterator(); it.hasNext();) {
- /**
- * Link & its packets
- */
- Entry<Link, LinkedList<Packet>> entry = it.next();
- /**
- * Link the packets were captured on
- */
- Link l = entry.getKey();
- for (Iterator<Packet> itPacket = entry.getValue().iterator(); itPacket.hasNext();) {
- /**
- * Packet which should be checked
- */
- Packet packet = (Packet) itPacket.next();
- start = Math.min(start, packet.getTimestamp());
- end = Math.max(end, packet.getTimestamp());
- /**
- * Instance Representation
- */
- Instance packet_instance = packet2Instance(l, packet, dataset);
-
- if(packet_instance == null)continue;
- try {
- double dist = classifyInstance(packet_instance, packet);
- if(dist<=1.0) {
- if(packet.getLabel()==0)
- tn++;
- else
- fn++;
- }else {
- if(packet.getLabel()==0)
- fp++;
- else
- tp++;
- }
- } catch (Exception e) {
- if(packet.getLabel()==0)
- fp++;
- else
- tp++;
- }
- }
- }
- int n = tp+tn+fp+fn;
- if(n!=0) {
- System.out.println(getAlgoName()+" Performance: ["+start+"ms, "+end+"ms]");
- System.out.println("n: "+n);
- System.out.println("TP: "+tp);
- System.out.println("FP: "+fp);
- System.out.println("TN: "+tn);
- System.out.println("FN: "+fn);
- System.out.println("TPR: "+(tp/(tp+fn+0.0)));
- System.out.println("FPR: "+(fp/(fp+tn+0.0)));
- System.out.println("");
- }
- }
-
- /**
- * Train the model using the given instances
- * @param instances training set, which should be learned
- */
- public abstract void trainModel(Instances instances);
-
- /**
- * classifies the given instance
- * @param instance instance which should be classified
- * @param origin original packet, which was transformed into the instance
- * @return distance to next centroid
- * @throws Exception if anomaly was detected
- */
- public abstract double classifyInstance(Instance instance, Packet origin) throws Exception;
-
- /**
- * Returns the timestep, after which the classifier should start classifying instead of training.
- * @return timestep of the testing begin.
- */
- public abstract long getClassificationStart();
-
- @Override
- public void setMode(boolean testing) {
- training = !testing;
- if(testing) {
- try {
- finishDataCollection();
- } catch (Exception e) {
- System.out.println("Clustering failed");
- e.printStackTrace();
- }
- }
- }
-
- @Override
- public boolean getMode() {
- return !training;
- }
-
- /**
- * Short String representation of the classifier
- * @return
- */
- public abstract String getAlgoName();
- }
|