UnsupervisedAnomalyDetectionExample2.java 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. import java.util.ArrayList;
  2. import java.util.HashMap;
  3. import java.util.HashSet;
  4. import java.util.Iterator;
  5. import java.util.LinkedList;
  6. import java.util.Map.Entry;
  7. import de.tu_darmstadt.tk.SmartHomeNetworkSim.core.Link;
  8. import de.tu_darmstadt.tk.SmartHomeNetworkSim.core.Packet;
  9. import de.tu_darmstadt.tk.SmartHomeNetworkSim.core.PacketSniffer;
  10. import weka.clusterers.SimpleKMeans;
  11. import weka.core.Attribute;
  12. import weka.core.DenseInstance;
  13. import weka.core.Instance;
  14. import weka.core.Instances;
  15. import weka.core.NominalAttributeInfo;
  16. /**
  17. * Unsupervised Example - maybe Clustering
  18. *
  19. * @author Andreas T. Meyer-Berg
  20. */
  21. public class UnsupervisedAnomalyDetectionExample2 implements PacketSniffer {
  22. /**
  23. * Clusterer
  24. */
  25. private SimpleKMeans clusterer;
  26. /**
  27. * True, if instances should be used for training
  28. */
  29. private boolean training = true;
  30. /**
  31. * Attributes which should be taken into account
  32. */
  33. private ArrayList<Attribute> atts = new ArrayList<Attribute>();
  34. /**
  35. * Collected Packets
  36. */
  37. private Instances dataset;
  38. /**
  39. * CollectedPackets
  40. */
  41. private HashMap<Link, LinkedList<Packet>> collectedPackets = new HashMap<Link, LinkedList<Packet>>();
  42. /**
  43. * HashMap for calculating transmission delay
  44. */
  45. private HashMap<Link, LinkedList<Packet>> lastPackets = new HashMap<Link, LinkedList<Packet>>();
  46. /**
  47. * Number of Clusters
  48. */
  49. private int NUMBER_OF_CLUSTERS = 30;
  50. /**
  51. * Number of packets used for number of packets per second
  52. */
  53. private int NUMBER_OF_PACKETS = 30;
  54. /**
  55. *
  56. */
  57. private HashSet<String> link_mappings = new HashSet<String>();
  58. private HashSet<String> source_mappings = new HashSet<String>();
  59. private HashSet<String> destination_mappings = new HashSet<String>();
  60. private HashSet<String> protocol_mappings = new HashSet<String>();
  61. /**
  62. *
  63. */
  64. public UnsupervisedAnomalyDetectionExample2() {
  65. // Initialize Attribute list
  66. source_mappings.add("unknown");
  67. link_mappings.add("unknown");
  68. destination_mappings.add("unknown");
  69. protocol_mappings.add("unknown");
  70. // Initialize data set
  71. // Initialize Clusterer
  72. clusterer = new SimpleKMeans();
  73. clusterer.setSeed(42);
  74. try {
  75. clusterer.setNumClusters(NUMBER_OF_CLUSTERS);
  76. } catch (Exception e) {
  77. System.out.println("Error while building cluster");
  78. e.printStackTrace();
  79. }
  80. }
  81. @Override
  82. public void processPackets(HashMap<Link, LinkedList<Packet>> packets) {
  83. if(training && !packets.entrySet().isEmpty() && !packets.entrySet().iterator().next().getValue().isEmpty() && packets.entrySet().iterator().next().getValue().getFirst().getTimestamp()>10000) {
  84. training = false;
  85. // Build Clusterer
  86. try {
  87. finishDataCollection();
  88. } catch (Exception e) {
  89. System.out.println("Clustering failed");
  90. e.printStackTrace();
  91. }
  92. }
  93. if(training)
  94. try {
  95. training(packets);
  96. } catch (Exception e) {
  97. e.printStackTrace();
  98. }
  99. else
  100. classify(packets);
  101. }
  102. /**
  103. * Estimates the current Packets per second (depending on the last 100 packets of the link)
  104. * @param link Link which should be checked
  105. * @param packet Packet which should investigated
  106. * @return estimated number of packets per second
  107. */
  108. private double getEstimatedPacketsPerSecond(Link link, Packet packet) {
  109. /**
  110. * Packets used to calculated the packets per second
  111. */
  112. LinkedList<Packet> list = lastPackets.get(link);
  113. if(list == null) {
  114. /**
  115. * Add list if not present
  116. */
  117. list = new LinkedList<Packet>();
  118. lastPackets.put(link, list);
  119. }
  120. if(list.isEmpty()) {
  121. list.addLast(packet);
  122. // Default 1 packet per second
  123. return 1.0;
  124. }
  125. if(list.size() == NUMBER_OF_PACKETS){
  126. list.removeFirst();
  127. }
  128. list.addLast(packet);
  129. /**
  130. * elapsed time in milliseconds since last packet
  131. */
  132. long elapsed_time = packet.getTimestamp()-list.getFirst().getTimestamp()/list.size();
  133. if(elapsed_time<=0)
  134. return Double.POSITIVE_INFINITY;
  135. /**
  136. * Return number of packets per second
  137. */
  138. return 1000.0/elapsed_time;
  139. }
  140. /**
  141. * Returns the instance representation of the given packet and link
  142. * @param link link the packet was sent on
  143. * @param packet packet which should be transformed
  144. * @param dataset distribution the packet is part of
  145. * @return instance representation
  146. */
  147. private Instance packet2Instance(Link link, Packet packet, Instances dataset) {
  148. /**
  149. * Instance for the given Packet
  150. */
  151. DenseInstance instance = new DenseInstance(dataset.numAttributes());
  152. instance.setDataset(dataset);
  153. // link
  154. instance.setValue(0, stringToNominal(link_mappings, link.getName()));
  155. // source
  156. if(packet.getSource()==null) {
  157. instance.setValue(1, "unknown");
  158. instance.setValue(2, Double.NEGATIVE_INFINITY);
  159. }else if(packet.getSource().getOwner()==null){
  160. instance.setValue(1, "unknown");
  161. instance.setValue(2, packet.getSource().getPortNumber());
  162. }else {
  163. instance.setValue(1, stringToNominal(source_mappings, packet.getSource().getOwner().getName()));
  164. instance.setValue(2, packet.getSource().getPortNumber());
  165. }
  166. // Destination
  167. if(packet.getDestination()==null) {
  168. instance.setValue(3, "unknown");
  169. instance.setValue(4, Double.NEGATIVE_INFINITY);
  170. }else if(packet.getDestination().getOwner()==null){
  171. instance.setValue(3, "unknown");
  172. instance.setValue(4, packet.getDestination().getPortNumber());
  173. }else {
  174. instance.setValue(3, stringToNominal(destination_mappings, packet.getDestination().getOwner().getName()));
  175. instance.setValue(4, packet.getDestination().getPortNumber());
  176. }
  177. // Protocol name
  178. instance.setValue(5, stringToNominal(protocol_mappings, packet.getProtocolName()));
  179. // Packets per second
  180. instance.setValue(6, getEstimatedPacketsPerSecond(link, packet));
  181. return instance;
  182. }
  183. /**
  184. * Transforms the String into an Number
  185. * @param map
  186. * @param s
  187. * @return
  188. */
  189. String stringToNominal(HashSet<String> map, String s) {
  190. return map.contains(s)?s:"unknown";
  191. }
  192. /**
  193. * Train the clusterer by collecting the packets
  194. *
  195. * @param packets packets to be learned
  196. */
  197. private void training(HashMap<Link, LinkedList<Packet>> packets) {
  198. for(Entry<Link, LinkedList<Packet>> e:packets.entrySet()) {
  199. Link l = e.getKey();
  200. LinkedList<Packet> p = collectedPackets.get(l);
  201. if(p == null)
  202. collectedPackets.put(l, new LinkedList<Packet>(e.getValue()));
  203. else
  204. p.addAll(e.getValue());
  205. }
  206. }
  207. /**
  208. * Finishes the collection and trains the clusterer on the collected packets
  209. *
  210. * @throws Exception
  211. */
  212. private void finishDataCollection() throws Exception{
  213. atts.add(new Attribute("Link-Name", new LinkedList<String>(link_mappings)));//TODO:??
  214. atts.add(new Attribute("Source-Device", new LinkedList<String>(source_mappings)));
  215. atts.add(new Attribute("Source-Port-number", false));
  216. atts.add(new Attribute("Destination-Device", new LinkedList<String>(destination_mappings)));
  217. atts.add(new Attribute("Destination-Port-number", false));
  218. Attribute pn = new Attribute("Protocol-name", new LinkedList<String>(protocol_mappings));
  219. //pn.setWeight(10);
  220. atts.add(pn);
  221. Attribute pps = new Attribute("Packets-per-second", false);
  222. //pps.setWeight(20);
  223. atts.add(pps);
  224. //atts.add(new Attribute("Anomaly", false));
  225. /*
  226. atts = new ArrayList<Attribute>();
  227. atts.add(new Attribute("LN", new LinkedList<String>(link_mappings)));//TODO:??
  228. atts.add(new Attribute("SD", new LinkedList<String>(source_mappings)));
  229. atts.add(new Attribute("SPN", false));
  230. atts.add(new Attribute("DD", new LinkedList<String>(destination_mappings)));
  231. atts.add(new Attribute("DPN", false));
  232. atts.add(new Attribute("PN", new LinkedList<String>(protocol_mappings)));
  233. atts.add(new Attribute("PPS", false));
  234. atts.add(new Attribute("A", false));*/
  235. dataset = new Instances("Packets", atts, 100000);
  236. //dataset.setClassIndex(7);
  237. /**
  238. * Add Instances to dataset
  239. */
  240. for (Iterator<Entry<Link, LinkedList<Packet>>> it = collectedPackets.entrySet().iterator(); it.hasNext();) {
  241. Entry<Link, LinkedList<Packet>> entry = it.next();
  242. /**
  243. * Link the packet was captured on
  244. */
  245. Link l = entry.getKey();
  246. for (Iterator<Packet> itPacket = entry.getValue().iterator(); itPacket.hasNext();) {
  247. /**
  248. * Packets to be added to the dataset
  249. */
  250. Packet packet = (Packet) itPacket.next();
  251. dataset.add(packet2Instance(l, packet, dataset));
  252. }
  253. }
  254. /**
  255. * Build the clusterer for the given dataset
  256. */
  257. clusterer.buildClusterer(dataset);
  258. }
  259. /**
  260. * Try to classify the given packets and detect anomalies
  261. * @param packets packets to be classified
  262. */
  263. private void classify(HashMap<Link, LinkedList<Packet>> packets) {
  264. for (Iterator<Entry<Link, LinkedList<Packet>>> it = packets.entrySet().iterator(); it.hasNext();) {
  265. /**
  266. * Link & its packets
  267. */
  268. Entry<Link, LinkedList<Packet>> entry = it.next();
  269. /**
  270. * Link the packets were captured on
  271. */
  272. Link l = entry.getKey();
  273. for (Iterator<Packet> itPacket = entry.getValue().iterator(); itPacket.hasNext();) {
  274. /**
  275. * Packet which should be checked
  276. */
  277. Packet packet = (Packet) itPacket.next();
  278. /**
  279. * Instance Representation
  280. */
  281. Instance packet_instance = packet2Instance(l, packet, dataset);
  282. if(packet_instance == null)continue;
  283. try {
  284. /**
  285. * Try to classify (find appropriate cluster)
  286. */
  287. int c = clusterer.clusterInstance(packet_instance);
  288. System.out.println("Cluster "+c+": "+packet.getTextualRepresentation());
  289. } catch (Exception e) {
  290. /**
  291. * Anomaly found
  292. */
  293. System.out.println("Anomaly: "+packet.getTextualRepresentation());
  294. e.printStackTrace();
  295. }
  296. }
  297. }
  298. }
  299. }