UnsupervisedAnomalyDetectionExample2.java 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343
  1. import java.util.ArrayList;
  2. import java.util.HashMap;
  3. import java.util.HashSet;
  4. import java.util.Iterator;
  5. import java.util.LinkedList;
  6. import java.util.Map.Entry;
  7. import de.tu_darmstadt.tk.SmartHomeNetworkSim.core.Link;
  8. import de.tu_darmstadt.tk.SmartHomeNetworkSim.core.Packet;
  9. import de.tu_darmstadt.tk.SmartHomeNetworkSim.core.PacketSniffer;
  10. import weka.clusterers.SimpleKMeans;
  11. import weka.core.Attribute;
  12. import weka.core.DenseInstance;
  13. import weka.core.Instance;
  14. import weka.core.Instances;
  15. /**
  16. * Unsupervised Example - maybe Clustering
  17. *
  18. * @author Andreas T. Meyer-Berg
  19. */
  20. public class UnsupervisedAnomalyDetectionExample2 implements PacketSniffer {
  21. /**
  22. * Clusterer
  23. */
  24. private SimpleKMeans clusterer;
  25. /**
  26. * True, if instances should be used for training
  27. */
  28. private boolean training = true;
  29. /**
  30. * Attributes which should be taken into account
  31. */
  32. private ArrayList<Attribute> atts = new ArrayList<Attribute>();
  33. /**
  34. * Collected Packets
  35. */
  36. private Instances dataset;
  37. /**
  38. * CollectedPackets
  39. */
  40. private HashMap<Link, LinkedList<Packet>> collectedPackets = new HashMap<Link, LinkedList<Packet>>();
  41. /**
  42. * HashMap for calculating transmission delay
  43. */
  44. private HashMap<Link, LinkedList<Packet>> lastPackets = new HashMap<Link, LinkedList<Packet>>();
  45. /**
  46. * Number of Clusters
  47. */
  48. private int NUMBER_OF_CLUSTERS = 2;
  49. /**
  50. * Number of packets used for number of packets per second
  51. */
  52. private int NUMBER_OF_PACKETS = 30;
  53. /**
  54. *
  55. */
  56. private HashSet<String> link_mappings = new HashSet<String>();
  57. private HashSet<String> source_mappings = new HashSet<String>();
  58. private HashSet<String> destination_mappings = new HashSet<String>();
  59. private HashSet<String> protocol_mappings = new HashSet<String>();
  60. /**
  61. *
  62. */
  63. public UnsupervisedAnomalyDetectionExample2() {
  64. // Initialize Attribute list
  65. source_mappings.add("unknown");
  66. link_mappings.add("unknown");
  67. destination_mappings.add("unknown");
  68. protocol_mappings.add("unknown");
  69. // Initialize data set
  70. // Initialize Clusterer
  71. clusterer = new SimpleKMeans();
  72. clusterer.setSeed(42);
  73. try {
  74. clusterer.setNumClusters(NUMBER_OF_CLUSTERS);
  75. } catch (Exception e) {
  76. System.out.println("Error while building cluster");
  77. e.printStackTrace();
  78. }
  79. }
  80. @Override
  81. public void processPackets(HashMap<Link, LinkedList<Packet>> packets) {
  82. if(training)
  83. try {
  84. training(packets);
  85. } catch (Exception e) {
  86. e.printStackTrace();
  87. }
  88. else
  89. classify(packets);
  90. }
  91. /**
  92. * Estimates the current Packets per second (depending on the last 100 packets of the link)
  93. * @param link Link which should be checked
  94. * @param packet Packet which should investigated
  95. * @return estimated number of packets per second
  96. */
  97. private double getEstimatedPacketsPerSecond(Link link, Packet packet) {
  98. /**
  99. * Packets used to calculated the packets per second
  100. */
  101. LinkedList<Packet> list = lastPackets.get(link);
  102. if(list == null) {
  103. /**
  104. * Add list if not present
  105. */
  106. list = new LinkedList<Packet>();
  107. lastPackets.put(link, list);
  108. }
  109. if(list.isEmpty()) {
  110. list.addLast(packet);
  111. // Default 1 packet per second
  112. return 1.0;
  113. }
  114. if(list.size() == NUMBER_OF_PACKETS){
  115. list.removeFirst();
  116. }
  117. list.addLast(packet);
  118. /**
  119. * elapsed time in milliseconds since last packet
  120. */
  121. long elapsed_time = packet.getTimestamp()-list.getFirst().getTimestamp()/list.size();
  122. if(elapsed_time<=0)
  123. return Double.POSITIVE_INFINITY;
  124. /**
  125. * Return number of packets per second
  126. */
  127. return 1000.0/elapsed_time;
  128. }
  129. /**
  130. * Returns the instance representation of the given packet and link
  131. * @param link link the packet was sent on
  132. * @param packet packet which should be transformed
  133. * @param dataset distribution the packet is part of
  134. * @return instance representation
  135. */
  136. private Instance packet2Instance(Link link, Packet packet, Instances dataset) {
  137. /**
  138. * Instance for the given Packet
  139. */
  140. DenseInstance instance = new DenseInstance(dataset.numAttributes());
  141. instance.setDataset(dataset);
  142. // link
  143. instance.setValue(0, stringToNominal(link_mappings, link.getName()));
  144. // source
  145. if(packet.getSource()==null) {
  146. instance.setValue(1, "unknown");
  147. instance.setValue(2, Double.NEGATIVE_INFINITY);
  148. }else if(packet.getSource().getOwner()==null){
  149. instance.setValue(1, "unknown");
  150. instance.setValue(2, packet.getSource().getPortNumber());
  151. }else {
  152. instance.setValue(1, stringToNominal(source_mappings, packet.getSource().getOwner().getName()));
  153. instance.setValue(2, packet.getSource().getPortNumber());
  154. }
  155. // Destination
  156. if(packet.getDestination()==null) {
  157. instance.setValue(3, "unknown");
  158. instance.setValue(4, Double.NEGATIVE_INFINITY);
  159. }else if(packet.getDestination().getOwner()==null){
  160. instance.setValue(3, "unknown");
  161. instance.setValue(4, packet.getDestination().getPortNumber());
  162. }else {
  163. instance.setValue(3, stringToNominal(destination_mappings, packet.getDestination().getOwner().getName()));
  164. instance.setValue(4, packet.getDestination().getPortNumber());
  165. }
  166. // Protocol name
  167. instance.setValue(5, stringToNominal(protocol_mappings, packet.getProtocolName()));
  168. // Packets per second
  169. instance.setValue(6, getEstimatedPacketsPerSecond(link, packet));
  170. return instance;
  171. }
  172. /**
  173. * Transforms the String into an Number
  174. * @param map
  175. * @param s
  176. * @return
  177. */
  178. String stringToNominal(HashSet<String> map, String s) {
  179. return map.contains(s)?s:"unknown";
  180. }
  181. /**
  182. * Train the clusterer by collecting the packets
  183. *
  184. * @param packets packets to be learned
  185. */
  186. private void training(HashMap<Link, LinkedList<Packet>> packets) {
  187. for(Entry<Link, LinkedList<Packet>> e:packets.entrySet()) {
  188. Link l = e.getKey();
  189. LinkedList<Packet> p = collectedPackets.get(l);
  190. if(p == null)
  191. collectedPackets.put(l, new LinkedList<Packet>(e.getValue()));
  192. else
  193. p.addAll(e.getValue());
  194. }
  195. }
  196. /**
  197. * Finishes the collection and trains the clusterer on the collected packets
  198. *
  199. * @throws Exception
  200. */
  201. private void finishDataCollection() throws Exception{
  202. atts.add(new Attribute("Link-Name", new LinkedList<String>(link_mappings)));//TODO:??
  203. atts.add(new Attribute("Source-Device", new LinkedList<String>(source_mappings)));
  204. atts.add(new Attribute("Source-Port-number", false));
  205. atts.add(new Attribute("Destination-Device", new LinkedList<String>(destination_mappings)));
  206. atts.add(new Attribute("Destination-Port-number", false));
  207. Attribute pn = new Attribute("Protocol-name", new LinkedList<String>(protocol_mappings));
  208. //pn.setWeight(10);
  209. atts.add(pn);
  210. Attribute pps = new Attribute("Packets-per-second", false);
  211. //pps.setWeight(20);
  212. atts.add(pps);
  213. //atts.add(new Attribute("Anomaly", false));
  214. /*
  215. atts = new ArrayList<Attribute>();
  216. atts.add(new Attribute("LN", new LinkedList<String>(link_mappings)));//TODO:??
  217. atts.add(new Attribute("SD", new LinkedList<String>(source_mappings)));
  218. atts.add(new Attribute("SPN", false));
  219. atts.add(new Attribute("DD", new LinkedList<String>(destination_mappings)));
  220. atts.add(new Attribute("DPN", false));
  221. atts.add(new Attribute("PN", new LinkedList<String>(protocol_mappings)));
  222. atts.add(new Attribute("PPS", false));
  223. atts.add(new Attribute("A", false));*/
  224. dataset = new Instances("Packets", atts, 100000);
  225. //dataset.setClassIndex(7);
  226. /**
  227. * Add Instances to dataset
  228. */
  229. for (Iterator<Entry<Link, LinkedList<Packet>>> it = collectedPackets.entrySet().iterator(); it.hasNext();) {
  230. Entry<Link, LinkedList<Packet>> entry = it.next();
  231. /**
  232. * Link the packet was captured on
  233. */
  234. Link l = entry.getKey();
  235. for (Iterator<Packet> itPacket = entry.getValue().iterator(); itPacket.hasNext();) {
  236. /**
  237. * Packets to be added to the dataset
  238. */
  239. Packet packet = (Packet) itPacket.next();
  240. dataset.add(packet2Instance(l, packet, dataset));
  241. }
  242. }
  243. /**
  244. * Build the clusterer for the given dataset
  245. */
  246. clusterer.buildClusterer(dataset);
  247. }
  248. /**
  249. * Try to classify the given packets and detect anomalies
  250. * @param packets packets to be classified
  251. */
  252. private void classify(HashMap<Link, LinkedList<Packet>> packets) {
  253. for (Iterator<Entry<Link, LinkedList<Packet>>> it = packets.entrySet().iterator(); it.hasNext();) {
  254. /**
  255. * Link & its packets
  256. */
  257. Entry<Link, LinkedList<Packet>> entry = it.next();
  258. /**
  259. * Link the packets were captured on
  260. */
  261. Link l = entry.getKey();
  262. for (Iterator<Packet> itPacket = entry.getValue().iterator(); itPacket.hasNext();) {
  263. /**
  264. * Packet which should be checked
  265. */
  266. Packet packet = (Packet) itPacket.next();
  267. /**
  268. * Instance Representation
  269. */
  270. Instance packet_instance = packet2Instance(l, packet, dataset);
  271. if(packet_instance == null)continue;
  272. try {
  273. /**
  274. * Try to classify (find appropriate cluster)
  275. */
  276. int c = clusterer.clusterInstance(packet_instance);
  277. System.out.println("Cluster "+c+": "+packet.getTextualRepresentation());
  278. } catch (Exception e) {
  279. /**
  280. * Anomaly found
  281. */
  282. System.out.println("Anomaly: "+packet.getTextualRepresentation());
  283. e.printStackTrace();
  284. }
  285. }
  286. }
  287. }
  288. @Override
  289. public void setMode(boolean testing) {
  290. training = !testing;
  291. if(testing) {
  292. // Build Clusterer
  293. try {
  294. finishDataCollection();
  295. } catch (Exception e) {
  296. System.out.println("Clustering failed");
  297. e.printStackTrace();
  298. }
  299. }
  300. }
  301. @Override
  302. public boolean getMode() {
  303. return !training;
  304. }
  305. }