import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.LinkedList; import java.util.Map.Entry; import de.tu_darmstadt.tk.SmartHomeNetworkSim.core.Link; import de.tu_darmstadt.tk.SmartHomeNetworkSim.core.Packet; import de.tu_darmstadt.tk.SmartHomeNetworkSim.core.PacketSniffer; import weka.clusterers.SimpleKMeans; import weka.core.Attribute; import weka.core.DenseInstance; import weka.core.Instance; import weka.core.Instances; /** * Unsupervised Example - maybe Clustering * * @author Andreas T. Meyer-Berg */ public class UnsupervisedAnomalyDetectionExample2 implements PacketSniffer { /** * Clusterer */ private SimpleKMeans clusterer; /** * True, if instances should be used for training */ private boolean training = true; /** * Attributes which should be taken into account */ private ArrayList atts = new ArrayList(); /** * Collected Packets */ private Instances dataset; /** * CollectedPackets */ private HashMap> collectedPackets = new HashMap>(); /** * HashMap for calculating transmission delay */ private HashMap> lastPackets = new HashMap>(); /** * Number of Clusters */ private int NUMBER_OF_CLUSTERS = 2; /** * Number of packets used for number of packets per second */ private int NUMBER_OF_PACKETS = 30; /** * */ private HashSet link_mappings = new HashSet(); private HashSet source_mappings = new HashSet(); private HashSet destination_mappings = new HashSet(); private HashSet protocol_mappings = new HashSet(); /** * */ public UnsupervisedAnomalyDetectionExample2() { // Initialize Attribute list source_mappings.add("unknown"); link_mappings.add("unknown"); destination_mappings.add("unknown"); protocol_mappings.add("unknown"); // Initialize data set // Initialize Clusterer clusterer = new SimpleKMeans(); clusterer.setSeed(42); try { clusterer.setNumClusters(NUMBER_OF_CLUSTERS); } catch (Exception e) { System.out.println("Error while building cluster"); e.printStackTrace(); } } @Override public void processPackets(HashMap> packets) { if(training) try { training(packets); } catch (Exception e) { e.printStackTrace(); } else classify(packets); } /** * Estimates the current Packets per second (depending on the last 100 packets of the link) * @param link Link which should be checked * @param packet Packet which should investigated * @return estimated number of packets per second */ private double getEstimatedPacketsPerSecond(Link link, Packet packet) { /** * Packets used to calculated the packets per second */ LinkedList list = lastPackets.get(link); if(list == null) { /** * Add list if not present */ list = new LinkedList(); lastPackets.put(link, list); } if(list.isEmpty()) { list.addLast(packet); // Default 1 packet per second return 1.0; } if(list.size() == NUMBER_OF_PACKETS){ list.removeFirst(); } list.addLast(packet); /** * elapsed time in milliseconds since last packet */ long elapsed_time = packet.getTimestamp()-list.getFirst().getTimestamp()/list.size(); if(elapsed_time<=0) return Double.POSITIVE_INFINITY; /** * Return number of packets per second */ return 1000.0/elapsed_time; } /** * Returns the instance representation of the given packet and link * @param link link the packet was sent on * @param packet packet which should be transformed * @param dataset distribution the packet is part of * @return instance representation */ private Instance packet2Instance(Link link, Packet packet, Instances dataset) { /** * Instance for the given Packet */ DenseInstance instance = new DenseInstance(dataset.numAttributes()); instance.setDataset(dataset); // link instance.setValue(0, stringToNominal(link_mappings, link.getName())); // source if(packet.getSource()==null) { instance.setValue(1, "unknown"); instance.setValue(2, Double.NEGATIVE_INFINITY); }else if(packet.getSource().getOwner()==null){ instance.setValue(1, "unknown"); instance.setValue(2, packet.getSource().getPortNumber()); }else { instance.setValue(1, stringToNominal(source_mappings, packet.getSource().getOwner().getName())); instance.setValue(2, packet.getSource().getPortNumber()); } // Destination if(packet.getDestination()==null) { instance.setValue(3, "unknown"); instance.setValue(4, Double.NEGATIVE_INFINITY); }else if(packet.getDestination().getOwner()==null){ instance.setValue(3, "unknown"); instance.setValue(4, packet.getDestination().getPortNumber()); }else { instance.setValue(3, stringToNominal(destination_mappings, packet.getDestination().getOwner().getName())); instance.setValue(4, packet.getDestination().getPortNumber()); } // Protocol name instance.setValue(5, stringToNominal(protocol_mappings, packet.getProtocolName())); // Packets per second instance.setValue(6, getEstimatedPacketsPerSecond(link, packet)); return instance; } /** * Transforms the String into an Number * @param map * @param s * @return */ String stringToNominal(HashSet map, String s) { return map.contains(s)?s:"unknown"; } /** * Train the clusterer by collecting the packets * * @param packets packets to be learned */ private void training(HashMap> packets) { for(Entry> e:packets.entrySet()) { Link l = e.getKey(); LinkedList p = collectedPackets.get(l); if(p == null) collectedPackets.put(l, new LinkedList(e.getValue())); else p.addAll(e.getValue()); } } /** * Finishes the collection and trains the clusterer on the collected packets * * @throws Exception */ private void finishDataCollection() throws Exception{ atts.add(new Attribute("Link-Name", new LinkedList(link_mappings)));//TODO:?? atts.add(new Attribute("Source-Device", new LinkedList(source_mappings))); atts.add(new Attribute("Source-Port-number", false)); atts.add(new Attribute("Destination-Device", new LinkedList(destination_mappings))); atts.add(new Attribute("Destination-Port-number", false)); Attribute pn = new Attribute("Protocol-name", new LinkedList(protocol_mappings)); //pn.setWeight(10); atts.add(pn); Attribute pps = new Attribute("Packets-per-second", false); //pps.setWeight(20); atts.add(pps); //atts.add(new Attribute("Anomaly", false)); /* atts = new ArrayList(); atts.add(new Attribute("LN", new LinkedList(link_mappings)));//TODO:?? atts.add(new Attribute("SD", new LinkedList(source_mappings))); atts.add(new Attribute("SPN", false)); atts.add(new Attribute("DD", new LinkedList(destination_mappings))); atts.add(new Attribute("DPN", false)); atts.add(new Attribute("PN", new LinkedList(protocol_mappings))); atts.add(new Attribute("PPS", false)); atts.add(new Attribute("A", false));*/ dataset = new Instances("Packets", atts, 100000); //dataset.setClassIndex(7); /** * Add Instances to dataset */ for (Iterator>> it = collectedPackets.entrySet().iterator(); it.hasNext();) { Entry> entry = it.next(); /** * Link the packet was captured on */ Link l = entry.getKey(); for (Iterator itPacket = entry.getValue().iterator(); itPacket.hasNext();) { /** * Packets to be added to the dataset */ Packet packet = (Packet) itPacket.next(); dataset.add(packet2Instance(l, packet, dataset)); } } /** * Build the clusterer for the given dataset */ clusterer.buildClusterer(dataset); } /** * Try to classify the given packets and detect anomalies * @param packets packets to be classified */ private void classify(HashMap> packets) { for (Iterator>> it = packets.entrySet().iterator(); it.hasNext();) { /** * Link & its packets */ Entry> entry = it.next(); /** * Link the packets were captured on */ Link l = entry.getKey(); for (Iterator itPacket = entry.getValue().iterator(); itPacket.hasNext();) { /** * Packet which should be checked */ Packet packet = (Packet) itPacket.next(); /** * Instance Representation */ Instance packet_instance = packet2Instance(l, packet, dataset); if(packet_instance == null)continue; try { /** * Try to classify (find appropriate cluster) */ int c = clusterer.clusterInstance(packet_instance); System.out.println("Cluster "+c+": "+packet.getTextualRepresentation()); } catch (Exception e) { /** * Anomaly found */ System.out.println("Anomaly: "+packet.getTextualRepresentation()); e.printStackTrace(); } } } } @Override public void setMode(boolean testing) { training = !testing; if(testing) { // Build Clusterer try { finishDataCollection(); } catch (Exception e) { System.out.println("Clustering failed"); e.printStackTrace(); } } } @Override public boolean getMode() { return !training; } }