import java.util.ArrayList; import java.util.HashMap; import java.util.Iterator; import java.util.LinkedList; import java.util.Map.Entry; import de.tu_darmstadt.tk.SmartHomeNetworkSim.core.Link; import de.tu_darmstadt.tk.SmartHomeNetworkSim.core.Packet; import de.tu_darmstadt.tk.SmartHomeNetworkSim.core.PacketSniffer; import weka.clusterers.SimpleKMeans; import weka.core.Attribute; import weka.core.DenseInstance; import weka.core.Instance; import weka.core.Instances; /** * Unsupervised Example - maybe Clustering * * @author Andreas T. Meyer-Berg */ public class UnsupervisedAnomalyDetectionExample implements PacketSniffer { /** * Clusterer */ private SimpleKMeans clusterer; /** * True, if instances should be used for training */ private boolean training = true; /** * Attributes which should be taken into account */ private ArrayList atts = new ArrayList(); /** * Collected Packets */ private Instances dataset; /** * HashMap for calculating transmission delay */ private HashMap> lastPackets = new HashMap>(); /** * Number of Clusters */ private int NUMBER_OF_CLUSTERS = 2; /** * Number of packets used for number of packets per second */ private int NUMBER_OF_PACKETS = 30; /** * */ private HashMap link_mappings = new HashMap(); private HashMap source_mappings = new HashMap(); private HashMap destination_mappings = new HashMap(); private HashMap protocol_mappings = new HashMap(); /** * */ public UnsupervisedAnomalyDetectionExample() { // Initialize Attribute list link_mappings.put("unknown", 0); atts.add(new Attribute("Link-Name", false));//TODO:?? source_mappings.put("unknown", 0); atts.add(new Attribute("Source-Device", false)); atts.add(new Attribute("Source-Port-number", false)); destination_mappings.put("unknown", 0); atts.add(new Attribute("Destination-Device", false)); atts.add(new Attribute("Destination-Port-number", false)); protocol_mappings.put("unknown", 0); atts.add(new Attribute("Protocol-name", false)); atts.add(new Attribute("Packets-per-second", false)); // Initialize data set dataset = new Instances("Packets", atts, 100000); // Initialize Clusterer clusterer = new SimpleKMeans(); clusterer.setSeed(42); try { clusterer.setNumClusters(NUMBER_OF_CLUSTERS); } catch (Exception e) { System.out.println("Error while building cluster"); e.printStackTrace(); } } @Override public void processPackets(HashMap> packets) { if(training) try { training(packets); } catch (Exception e) { e.printStackTrace(); } else classify(packets); } /** * Estimates the current Packets per second (depending on the last 100 packets of the link) * @param link Link which should be checked * @param packet Packet which should investigated * @return estimated number of packets per second */ private double getEstimatedPacketsPerSecond(Link link, Packet packet) { /** * Packets used to calculated the packets per second */ LinkedList list = lastPackets.get(link); if(list == null) { /** * Add list if not present */ list = new LinkedList(); lastPackets.put(link, list); } if(list.isEmpty()) { list.addLast(packet); // Default 1 packet per second return 1.0; } if(list.size() == NUMBER_OF_PACKETS){ list.removeFirst(); } list.addLast(packet); /** * elapsed time in milliseconds since last packet */ long elapsed_time = packet.getTimestamp()-list.getFirst().getTimestamp()/list.size(); if(elapsed_time<=0) return Double.POSITIVE_INFINITY; /** * Return number of packets per second */ return 1000.0/elapsed_time; } /** * Returns the instance representation of the given packet and link * @param link link the packet was sent on * @param packet packet which should be transformed * @param dataset distribution the packet is part of * @return instance representation */ private Instance packet2Instance(Link link, Packet packet, Instances dataset) { /** * Instance for the given Packet */ DenseInstance instance = new DenseInstance(dataset.numAttributes()); instance.setDataset(dataset); // link instance.setValue(0, link == null ? 0 : stringToNumber(link_mappings, link.getName())); // source if(packet.getSource()==null) { instance.setValue(1, 0); instance.setValue(2, Double.NEGATIVE_INFINITY); }else if(packet.getSource().getOwner()==null){ instance.setValue(1, 0); instance.setValue(2, packet.getSource().getPortNumber()); }else { instance.setValue(1, stringToNumber(source_mappings, packet.getSource().getOwner().getName())); instance.setValue(2, packet.getSource().getPortNumber()); } // Destination if(packet.getDestination()==null) { instance.setValue(3, 0); instance.setValue(4, Double.NEGATIVE_INFINITY); }else if(packet.getDestination().getOwner()==null){ instance.setValue(3, 0); instance.setValue(4, packet.getDestination().getPortNumber()); }else { instance.setValue(3, stringToNumber(destination_mappings, packet.getDestination().getOwner().getName())); instance.setValue(4, packet.getDestination().getPortNumber()); } // Protocol name instance.setValue(5, stringToNumber(protocol_mappings, packet.getProtocolName())); // Packets per second instance.setValue(6, getEstimatedPacketsPerSecond(link, packet)); return instance; } /** * Transforms the String into an Number * @param map * @param s * @return */ double stringToNumber(HashMap map, String s) { Integer i = map.get(s); if(i == null) { int size = map.size(); map.put(s, size); return size; }else { return i; } } /** * Train the clusterer by collecting the packets * * @param packets packets to be learned */ private void training(HashMap> packets) { for (Iterator>> it = packets.entrySet().iterator(); it.hasNext();) { Entry> entry = it.next(); /** * Link the packet was captured on */ Link l = entry.getKey(); for (Iterator itPacket = entry.getValue().iterator(); itPacket.hasNext();) { /** * Packets to be added to the dataset */ Packet packet = (Packet) itPacket.next(); dataset.add(packet2Instance(l, packet, dataset)); } } } /** * Finishes the collection and trains the clusterer on the collected packets * * @throws Exception */ private void finishDataCollection() throws Exception{ /** * Build the clusterer for the given dataset */ clusterer.buildClusterer(dataset); } /** * Try to classify the given packets and detect anomalies * @param packets packets to be classified */ private void classify(HashMap> packets) { for (Iterator>> it = packets.entrySet().iterator(); it.hasNext();) { /** * Link & its packets */ Entry> entry = it.next(); /** * Link the packets were captured on */ Link l = entry.getKey(); for (Iterator itPacket = entry.getValue().iterator(); itPacket.hasNext();) { /** * Packet which should be checked */ Packet packet = (Packet) itPacket.next(); /** * Instance Representation */ Instance packet_instance = packet2Instance(l, packet, dataset); try { /** * Try to classify (find appropriate cluster) */ clusterer.clusterInstance(packet_instance); } catch (Exception e) { /** * Anomaly found */ System.out.println("Anomaly: "+packet.getTextualRepresentation()); //e.printStackTrace(); } } } } @Override public void setMode(boolean testing) { training = !testing; if(testing) { // Build Clusterer try { finishDataCollection(); } catch (Exception e) { System.out.println("Clustering failed"); e.printStackTrace(); } } } @Override public boolean getMode() { return !training; } }