BasicPacketClassifier.java 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434
  1. package de.tu_darmstadt.tk.SmartHomeNetworkSim.evaluation;
  2. import java.util.ArrayList;
  3. import java.util.HashMap;
  4. import java.util.HashSet;
  5. import java.util.Iterator;
  6. import java.util.LinkedList;
  7. import java.util.Map.Entry;
  8. import de.tu_darmstadt.tk.SmartHomeNetworkSim.core.Link;
  9. import de.tu_darmstadt.tk.SmartHomeNetworkSim.core.Packet;
  10. import de.tu_darmstadt.tk.SmartHomeNetworkSim.core.PacketSniffer;
  11. import de.tu_darmstadt.tk.SmartHomeNetworkSim.core.protocols.packets.MQTTpublishPacket;
  12. import weka.core.Attribute;
  13. import weka.core.DenseInstance;
  14. import weka.core.Instance;
  15. import weka.core.Instances;
  16. /**
  17. * Unsupervised Classifier Basis, which contains methods for transforming {@link Packet}s into {@link Instance}s.
  18. *
  19. * @author Andreas T. Meyer-Berg
  20. */
  21. public abstract class BasicPacketClassifier implements PacketSniffer {
  22. /**
  23. * True, if instances should be used for training
  24. */
  25. protected boolean training = true;
  26. /**
  27. * Attributes which should be taken into account
  28. */
  29. protected ArrayList<Attribute> atts = new ArrayList<Attribute>();
  30. /**
  31. * Collected Packets
  32. */
  33. protected Instances dataset;
  34. /**
  35. * CollectedPackets
  36. */
  37. protected HashMap<Link, LinkedList<Packet>> collectedPackets = new HashMap<Link, LinkedList<Packet>>();
  38. /**
  39. * HashMap for calculating transmission delay
  40. */
  41. protected HashMap<Link, LinkedList<Packet>> lastPackets = new HashMap<Link, LinkedList<Packet>>();
  42. /**
  43. * Map for the different Link names
  44. */
  45. protected HashSet<String> link_mappings = new HashSet<String>();
  46. /**
  47. * Map for the difference source device names
  48. */
  49. protected HashSet<String> source_mappings = new HashSet<String>();
  50. /**
  51. * Map for the different destination device names
  52. */
  53. protected HashSet<String> destination_mappings = new HashSet<String>();
  54. /**
  55. * Map for the protocol names
  56. */
  57. protected HashSet<String> protocol_mappings = new HashSet<String>();
  58. /**
  59. * Number of packets which are used to calculate the current transmission speed
  60. */
  61. protected int NUMBER_OF_PACKETS = 200;
  62. /**
  63. * Initializes the different maps
  64. */
  65. public BasicPacketClassifier() {
  66. // Initialize Attribute list
  67. source_mappings.add("unknown");
  68. link_mappings.add("unknown");
  69. destination_mappings.add("unknown");
  70. protocol_mappings.add("unknown");
  71. }
  72. @Override
  73. public void processPackets(HashMap<Link, LinkedList<Packet>> packets) {
  74. if(training)
  75. try {
  76. training(packets);
  77. } catch (Exception e) {
  78. e.printStackTrace();
  79. }
  80. else
  81. classify(packets);
  82. }
  83. /**
  84. * Estimates the current Packets per second (depending on the last 100 packets of the link)
  85. * @param link Link which should be checked
  86. * @param packet Packet which should investigated
  87. * @return estimated number of packets per second
  88. */
  89. protected double getEstimatedPacketsPerSecond(Link link, Packet packet) {
  90. /**
  91. * Packets used to calculated the packets per second
  92. */
  93. LinkedList<Packet> list = lastPackets.get(link);
  94. if(list == null) {
  95. /**
  96. * Add list if not present
  97. */
  98. list = new LinkedList<Packet>();
  99. lastPackets.put(link, list);
  100. }
  101. if(list.isEmpty()) {
  102. list.addLast(packet);
  103. // Default 1 packet per second
  104. return 1.0;
  105. }
  106. if(list.size() == NUMBER_OF_PACKETS){
  107. list.removeFirst();
  108. }
  109. list.addLast(packet);
  110. /**
  111. * elapsed time in milliseconds since last packet
  112. */
  113. long elapsed_time = packet.getTimestamp()-list.getFirst().getTimestamp()/list.size();
  114. if(elapsed_time<=0)
  115. return Double.POSITIVE_INFINITY;
  116. /**
  117. * Return number of packets per second
  118. */
  119. return 1000.0/elapsed_time;
  120. }
  121. /**
  122. * Returns the instance representation of the given packet and link
  123. * @param link link the packet was sent on
  124. * @param packet packet which should be transformed
  125. * @param dataset distribution the packet is part of
  126. * @return instance representation
  127. */
  128. protected Instance packet2Instance(Link link, Packet packet, Instances dataset) {
  129. /**
  130. * Instance for the given Packet
  131. */
  132. DenseInstance instance = new DenseInstance(dataset.numAttributes());
  133. instance.setDataset(dataset);
  134. // link
  135. instance.setValue(0, stringToNominal(link_mappings, link.getName()));
  136. // source
  137. if(packet.getSource()==null) {
  138. instance.setValue(1, "unknown");
  139. instance.setValue(2, Double.NEGATIVE_INFINITY);
  140. }else if(packet.getSource().getOwner()==null){
  141. instance.setValue(1, "unknown");
  142. instance.setValue(2, packet.getSource().getPortNumber());
  143. }else {
  144. instance.setValue(1, stringToNominal(source_mappings, packet.getSource().getOwner().getName()));
  145. instance.setValue(2, packet.getSource().getPortNumber());
  146. }
  147. // Destination
  148. if(packet.getDestination()==null) {
  149. instance.setValue(3, "unknown");
  150. instance.setValue(4, Double.NEGATIVE_INFINITY);
  151. }else if(packet.getDestination().getOwner()==null){
  152. instance.setValue(3, "unknown");
  153. instance.setValue(4, packet.getDestination().getPortNumber());
  154. }else {
  155. instance.setValue(3, stringToNominal(destination_mappings, packet.getDestination().getOwner().getName()));
  156. instance.setValue(4, packet.getDestination().getPortNumber());
  157. }
  158. // Protocol name
  159. instance.setValue(5, stringToNominal(protocol_mappings, packet.getProtocolName()));
  160. // Packets per second
  161. instance.setValue(6, getEstimatedPacketsPerSecond(link, packet));
  162. // MQTT Value
  163. if(packet instanceof MQTTpublishPacket) {
  164. instance.setValue(7, ((MQTTpublishPacket)packet).getValue());
  165. instance.setValue(8, ((MQTTpublishPacket)packet).getSensorValue());
  166. } else {
  167. instance.setValue(7, -1);
  168. instance.setValue(8, -1);
  169. }
  170. return instance;
  171. }
  172. /**
  173. * Inserts the
  174. * @param map
  175. * @param nominal
  176. */
  177. protected void insertNominalIntoMap(HashSet<String> map, String nominal) {
  178. if(map == null || nominal == null)
  179. return;
  180. map.add(nominal);
  181. }
  182. /**
  183. * Transforms the String into an Number
  184. * @param map
  185. * @param s
  186. * @return
  187. */
  188. protected String stringToNominal(HashSet<String> map, String s) {
  189. return map.contains(s)?s:"unknown";
  190. }
  191. /**
  192. * Train the clusterer by collecting the packets
  193. *
  194. * @param packets packets to be learned
  195. */
  196. protected void training(HashMap<Link, LinkedList<Packet>> packets) {
  197. for(Entry<Link, LinkedList<Packet>> e:packets.entrySet()) {
  198. Link l = e.getKey();
  199. // TODO: ERROR ????????
  200. LinkedList<Packet> p = collectedPackets.get(l);
  201. if(p == null) {
  202. collectedPackets.put(l, new LinkedList<Packet>(e.getValue()));
  203. } else
  204. p.addAll(e.getValue());
  205. insertNominalIntoMap(link_mappings, l.getName());
  206. for(Packet pac: e.getValue()) {
  207. if(pac == null || pac.getSource()==null ||pac.getDestination() == null || pac.getSource().getOwner() == null || pac.getDestination().getOwner() == null)
  208. continue;
  209. insertNominalIntoMap(destination_mappings, pac.getSource().getOwner().getName());
  210. insertNominalIntoMap(destination_mappings, pac.getDestination().getOwner().getName());
  211. insertNominalIntoMap(source_mappings, pac.getSource().getOwner().getName());
  212. insertNominalIntoMap(source_mappings, pac.getDestination().getOwner().getName());
  213. insertNominalIntoMap(protocol_mappings, pac.getProtocolName());
  214. }
  215. //TODO: Add packet/Link/Names etc. to mappings
  216. }
  217. }
  218. /**
  219. * Finishes the collection and trains the clusterer on the collected packets
  220. *
  221. * @throws Exception
  222. */
  223. protected void finishDataCollection() throws Exception{
  224. /**
  225. printHashSet("Link-Name", link_mappings);
  226. printHashSet("Source-Device", source_mappings);
  227. printHashSet("Destination-Port", destination_mappings);
  228. printHashSet("Protocol-name", protocol_mappings);
  229. */
  230. atts.add(new Attribute("Link-Name", new LinkedList<String>(link_mappings)));//TODO:??
  231. atts.add(new Attribute("Source-Device", new LinkedList<String>(source_mappings)));
  232. atts.add(new Attribute("Source-Port-number", false));
  233. atts.add(new Attribute("Destination-Device", new LinkedList<String>(destination_mappings)));
  234. atts.add(new Attribute("Destination-Port-number", false));
  235. Attribute pn = new Attribute("Protocol-name", new LinkedList<String>(protocol_mappings));
  236. //pn.setWeight(10);
  237. atts.add(pn);
  238. Attribute pps = new Attribute("Packets-per-second", false);
  239. //pps.setWeight(20);
  240. atts.add(pps);
  241. atts.add(new Attribute("PacketValue", false));
  242. //atts.add(new Attribute("Anomaly", false));
  243. // TODO: Sensor Attribute, given as side channel information
  244. atts.add(new Attribute("SensorValue", false));
  245. /*
  246. atts = new ArrayList<Attribute>();
  247. atts.add(new Attribute("LN", new LinkedList<String>(link_mappings)));//TODO:??
  248. atts.add(new Attribute("SD", new LinkedList<String>(source_mappings)));
  249. atts.add(new Attribute("SPN", false));
  250. atts.add(new Attribute("DD", new LinkedList<String>(destination_mappings)));
  251. atts.add(new Attribute("DPN", false));
  252. atts.add(new Attribute("PN", new LinkedList<String>(protocol_mappings)));
  253. atts.add(new Attribute("PPS", false));
  254. atts.add(new Attribute("A", false));*/
  255. dataset = new Instances("Packets", atts, 100000);
  256. //dataset.setClassIndex(7);
  257. /**
  258. * Add Instances to dataset
  259. */
  260. for (Iterator<Entry<Link, LinkedList<Packet>>> it = collectedPackets.entrySet().iterator(); it.hasNext();) {
  261. Entry<Link, LinkedList<Packet>> entry = it.next();
  262. /**
  263. * Link the packet was captured on
  264. */
  265. Link l = entry.getKey();
  266. for (Iterator<Packet> itPacket = entry.getValue().iterator(); itPacket.hasNext();) {
  267. /**
  268. * Packets to be added to the dataset
  269. */
  270. Packet packet = (Packet) itPacket.next();
  271. dataset.add(packet2Instance(l, packet, dataset));
  272. }
  273. }
  274. trainModel(dataset);
  275. }
  276. private void printHashSet(String name, HashSet<String> toPrint) {
  277. System.out.println(name+":");
  278. for (Iterator<String> iterator = toPrint.iterator(); iterator.hasNext();) {
  279. String string = (String) iterator.next();
  280. System.out.print(string);
  281. if(iterator.hasNext())
  282. System.out.print(", ");
  283. }
  284. System.out.println();
  285. }
  286. /**
  287. * Try to classify the given packets and detect anomalies
  288. * @param packets packets to be classified
  289. */
  290. protected void classify(HashMap<Link, LinkedList<Packet>> packets) {
  291. int tp = 0;
  292. int fp = 0;
  293. int tn = 0;
  294. int fn = 0;
  295. long start = Long.MAX_VALUE;
  296. long end = Long.MIN_VALUE;
  297. for (Iterator<Entry<Link, LinkedList<Packet>>> it = packets.entrySet().iterator(); it.hasNext();) {
  298. /**
  299. * Link & its packets
  300. */
  301. Entry<Link, LinkedList<Packet>> entry = it.next();
  302. /**
  303. * Link the packets were captured on
  304. */
  305. Link l = entry.getKey();
  306. for (Iterator<Packet> itPacket = entry.getValue().iterator(); itPacket.hasNext();) {
  307. /**
  308. * Packet which should be checked
  309. */
  310. Packet packet = (Packet) itPacket.next();
  311. start = Math.min(start, packet.getTimestamp());
  312. end = Math.max(end, packet.getTimestamp());
  313. /**
  314. * Instance Representation
  315. */
  316. Instance packet_instance = packet2Instance(l, packet, dataset);
  317. if(packet_instance == null)continue;
  318. try {
  319. double dist = classifyInstance(packet_instance, packet);
  320. if(dist<=1.0) {
  321. if(packet.getLabel()==0)
  322. tn++;
  323. else {
  324. fn++;
  325. System.out.println(packet.getTextualRepresentation());
  326. }
  327. }else {
  328. if(packet.getLabel()==0)
  329. fp++;
  330. else
  331. tp++;
  332. }
  333. } catch (Exception e) {
  334. if(packet.getLabel()==0)
  335. fp++;
  336. else
  337. tp++;
  338. }
  339. }
  340. }
  341. int n = tp+tn+fp+fn;
  342. if(n!=0) {
  343. System.out.println(getAlgoName()+" Performance: ["+start+"ms, "+end+"ms]");
  344. System.out.println("n: "+n);
  345. System.out.println("TP: "+tp);
  346. System.out.println("FP: "+fp);
  347. System.out.println("TN: "+tn);
  348. System.out.println("FN: "+fn);
  349. System.out.println("TPR: "+(tp/(tp+fn+0.0)));
  350. System.out.println("FPR: "+(fp/(fp+tn+0.0)));
  351. System.out.println("");
  352. }
  353. }
  354. /**
  355. * Train the model using the given instances
  356. * @param instances training set, which should be learned
  357. */
  358. public abstract void trainModel(Instances instances);
  359. /**
  360. * classifies the given instance
  361. * @param instance instance which should be classified
  362. * @param origin original packet, which was transformed into the instance
  363. * @return distance to next centroid
  364. * @throws Exception if anomaly was detected
  365. */
  366. public abstract double classifyInstance(Instance instance, Packet origin) throws Exception;
  367. /**
  368. * Returns the timestep, after which the classifier should start classifying instead of training.
  369. * @return timestep of the testing begin.
  370. */
  371. public abstract long getClassificationStart();
  372. @Override
  373. public void setMode(boolean testing) {
  374. training = !testing;
  375. if(testing) {
  376. try {
  377. finishDataCollection();
  378. } catch (Exception e) {
  379. System.out.println("Clustering failed");
  380. e.printStackTrace();
  381. }
  382. }
  383. }
  384. @Override
  385. public boolean getMode() {
  386. return !training;
  387. }
  388. /**
  389. * Short String representation of the classifier
  390. * @return
  391. */
  392. public abstract String getAlgoName();
  393. }