BasicPacketClassifierWitLabels.java 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533
  1. package de.tu_darmstadt.tk.SmartHomeNetworkSim.evaluation;
  2. import java.io.BufferedWriter;
  3. import java.io.File;
  4. import java.io.FileWriter;
  5. import java.io.IOException;
  6. import java.util.ArrayList;
  7. import java.util.HashMap;
  8. import java.util.HashSet;
  9. import java.util.Iterator;
  10. import java.util.LinkedList;
  11. import java.util.Map.Entry;
  12. import de.tu_darmstadt.tk.SmartHomeNetworkSim.core.Link;
  13. import de.tu_darmstadt.tk.SmartHomeNetworkSim.core.Packet;
  14. import de.tu_darmstadt.tk.SmartHomeNetworkSim.core.PacketSniffer;
  15. import de.tu_darmstadt.tk.SmartHomeNetworkSim.core.protocols.packets.MQTTpublishPacket;
  16. import weka.core.Attribute;
  17. import weka.core.DenseInstance;
  18. import weka.core.Instance;
  19. import weka.core.Instances;
  20. /**
  21. * Unsupervised Classifier Basis, which contains methods for transforming {@link Packet}s into {@link Instance}s.
  22. *
  23. * @author Andreas T. Meyer-Berg
  24. */
  25. public abstract class BasicPacketClassifierWitLabels implements PacketSniffer {
  26. /**
  27. * True, if instances should be used for training
  28. */
  29. protected boolean training = true;
  30. /**
  31. * Attributes which should be taken into account
  32. */
  33. protected ArrayList<Attribute> atts = new ArrayList<Attribute>();
  34. /**
  35. * Collected Packets
  36. */
  37. protected Instances dataset;
  38. /**
  39. * CollectedPackets
  40. */
  41. protected HashMap<Link, LinkedList<Packet>> collectedPackets = new HashMap<Link, LinkedList<Packet>>();
  42. /**
  43. * HashMap for calculating transmission delay
  44. */
  45. protected HashMap<Link, LinkedList<Packet>> lastPackets = new HashMap<Link, LinkedList<Packet>>();
  46. /**
  47. * Map for the different Link names
  48. */
  49. protected HashSet<String> link_mappings = new HashSet<String>();
  50. /**
  51. * Map for the difference source device names
  52. */
  53. protected HashSet<String> source_mappings = new HashSet<String>();
  54. /**
  55. * Map for the different destination device names
  56. */
  57. protected HashSet<String> destination_mappings = new HashSet<String>();
  58. /**
  59. * Map for the protocol names
  60. */
  61. protected HashSet<String> protocol_mappings = new HashSet<String>();
  62. /**
  63. * Map for the protocol names
  64. */
  65. protected HashSet<String> packet_mappings = new HashSet<String>();
  66. /**
  67. * Number of packets which are used to calculate the current transmission speed
  68. */
  69. protected int NUMBER_OF_PACKETS = 200;
  70. private String currentScenario = "";
  71. private int scenarioRun = 0;
  72. /**
  73. * Initializes the different maps
  74. */
  75. public BasicPacketClassifierWitLabels() {
  76. // Initialize Attribute list
  77. source_mappings.add("unknown");
  78. link_mappings.add("unknown");
  79. destination_mappings.add("unknown");
  80. protocol_mappings.add("unknown");
  81. packet_mappings.add("unknown");
  82. }
  83. @Override
  84. public void processPackets(HashMap<Link, LinkedList<Packet>> packets) {
  85. if(training)
  86. try {
  87. training(packets);
  88. } catch (Exception e) {
  89. e.printStackTrace();
  90. }
  91. else
  92. classify(packets);
  93. }
  94. /**
  95. * Estimates the current Packets per second (depending on the last 100 packets of the link)
  96. * @param link Link which should be checked
  97. * @param packet Packet which should investigated
  98. * @return estimated number of packets per second
  99. */
  100. protected double getEstimatedPacketsPerSecond(Link link, Packet packet) {
  101. /**
  102. * Packets used to calculated the packets per second
  103. */
  104. LinkedList<Packet> list = lastPackets.get(link);
  105. if(list == null) {
  106. /**
  107. * Add list if not present
  108. */
  109. list = new LinkedList<Packet>();
  110. lastPackets.put(link, list);
  111. }
  112. if(list.isEmpty()) {
  113. list.addLast(packet);
  114. // Default 1 packet per second
  115. return 1.0;
  116. }
  117. if(list.size() == NUMBER_OF_PACKETS){
  118. list.removeFirst();
  119. }
  120. list.addLast(packet);
  121. /**
  122. * elapsed time in milliseconds since last packet
  123. */
  124. long elapsed_time = packet.getTimestamp()-list.getFirst().getTimestamp()/list.size();
  125. if(elapsed_time<=0)
  126. return Double.POSITIVE_INFINITY;
  127. /**
  128. * Return number of packets per second
  129. */
  130. return 1000.0/elapsed_time;
  131. }
  132. /**
  133. * Returns the instance representation of the given packet and link
  134. * @param link link the packet was sent on
  135. * @param packet packet which should be transformed
  136. * @param dataset distribution the packet is part of
  137. * @return instance representation
  138. */
  139. protected Instance packet2Instance(Link link, Packet packet, Instances dataset) {
  140. /**
  141. * Instance for the given Packet
  142. */
  143. DenseInstance instance = new DenseInstance(dataset.numAttributes());
  144. instance.setDataset(dataset);
  145. // link
  146. instance.setValue(0, stringToNominal(link_mappings, link.getName()));
  147. // source
  148. if(packet.getSource()==null) {
  149. instance.setValue(1, "unknown");
  150. instance.setValue(2, Double.NEGATIVE_INFINITY);
  151. }else if(packet.getSource().getOwner()==null){
  152. instance.setValue(1, "unknown");
  153. instance.setValue(2, packet.getSource().getPortNumber());
  154. }else {
  155. instance.setValue(1, stringToNominal(source_mappings, packet.getSource().getOwner().getName()));
  156. instance.setValue(2, packet.getSource().getPortNumber());
  157. }
  158. // Destination
  159. if(packet.getDestination()==null) {
  160. instance.setValue(3, "unknown");
  161. instance.setValue(4, Double.NEGATIVE_INFINITY);
  162. }else if(packet.getDestination().getOwner()==null){
  163. instance.setValue(3, "unknown");
  164. instance.setValue(4, packet.getDestination().getPortNumber());
  165. }else {
  166. instance.setValue(3, stringToNominal(destination_mappings, packet.getDestination().getOwner().getName()));
  167. instance.setValue(4, packet.getDestination().getPortNumber());
  168. }
  169. // Protocol name
  170. instance.setValue(5, stringToNominal(protocol_mappings, packet.getProtocolName()));
  171. // Packets per second
  172. //instance.setValue(6, getEstimatedPacketsPerSecond(link, packet));
  173. // MQTT Value
  174. if(packet instanceof MQTTpublishPacket) {
  175. MQTTpublishPacket mqttPack = (MQTTpublishPacket)packet;
  176. if(mqttPack.isBoolean()) {
  177. //System.out.println("MQTT PACK: " + mqttPack.getValue() + "," +mqttPack.getSensorValue());
  178. if(mqttPack.getValue() == 0) {
  179. instance.setValue(6,0);
  180. //System.out.println("False");
  181. }
  182. else {
  183. instance.setValue(6, 1);
  184. //System.out.println("True");
  185. }
  186. if(mqttPack.getSensorValue() == 0) {
  187. instance.setValue(7,0);
  188. //System.out.println("False");
  189. } else {
  190. instance.setValue(7, 1);
  191. //System.out.println("True");
  192. }
  193. }else {
  194. instance.setValue(6, ((MQTTpublishPacket)packet).getValue());
  195. instance.setValue(7, ((MQTTpublishPacket)packet).getSensorValue());
  196. }
  197. } else {
  198. instance.setValue(6, -100);
  199. instance.setValue(7, -100);
  200. }
  201. instance.setValue(8, stringToNominal(packet_mappings, packet.getPackageType()));
  202. return instance;
  203. }
  204. /**
  205. * Inserts the
  206. * @param map
  207. * @param nominal
  208. */
  209. protected void insertNominalIntoMap(HashSet<String> map, String nominal) {
  210. if(map == null || nominal == null)
  211. return;
  212. map.add(nominal);
  213. }
  214. /**
  215. * Transforms the String into an Number
  216. * @param map
  217. * @param s
  218. * @return
  219. */
  220. protected String stringToNominal(HashSet<String> map, String s) {
  221. return map.contains(s)?s:"unknown";
  222. }
  223. /**
  224. * Train the clusterer by collecting the packets
  225. *
  226. * @param packets packets to be learned
  227. */
  228. protected void training(HashMap<Link, LinkedList<Packet>> packets) {
  229. for(Entry<Link, LinkedList<Packet>> e:packets.entrySet()) {
  230. Link l = e.getKey();
  231. // TODO: ERROR ????????
  232. LinkedList<Packet> p = collectedPackets.get(l);
  233. if(p == null) {
  234. collectedPackets.put(l, new LinkedList<Packet>(e.getValue()));
  235. } else
  236. p.addAll(e.getValue());
  237. insertNominalIntoMap(link_mappings, l.getName());
  238. for(Packet pac: e.getValue()) {
  239. if(pac == null || pac.getSource()==null ||pac.getDestination() == null || pac.getSource().getOwner() == null || pac.getDestination().getOwner() == null)
  240. continue;
  241. insertNominalIntoMap(destination_mappings, pac.getSource().getOwner().getName());
  242. insertNominalIntoMap(destination_mappings, pac.getDestination().getOwner().getName());
  243. insertNominalIntoMap(source_mappings, pac.getSource().getOwner().getName());
  244. insertNominalIntoMap(source_mappings, pac.getDestination().getOwner().getName());
  245. insertNominalIntoMap(protocol_mappings, pac.getProtocolName());
  246. insertNominalIntoMap(packet_mappings, pac.getPackageType());
  247. }
  248. //TODO: Add packet/Link/Names etc. to mappings
  249. }
  250. }
  251. /**
  252. * Finishes the collection and trains the clusterer on the collected packets
  253. *
  254. * @throws Exception
  255. */
  256. protected void finishDataCollection() throws Exception{
  257. /**
  258. printHashSet("Link-Name", link_mappings);
  259. printHashSet("Source-Device", source_mappings);
  260. printHashSet("Destination-Port", destination_mappings);
  261. printHashSet("Protocol-name", protocol_mappings);
  262. */
  263. atts.add(new Attribute("Link-Name", new LinkedList<String>(link_mappings)));//TODO:??
  264. atts.add(new Attribute("Source-Device", new LinkedList<String>(source_mappings)));
  265. atts.add(new Attribute("Source-Port-number", false));
  266. atts.add(new Attribute("Destination-Device", new LinkedList<String>(destination_mappings)));
  267. atts.add(new Attribute("Destination-Port-number", false));
  268. Attribute pn = new Attribute("Protocol-name", new LinkedList<String>(protocol_mappings));
  269. //pn.setWeight(10);
  270. atts.add(pn);
  271. //Attribute pps = new Attribute("Packets-per-second", false);
  272. //pps.setWeight(20);
  273. //atts.add(pps);
  274. atts.add(new Attribute("PacketValue", false));
  275. //atts.add(new Attribute("Anomaly", false));
  276. // TODO: Sensor Attribute, given as side channel information
  277. atts.add(new Attribute("SensorValue", false));
  278. atts.add(new Attribute("PackageType",new LinkedList<String>(packet_mappings)));
  279. /*
  280. atts = new ArrayList<Attribute>();
  281. atts.add(new Attribute("LN", new LinkedList<String>(link_mappings)));//TODO:??
  282. atts.add(new Attribute("SD", new LinkedList<String>(source_mappings)));
  283. atts.add(new Attribute("SPN", false));
  284. atts.add(new Attribute("DD", new LinkedList<String>(destination_mappings)));
  285. atts.add(new Attribute("DPN", false));
  286. atts.add(new Attribute("PN", new LinkedList<String>(protocol_mappings)));
  287. atts.add(new Attribute("PPS", false));
  288. atts.add(new Attribute("A", false));*/
  289. dataset = new Instances("Packets", atts, 100000);
  290. //dataset.setClassIndex(7);
  291. /**
  292. * Add Instances to dataset
  293. */
  294. for (Iterator<Entry<Link, LinkedList<Packet>>> it = collectedPackets.entrySet().iterator(); it.hasNext();) {
  295. Entry<Link, LinkedList<Packet>> entry = it.next();
  296. /**
  297. * Link the packet was captured on
  298. */
  299. Link l = entry.getKey();
  300. for (Iterator<Packet> itPacket = entry.getValue().iterator(); itPacket.hasNext();) {
  301. /**
  302. * Packets to be added to the dataset
  303. */
  304. Packet packet = (Packet) itPacket.next();
  305. dataset.add(packet2Instance(l, packet, dataset));
  306. }
  307. }
  308. trainModel(dataset);
  309. }
  310. private void printHashSet(String name, HashSet<String> toPrint) {
  311. System.out.println(name+":");
  312. for (Iterator<String> iterator = toPrint.iterator(); iterator.hasNext();) {
  313. String string = (String) iterator.next();
  314. System.out.print(string);
  315. if(iterator.hasNext())
  316. System.out.print(", ");
  317. }
  318. System.out.println();
  319. }
  320. /**
  321. * Try to classify the given packets and detect anomalies
  322. * @param packets packets to be classified
  323. */
  324. protected void classify(HashMap<Link, LinkedList<Packet>> packets) {
  325. File anomalyResults = new File("results/"+getCurrentScenario() + scenarioRun + ".csv");
  326. anomalyResults.getParentFile().mkdir();
  327. BufferedWriter writer = null;
  328. try {
  329. writer = new BufferedWriter(new FileWriter(anomalyResults));
  330. writer.write("PacketRepresentation,anomalyFPorTP,sensorInfo\n");
  331. } catch (IOException e1) {
  332. // TODO Auto-generated catch block
  333. e1.printStackTrace();
  334. }
  335. int tp = 0;
  336. int fp = 0;
  337. int tn = 0;
  338. int fn = 0;
  339. long start = Long.MAX_VALUE;
  340. long end = Long.MIN_VALUE;
  341. for (Iterator<Entry<Link, LinkedList<Packet>>> it = packets.entrySet().iterator(); it.hasNext();) {
  342. /**
  343. * Link & its packets
  344. */
  345. Entry<Link, LinkedList<Packet>> entry = it.next();
  346. /**
  347. * Link the packets were captured on
  348. */
  349. Link l = entry.getKey();
  350. for (Iterator<Packet> itPacket = entry.getValue().iterator(); itPacket.hasNext();) {
  351. /**
  352. * Packet which should be checked
  353. */
  354. Packet packet = (Packet) itPacket.next();
  355. start = Math.min(start, packet.getTimestamp());
  356. end = Math.max(end, packet.getTimestamp());
  357. /**
  358. * Instance Representation
  359. */
  360. Instance packet_instance = packet2Instance(l, packet, dataset);
  361. if(packet_instance == null)continue;
  362. String sensorLabel = "";
  363. if(packet instanceof MQTTpublishPacket) {
  364. MQTTpublishPacket mqttPac = (MQTTpublishPacket)packet;
  365. sensorLabel = ""+mqttPac.getSensorValue();
  366. if(mqttPac.isBoolean()) {
  367. if(mqttPac.getSensorValue() == 0)
  368. sensorLabel = "false";
  369. else
  370. sensorLabel = "true";
  371. }
  372. sensorLabel = ","+sensorLabel;
  373. }
  374. try {
  375. double dist = classifyInstance(packet_instance, packet);
  376. //System.out.println(packet.getTextualRepresentation()+": "+packet.getLabel() +":"+sensorLabel);
  377. if(dist<=Settings.PRECISION_ERROR) {
  378. if(packet.getLabel()==0) {
  379. tn++;
  380. writer.write(packet.getTextualRepresentation()+",TN"+sensorLabel+"\n");
  381. }
  382. else {
  383. fn++;
  384. writer.write(packet.getTextualRepresentation()+",FN"+sensorLabel+"\n");
  385. //System.out.println(packet.getTextualRepresentation()+",AnomalyNotFound"+sensorLabel);
  386. }
  387. }else {
  388. if(packet.getLabel()==0) {
  389. fp++;
  390. writer.write(packet.getTextualRepresentation()+",FP"+sensorLabel+"\n");
  391. } else {
  392. tp++;
  393. writer.write(packet.getTextualRepresentation()+",TP"+sensorLabel+"\n");
  394. }
  395. }
  396. } catch (Exception e) {
  397. System.out.println(e);
  398. if(packet.getLabel()==0) {
  399. fp++;
  400. try {
  401. writer.write(packet.getTextualRepresentation()+",FP"+sensorLabel+"\n");
  402. } catch (IOException e1) {
  403. // TODO Auto-generated catch block
  404. e1.printStackTrace();
  405. }
  406. } else {
  407. tp++;
  408. try {
  409. writer.write(packet.getTextualRepresentation()+",TP"+sensorLabel+"\n");
  410. } catch (IOException e1) {
  411. // TODO Auto-generated catch block
  412. e1.printStackTrace();
  413. }
  414. }
  415. }
  416. }
  417. }
  418. int n = tp+tn+fp+fn;
  419. if(n!=0) {
  420. System.out.println(getAlgoName()+" Performance: ["+start+"ms, "+end+"ms] Scenario: " + getCurrentScenario() + scenarioRun);
  421. scenarioRun++;
  422. System.out.println("n: "+n);
  423. System.out.println("TP: "+tp);
  424. System.out.println("FP: "+fp);
  425. System.out.println("TN: "+tn);
  426. System.out.println("FN: "+fn);
  427. System.out.println("TPR: "+(tp/(tp+fn+0.0)));
  428. System.out.println("FPR: "+(fp/(fp+tn+0.0)));
  429. System.out.println("");
  430. }
  431. try {
  432. writer.close();
  433. } catch (IOException e) {
  434. // TODO Auto-generated catch block
  435. e.printStackTrace();
  436. }
  437. }
  438. /**
  439. * Train the model using the given instances
  440. * @param instances training set, which should be learned
  441. */
  442. public abstract void trainModel(Instances instances);
  443. /**
  444. * classifies the given instance
  445. * @param instance instance which should be classified
  446. * @param origin original packet, which was transformed into the instance
  447. * @return distance to next centroid
  448. * @throws Exception if anomaly was detected
  449. */
  450. public abstract double classifyInstance(Instance instance, Packet origin) throws Exception;
  451. /**
  452. * Returns the timestep, after which the classifier should start classifying instead of training.
  453. * @return timestep of the testing begin.
  454. */
  455. public abstract long getClassificationStart();
  456. @Override
  457. public void setMode(boolean testing) {
  458. training = !testing;
  459. if(testing) {
  460. try {
  461. finishDataCollection();
  462. } catch (Exception e) {
  463. System.out.println("Clustering failed");
  464. e.printStackTrace();
  465. }
  466. }
  467. }
  468. @Override
  469. public boolean getMode() {
  470. return !training;
  471. }
  472. /**
  473. * Short String representation of the classifier
  474. * @return
  475. */
  476. public abstract String getAlgoName();
  477. /**
  478. * @return the currentScenario
  479. */
  480. public String getCurrentScenario() {
  481. return currentScenario;
  482. }
  483. /**
  484. * @param currentScenario the currentScenario to set
  485. */
  486. public void setCurrentScenario(String currentScenario) {
  487. this.currentScenario = currentScenario;
  488. this.scenarioRun = 0;
  489. }
  490. }