BasicPacketClassifierWitLabels.java 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540
  1. package de.tu_darmstadt.tk.SmartHomeNetworkSim.evaluation;
  2. import java.io.BufferedWriter;
  3. import java.io.File;
  4. import java.io.FileWriter;
  5. import java.io.IOException;
  6. import java.util.ArrayList;
  7. import java.util.HashMap;
  8. import java.util.HashSet;
  9. import java.util.Iterator;
  10. import java.util.LinkedList;
  11. import java.util.Map.Entry;
  12. import java.util.Set;
  13. import de.tu_darmstadt.tk.SmartHomeNetworkSim.core.Link;
  14. import de.tu_darmstadt.tk.SmartHomeNetworkSim.core.Packet;
  15. import de.tu_darmstadt.tk.SmartHomeNetworkSim.core.PacketSniffer;
  16. import de.tu_darmstadt.tk.SmartHomeNetworkSim.core.protocols.packets.MQTTpublishPacket;
  17. import weka.core.Attribute;
  18. import weka.core.DenseInstance;
  19. import weka.core.Instance;
  20. import weka.core.Instances;
  21. /**
  22. * Unsupervised Classifier Basis, which contains methods for transforming {@link Packet}s into {@link Instance}s.
  23. *
  24. * @author Andreas T. Meyer-Berg
  25. */
  26. public abstract class BasicPacketClassifierWitLabels implements PacketSniffer {
  27. /**
  28. * True, if instances should be used for training
  29. */
  30. protected boolean training = true;
  31. /**
  32. * Attributes which should be taken into account
  33. */
  34. protected ArrayList<Attribute> atts = new ArrayList<Attribute>();
  35. /**
  36. * Collected Packets
  37. */
  38. protected Instances dataset;
  39. /**
  40. * CollectedPackets
  41. */
  42. protected HashMap<Link, LinkedList<Packet>> collectedPackets = new HashMap<Link, LinkedList<Packet>>();
  43. /**
  44. * HashMap for calculating transmission delay
  45. */
  46. protected HashMap<Link, LinkedList<Packet>> lastPackets = new HashMap<Link, LinkedList<Packet>>();
  47. /**
  48. * Map for the different Link names
  49. */
  50. protected HashSet<String> link_mappings = new HashSet<String>();
  51. /**
  52. * Map for the difference source device names
  53. */
  54. protected HashSet<String> source_mappings = new HashSet<String>();
  55. /**
  56. * Map for the different destination device names
  57. */
  58. protected HashSet<String> destination_mappings = new HashSet<String>();
  59. /**
  60. * Map for the protocol names
  61. */
  62. protected HashSet<String> protocol_mappings = new HashSet<String>();
  63. /**
  64. * Map for the protocol names
  65. */
  66. protected HashSet<String> packet_mappings = new HashSet<String>();
  67. /**
  68. * Number of packets which are used to calculate the current transmission speed
  69. */
  70. protected int NUMBER_OF_PACKETS = 200;
  71. private String currentScenario = "";
  72. private int scenarioRun = 0;
  73. /**
  74. * Initializes the different maps
  75. */
  76. public BasicPacketClassifierWitLabels() {
  77. // Initialize Attribute list
  78. source_mappings.add("unknown");
  79. link_mappings.add("unknown");
  80. destination_mappings.add("unknown");
  81. protocol_mappings.add("unknown");
  82. packet_mappings.add("unknown");
  83. }
  84. @Override
  85. public void processPackets(HashMap<Link, LinkedList<Packet>> packets) {
  86. if(training)
  87. try {
  88. training(packets);
  89. } catch (Exception e) {
  90. e.printStackTrace();
  91. }
  92. else
  93. classify(packets);
  94. }
  95. /**
  96. * Estimates the current Packets per second (depending on the last 100 packets of the link)
  97. * @param link Link which should be checked
  98. * @param packet Packet which should investigated
  99. * @return estimated number of packets per second
  100. */
  101. protected double getEstimatedPacketsPerSecond(Link link, Packet packet) {
  102. /**
  103. * Packets used to calculated the packets per second
  104. */
  105. LinkedList<Packet> list = lastPackets.get(link);
  106. if(list == null) {
  107. /**
  108. * Add list if not present
  109. */
  110. list = new LinkedList<Packet>();
  111. lastPackets.put(link, list);
  112. }
  113. if(list.isEmpty()) {
  114. list.addLast(packet);
  115. // Default 1 packet per second
  116. return 1.0;
  117. }
  118. if(list.size() == NUMBER_OF_PACKETS){
  119. list.removeFirst();
  120. }
  121. list.addLast(packet);
  122. /**
  123. * elapsed time in milliseconds since last packet
  124. */
  125. long elapsed_time = packet.getTimestamp()-list.getFirst().getTimestamp()/list.size();
  126. if(elapsed_time<=0)
  127. return Double.POSITIVE_INFINITY;
  128. /**
  129. * Return number of packets per second
  130. */
  131. return 1000.0/elapsed_time;
  132. }
  133. /**
  134. * Returns the instance representation of the given packet and link
  135. * @param link link the packet was sent on
  136. * @param packet packet which should be transformed
  137. * @param dataset distribution the packet is part of
  138. * @return instance representation
  139. */
  140. protected Instance packet2Instance(Link link, Packet packet, Instances dataset) {
  141. /**
  142. * Instance for the given Packet
  143. */
  144. DenseInstance instance = new DenseInstance(dataset.numAttributes());
  145. instance.setDataset(dataset);
  146. /**
  147. // link
  148. instance.setValue(0, stringToNominal(link_mappings, link.getName()));
  149. // source
  150. if(packet.getSource()==null) {
  151. instance.setValue(1, "unknown");
  152. instance.setValue(2, Double.NEGATIVE_INFINITY);
  153. }else if(packet.getSource().getOwner()==null){
  154. instance.setValue(1, "unknown");
  155. instance.setValue(2, packet.getSource().getPortNumber());
  156. }else {
  157. instance.setValue(1, stringToNominal(source_mappings, packet.getSource().getOwner().getName()));
  158. instance.setValue(2, packet.getSource().getPortNumber());
  159. }
  160. // Destination
  161. if(packet.getDestination()==null) {
  162. instance.setValue(3, "unknown");
  163. instance.setValue(4, Double.NEGATIVE_INFINITY);
  164. }else if(packet.getDestination().getOwner()==null){
  165. instance.setValue(3, "unknown");
  166. instance.setValue(4, packet.getDestination().getPortNumber());
  167. }else {
  168. instance.setValue(3, stringToNominal(destination_mappings, packet.getDestination().getOwner().getName()));
  169. instance.setValue(4, packet.getDestination().getPortNumber());
  170. }
  171. // Protocol name
  172. instance.setValue(5, stringToNominal(protocol_mappings, packet.getProtocolName()));
  173. // Packets per second
  174. //instance.setValue(6, getEstimatedPacketsPerSecond(link, packet));
  175. // MQTT Value*/
  176. if(packet instanceof MQTTpublishPacket) {
  177. MQTTpublishPacket mqttPack = (MQTTpublishPacket)packet;
  178. if(mqttPack.isBoolean()) {
  179. //System.out.println("MQTT PACK: " + mqttPack.getValue() + "," +mqttPack.getSensorValue());
  180. if(mqttPack.getValue() == 0) {
  181. instance.setValue(0, Settings.FALSE_VALUE);
  182. //System.out.println("False");
  183. }
  184. else {
  185. instance.setValue(0, Settings.TRUE_VALUE);
  186. //System.out.println("True");
  187. }
  188. if(mqttPack.getSensorValue() == 0) {
  189. instance.setValue(1, Settings.FALSE_VALUE);
  190. //System.out.println("False");
  191. } else {
  192. instance.setValue(1, Settings.TRUE_VALUE);
  193. //System.out.println("True");
  194. }
  195. }else {
  196. instance.setValue(0, ((MQTTpublishPacket)packet).getValue());
  197. instance.setValue(1, ((MQTTpublishPacket)packet).getSensorValue());
  198. }
  199. } else {
  200. instance.setValue(0, Settings.NO_VALUE);
  201. instance.setValue(1, Settings.NO_VALUE);
  202. }
  203. instance.setValue(2, stringToNominal(packet_mappings, packet.getPackageType()));
  204. return instance;
  205. }
  206. /**
  207. * Inserts the
  208. * @param map
  209. * @param nominal
  210. */
  211. protected void insertNominalIntoMap(HashSet<String> map, String nominal) {
  212. if(map == null || nominal == null)
  213. return;
  214. map.add(nominal);
  215. }
  216. /**
  217. * Transforms the String into an Number
  218. * @param map
  219. * @param s
  220. * @return
  221. */
  222. protected String stringToNominal(HashSet<String> map, String s) {
  223. return map.contains(s)?s:"unknown";
  224. }
  225. /**
  226. * Train the clusterer by collecting the packets
  227. *
  228. * @param packets packets to be learned
  229. */
  230. protected void training(HashMap<Link, LinkedList<Packet>> packets) {
  231. for(Entry<Link, LinkedList<Packet>> e:packets.entrySet()) {
  232. Link l = e.getKey();
  233. // TODO: ERROR ????????
  234. LinkedList<Packet> p = collectedPackets.get(l);
  235. if(p == null) {
  236. collectedPackets.put(l, new LinkedList<Packet>(e.getValue()));
  237. } else
  238. p.addAll(e.getValue());
  239. insertNominalIntoMap(link_mappings, l.getName());
  240. for(Packet pac: e.getValue()) {
  241. if(pac == null || pac.getSource()==null ||pac.getDestination() == null || pac.getSource().getOwner() == null || pac.getDestination().getOwner() == null)
  242. continue;
  243. if(!(pac instanceof MQTTpublishPacket))
  244. continue;
  245. insertNominalIntoMap(destination_mappings, pac.getSource().getOwner().getName());
  246. insertNominalIntoMap(destination_mappings, pac.getDestination().getOwner().getName());
  247. insertNominalIntoMap(source_mappings, pac.getSource().getOwner().getName());
  248. insertNominalIntoMap(source_mappings, pac.getDestination().getOwner().getName());
  249. insertNominalIntoMap(protocol_mappings, pac.getProtocolName());
  250. insertNominalIntoMap(packet_mappings, pac.getPackageType());
  251. }
  252. //TODO: Add packet/Link/Names etc. to mappings
  253. }
  254. }
  255. /**
  256. * Finishes the collection and trains the clusterer on the collected packets
  257. *
  258. * @throws Exception
  259. */
  260. protected void finishDataCollection() throws Exception{
  261. /**
  262. printHashSet("Link-Name", link_mappings);
  263. printHashSet("Source-Device", source_mappings);
  264. printHashSet("Destination-Port", destination_mappings);
  265. printHashSet("Protocol-name", protocol_mappings);
  266. */
  267. /*
  268. atts.add(new Attribute("Link-Name", new LinkedList<String>(link_mappings)));//TODO:??
  269. atts.add(new Attribute("Source-Device", new LinkedList<String>(source_mappings)));
  270. atts.add(new Attribute("Source-Port-number", false));
  271. atts.add(new Attribute("Destination-Device", new LinkedList<String>(destination_mappings)));
  272. atts.add(new Attribute("Destination-Port-number", false));
  273. Attribute pn = new Attribute("Protocol-name", new LinkedList<String>(protocol_mappings));
  274. //pn.setWeight(10);
  275. atts.add(pn);
  276. //Attribute pps = new Attribute("Packets-per-second", false);
  277. //pps.setWeight(20);
  278. //atts.add(pps);*/
  279. atts.add(new Attribute("PacketValue", false));
  280. //atts.add(new Attribute("Anomaly", false));
  281. // TODO: Sensor Attribute, given as side channel information
  282. atts.add(new Attribute("SensorValue", false));
  283. atts.add(new Attribute("PackageType",new LinkedList<String>(packet_mappings)));
  284. /*
  285. atts = new ArrayList<Attribute>();
  286. atts.add(new Attribute("LN", new LinkedList<String>(link_mappings)));//TODO:??
  287. atts.add(new Attribute("SD", new LinkedList<String>(source_mappings)));
  288. atts.add(new Attribute("SPN", false));
  289. atts.add(new Attribute("DD", new LinkedList<String>(destination_mappings)));
  290. atts.add(new Attribute("DPN", false));
  291. atts.add(new Attribute("PN", new LinkedList<String>(protocol_mappings)));
  292. atts.add(new Attribute("PPS", false));
  293. atts.add(new Attribute("A", false));*/
  294. dataset = new Instances("Packets", atts, 100000);
  295. //dataset.setClassIndex(7);
  296. /**
  297. * Add Instances to dataset
  298. */
  299. for (Iterator<Entry<Link, LinkedList<Packet>>> it = collectedPackets.entrySet().iterator(); it.hasNext();) {
  300. Entry<Link, LinkedList<Packet>> entry = it.next();
  301. /**
  302. * Link the packet was captured on
  303. */
  304. Link l = entry.getKey();
  305. for (Iterator<Packet> itPacket = entry.getValue().iterator(); itPacket.hasNext();) {
  306. /**
  307. * Packets to be added to the dataset
  308. */
  309. Packet packet = (Packet) itPacket.next();
  310. dataset.add(packet2Instance(l, packet, dataset));
  311. }
  312. }
  313. trainModel(dataset);
  314. }
  315. private void printHashSet(String name, HashSet<String> toPrint) {
  316. System.out.println(name+":");
  317. for (Iterator<String> iterator = toPrint.iterator(); iterator.hasNext();) {
  318. String string = (String) iterator.next();
  319. System.out.print(string);
  320. if(iterator.hasNext())
  321. System.out.print(", ");
  322. }
  323. System.out.println();
  324. }
  325. /**
  326. * Try to classify the given packets and detect anomalies
  327. * @param packets packets to be classified
  328. */
  329. protected void classify(HashMap<Link, LinkedList<Packet>> packets) {
  330. File anomalyResults = new File("results/"+getCurrentScenario() + scenarioRun + ".csv");
  331. anomalyResults.getParentFile().mkdir();
  332. BufferedWriter writer = null;
  333. try {
  334. writer = new BufferedWriter(new FileWriter(anomalyResults));
  335. writer.write("PacketRepresentation,anomalyFPorTP,sensorInfo\n");
  336. } catch (IOException e1) {
  337. // TODO Auto-generated catch block
  338. e1.printStackTrace();
  339. }
  340. int tp = 0;
  341. int fp = 0;
  342. int tn = 0;
  343. int fn = 0;
  344. long start = Long.MAX_VALUE;
  345. long end = Long.MIN_VALUE;
  346. for (Iterator<Entry<Link, LinkedList<Packet>>> it = packets.entrySet().iterator(); it.hasNext();) {
  347. /**
  348. * Link & its packets
  349. */
  350. Entry<Link, LinkedList<Packet>> entry = it.next();
  351. /**
  352. * Link the packets were captured on
  353. */
  354. Link l = entry.getKey();
  355. for (Iterator<Packet> itPacket = entry.getValue().iterator(); itPacket.hasNext();) {
  356. /**
  357. * Packet which should be checked
  358. */
  359. Packet packet = (Packet) itPacket.next();
  360. if(!(packet instanceof MQTTpublishPacket))
  361. continue;
  362. start = Math.min(start, packet.getTimestamp());
  363. end = Math.max(end, packet.getTimestamp());
  364. /**
  365. * Instance Representation
  366. */
  367. Instance packet_instance = packet2Instance(l, packet, dataset);
  368. if(packet_instance == null)continue;
  369. String sensorLabel = "";
  370. if(packet instanceof MQTTpublishPacket) {
  371. MQTTpublishPacket mqttPac = (MQTTpublishPacket)packet;
  372. sensorLabel = ""+mqttPac.getSensorValue();
  373. if(mqttPac.isBoolean()) {
  374. if(mqttPac.getSensorValue() == 0)
  375. sensorLabel = "false";
  376. else
  377. sensorLabel = "true";
  378. }
  379. sensorLabel = ","+sensorLabel;
  380. }
  381. try {
  382. double dist = classifyInstance(packet_instance, packet);
  383. //System.out.println(packet.getTextualRepresentation()+": "+packet.getLabel() +":"+sensorLabel);
  384. if(dist<=Settings.SECOND_THRESHOLD) {
  385. if(packet.getLabel()==0) {
  386. tn++;
  387. writer.write(packet.getTextualRepresentation()+",TN"+sensorLabel+"\n");
  388. }
  389. else {
  390. fn++;
  391. writer.write(packet.getTextualRepresentation()+",FN"+sensorLabel+"\n");
  392. //System.out.println(packet.getTextualRepresentation()+",AnomalyNotFound"+sensorLabel);
  393. }
  394. }else {
  395. if(packet.getLabel()==0) {
  396. fp++;
  397. writer.write(packet.getTextualRepresentation()+",FP"+sensorLabel+"\n");
  398. } else {
  399. tp++;
  400. writer.write(packet.getTextualRepresentation()+",TP"+sensorLabel+"\n");
  401. }
  402. }
  403. } catch (Exception e) {
  404. System.out.println(e);
  405. e.printStackTrace();
  406. if(packet.getLabel()==0) {
  407. fp++;
  408. try {
  409. writer.write(packet.getTextualRepresentation()+",FP"+sensorLabel+"\n");
  410. } catch (IOException e1) {
  411. // TODO Auto-generated catch block
  412. e1.printStackTrace();
  413. }
  414. } else {
  415. tp++;
  416. try {
  417. writer.write(packet.getTextualRepresentation()+",TP"+sensorLabel+"\n");
  418. } catch (IOException e1) {
  419. // TODO Auto-generated catch block
  420. e1.printStackTrace();
  421. }
  422. }
  423. }
  424. }
  425. }
  426. int n = tp+tn+fp+fn;
  427. if(n!=0) {
  428. System.out.println(getAlgoName()+" Performance: ["+start+"ms, "+end+"ms] Scenario: " + getCurrentScenario() + scenarioRun);
  429. scenarioRun++;
  430. System.out.println("n: "+n);
  431. System.out.println("TP: "+tp);
  432. System.out.println("FP: "+fp);
  433. System.out.println("TN: "+tn);
  434. System.out.println("FN: "+fn);
  435. System.out.println("TPR: "+(tp/(tp+fn+0.0)));
  436. System.out.println("FPR: "+(fp/(fp+tn+0.0)));
  437. System.out.println("");
  438. }
  439. try {
  440. writer.close();
  441. } catch (IOException e) {
  442. // TODO Auto-generated catch block
  443. e.printStackTrace();
  444. }
  445. }
  446. /**
  447. * Train the model using the given instances
  448. * @param instances training set, which should be learned
  449. */
  450. public abstract void trainModel(Instances instances);
  451. /**
  452. * classifies the given instance
  453. * @param instance instance which should be classified
  454. * @param origin original packet, which was transformed into the instance
  455. * @return distance to next centroid
  456. * @throws Exception if anomaly was detected
  457. */
  458. public abstract double classifyInstance(Instance instance, Packet origin) throws Exception;
  459. /**
  460. * Returns the timestep, after which the classifier should start classifying instead of training.
  461. * @return timestep of the testing begin.
  462. */
  463. public abstract long getClassificationStart();
  464. @Override
  465. public void setMode(boolean testing) {
  466. training = !testing;
  467. if(testing) {
  468. try {
  469. finishDataCollection();
  470. } catch (Exception e) {
  471. System.out.println("Clustering failed");
  472. e.printStackTrace();
  473. }
  474. }
  475. }
  476. @Override
  477. public boolean getMode() {
  478. return !training;
  479. }
  480. /**
  481. * Short String representation of the classifier
  482. * @return
  483. */
  484. public abstract String getAlgoName();
  485. /**
  486. * @return the currentScenario
  487. */
  488. public String getCurrentScenario() {
  489. return currentScenario;
  490. }
  491. /**
  492. * @param currentScenario the currentScenario to set
  493. */
  494. public void setCurrentScenario(String currentScenario) {
  495. this.currentScenario = currentScenario;
  496. this.scenarioRun = 0;
  497. }
  498. }