BasicPacketClassifier.java 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417
  1. package classifier;
  2. import java.util.ArrayList;
  3. import java.util.HashMap;
  4. import java.util.HashSet;
  5. import java.util.Iterator;
  6. import java.util.LinkedList;
  7. import java.util.Map.Entry;
  8. import de.tu_darmstadt.tk.SmartHomeNetworkSim.core.Link;
  9. import de.tu_darmstadt.tk.SmartHomeNetworkSim.core.Packet;
  10. import de.tu_darmstadt.tk.SmartHomeNetworkSim.core.PacketSniffer;
  11. import weka.core.Attribute;
  12. import weka.core.DenseInstance;
  13. import weka.core.Instance;
  14. import weka.core.Instances;
  15. /**
  16. * Unsupervised Classifier Basis, which contains methods for transforming {@link Packet}s into {@link Instance}s.
  17. *
  18. * @author Andreas T. Meyer-Berg
  19. */
  20. public abstract class BasicPacketClassifier implements PacketSniffer {
  21. /**
  22. * True, if instances should be used for training
  23. */
  24. protected boolean training = true;
  25. /**
  26. * Attributes which should be taken into account
  27. */
  28. protected ArrayList<Attribute> atts = new ArrayList<Attribute>();
  29. /**
  30. * Collected Packets
  31. */
  32. protected Instances dataset;
  33. /**
  34. * CollectedPackets
  35. */
  36. protected HashMap<Link, LinkedList<Packet>> collectedPackets = new HashMap<Link, LinkedList<Packet>>();
  37. /**
  38. * HashMap for calculating transmission delay
  39. */
  40. protected HashMap<Link, LinkedList<Packet>> lastPackets = new HashMap<Link, LinkedList<Packet>>();
  41. /**
  42. * Map for the different Link names
  43. */
  44. protected HashSet<String> link_mappings = new HashSet<String>();
  45. /**
  46. * Map for the difference source device names
  47. */
  48. protected HashSet<String> source_mappings = new HashSet<String>();
  49. /**
  50. * Map for the different destination device names
  51. */
  52. protected HashSet<String> destination_mappings = new HashSet<String>();
  53. /**
  54. * Map for the protocol names
  55. */
  56. protected HashSet<String> protocol_mappings = new HashSet<String>();
  57. /**
  58. * Number of packets which are used to calculate the current transmission speed
  59. */
  60. protected int NUMBER_OF_PACKETS = 200;
  61. /**
  62. * Initializes the different maps
  63. */
  64. public BasicPacketClassifier() {
  65. // Initialize Attribute list
  66. source_mappings.add("unknown");
  67. link_mappings.add("unknown");
  68. destination_mappings.add("unknown");
  69. protocol_mappings.add("unknown");
  70. }
  71. @Override
  72. public void processPackets(HashMap<Link, LinkedList<Packet>> packets) {
  73. if(training)
  74. try {
  75. training(packets);
  76. } catch (Exception e) {
  77. e.printStackTrace();
  78. }
  79. else
  80. classify(packets);
  81. }
  82. /**
  83. * Estimates the current Packets per second (depending on the last 100 packets of the link)
  84. * @param link Link which should be checked
  85. * @param packet Packet which should investigated
  86. * @return estimated number of packets per second
  87. */
  88. protected double getEstimatedPacketsPerSecond(Link link, Packet packet) {
  89. /**
  90. * Packets used to calculated the packets per second
  91. */
  92. LinkedList<Packet> list = lastPackets.get(link);
  93. if(list == null) {
  94. /**
  95. * Add list if not present
  96. */
  97. list = new LinkedList<Packet>();
  98. lastPackets.put(link, list);
  99. }
  100. if(list.isEmpty()) {
  101. list.addLast(packet);
  102. // Default 1 packet per second
  103. return 1.0;
  104. }
  105. if(list.size() == NUMBER_OF_PACKETS){
  106. list.removeFirst();
  107. }
  108. list.addLast(packet);
  109. /**
  110. * elapsed time in milliseconds since last packet
  111. */
  112. long elapsed_time = packet.getTimestamp()-list.getFirst().getTimestamp()/list.size();
  113. if(elapsed_time<=0)
  114. return Double.POSITIVE_INFINITY;
  115. /**
  116. * Return number of packets per second
  117. */
  118. return 1000.0/elapsed_time;
  119. }
  120. /**
  121. * Returns the instance representation of the given packet and link
  122. * @param link link the packet was sent on
  123. * @param packet packet which should be transformed
  124. * @param dataset distribution the packet is part of
  125. * @return instance representation
  126. */
  127. protected Instance packet2Instance(Link link, Packet packet, Instances dataset) {
  128. /**
  129. * Instance for the given Packet
  130. */
  131. DenseInstance instance = new DenseInstance(dataset.numAttributes());
  132. instance.setDataset(dataset);
  133. // link
  134. instance.setValue(0, stringToNominal(link_mappings, link.getName()));
  135. // source
  136. if(packet.getSource()==null) {
  137. instance.setValue(1, "unknown");
  138. instance.setValue(2, Double.NEGATIVE_INFINITY);
  139. }else if(packet.getSource().getOwner()==null){
  140. instance.setValue(1, "unknown");
  141. instance.setValue(2, packet.getSource().getPortNumber());
  142. }else {
  143. instance.setValue(1, stringToNominal(source_mappings, packet.getSource().getOwner().getName()));
  144. instance.setValue(2, packet.getSource().getPortNumber());
  145. }
  146. // Destination
  147. if(packet.getDestination()==null) {
  148. instance.setValue(3, "unknown");
  149. instance.setValue(4, Double.NEGATIVE_INFINITY);
  150. }else if(packet.getDestination().getOwner()==null){
  151. instance.setValue(3, "unknown");
  152. instance.setValue(4, packet.getDestination().getPortNumber());
  153. }else {
  154. instance.setValue(3, stringToNominal(destination_mappings, packet.getDestination().getOwner().getName()));
  155. instance.setValue(4, packet.getDestination().getPortNumber());
  156. }
  157. // Protocol name
  158. instance.setValue(5, stringToNominal(protocol_mappings, packet.getProtocolName()));
  159. // Packets per second
  160. instance.setValue(6, getEstimatedPacketsPerSecond(link, packet));
  161. return instance;
  162. }
  163. /**
  164. * Inserts the
  165. * @param map
  166. * @param nominal
  167. */
  168. protected void insertNominalIntoMap(HashSet<String> map, String nominal) {
  169. if(map == null || nominal == null)
  170. return;
  171. map.add(nominal);
  172. }
  173. /**
  174. * Transforms the String into an Number
  175. * @param map
  176. * @param s
  177. * @return
  178. */
  179. protected String stringToNominal(HashSet<String> map, String s) {
  180. return map.contains(s)?s:"unknown";
  181. }
  182. /**
  183. * Train the clusterer by collecting the packets
  184. *
  185. * @param packets packets to be learned
  186. */
  187. protected void training(HashMap<Link, LinkedList<Packet>> packets) {
  188. for(Entry<Link, LinkedList<Packet>> e:packets.entrySet()) {
  189. Link l = e.getKey();
  190. // TODO: ERROR ????????
  191. LinkedList<Packet> p = collectedPackets.get(l);
  192. if(p == null) {
  193. collectedPackets.put(l, new LinkedList<Packet>(e.getValue()));
  194. } else
  195. p.addAll(e.getValue());
  196. insertNominalIntoMap(link_mappings, l.getName());
  197. for(Packet pac: e.getValue()) {
  198. if(pac == null || pac.getSource()==null ||pac.getDestination() == null || pac.getSource().getOwner() == null || pac.getDestination().getOwner() == null)
  199. continue;
  200. insertNominalIntoMap(destination_mappings, pac.getSource().getOwner().getName());
  201. insertNominalIntoMap(destination_mappings, pac.getDestination().getOwner().getName());
  202. insertNominalIntoMap(source_mappings, pac.getSource().getOwner().getName());
  203. insertNominalIntoMap(source_mappings, pac.getDestination().getOwner().getName());
  204. insertNominalIntoMap(protocol_mappings, pac.getProtocolName());
  205. }
  206. //TODO: Add packet/Link/Names etc. to mappings
  207. }
  208. }
  209. /**
  210. * Finishes the collection and trains the clusterer on the collected packets
  211. *
  212. * @throws Exception
  213. */
  214. protected void finishDataCollection() throws Exception{
  215. /**
  216. printHashSet("Link-Name", link_mappings);
  217. printHashSet("Source-Device", source_mappings);
  218. printHashSet("Destination-Port", destination_mappings);
  219. printHashSet("Protocol-name", protocol_mappings);
  220. */
  221. atts.add(new Attribute("Link-Name", new LinkedList<String>(link_mappings)));//TODO:??
  222. atts.add(new Attribute("Source-Device", new LinkedList<String>(source_mappings)));
  223. atts.add(new Attribute("Source-Port-number", false));
  224. atts.add(new Attribute("Destination-Device", new LinkedList<String>(destination_mappings)));
  225. atts.add(new Attribute("Destination-Port-number", false));
  226. Attribute pn = new Attribute("Protocol-name", new LinkedList<String>(protocol_mappings));
  227. //pn.setWeight(10);
  228. atts.add(pn);
  229. Attribute pps = new Attribute("Packets-per-second", false);
  230. //pps.setWeight(20);
  231. atts.add(pps);
  232. //atts.add(new Attribute("Anomaly", false));
  233. /*
  234. atts = new ArrayList<Attribute>();
  235. atts.add(new Attribute("LN", new LinkedList<String>(link_mappings)));//TODO:??
  236. atts.add(new Attribute("SD", new LinkedList<String>(source_mappings)));
  237. atts.add(new Attribute("SPN", false));
  238. atts.add(new Attribute("DD", new LinkedList<String>(destination_mappings)));
  239. atts.add(new Attribute("DPN", false));
  240. atts.add(new Attribute("PN", new LinkedList<String>(protocol_mappings)));
  241. atts.add(new Attribute("PPS", false));
  242. atts.add(new Attribute("A", false));*/
  243. dataset = new Instances("Packets", atts, 100000);
  244. //dataset.setClassIndex(7);
  245. /**
  246. * Add Instances to dataset
  247. */
  248. for (Iterator<Entry<Link, LinkedList<Packet>>> it = collectedPackets.entrySet().iterator(); it.hasNext();) {
  249. Entry<Link, LinkedList<Packet>> entry = it.next();
  250. /**
  251. * Link the packet was captured on
  252. */
  253. Link l = entry.getKey();
  254. for (Iterator<Packet> itPacket = entry.getValue().iterator(); itPacket.hasNext();) {
  255. /**
  256. * Packets to be added to the dataset
  257. */
  258. Packet packet = (Packet) itPacket.next();
  259. dataset.add(packet2Instance(l, packet, dataset));
  260. }
  261. }
  262. trainModel(dataset);
  263. }
  264. private void printHashSet(String name, HashSet<String> toPrint) {
  265. System.out.println(name+":");
  266. for (Iterator<String> iterator = toPrint.iterator(); iterator.hasNext();) {
  267. String string = (String) iterator.next();
  268. System.out.print(string);
  269. if(iterator.hasNext())
  270. System.out.print(", ");
  271. }
  272. System.out.println();
  273. }
  274. /**
  275. * Try to classify the given packets and detect anomalies
  276. * @param packets packets to be classified
  277. */
  278. protected void classify(HashMap<Link, LinkedList<Packet>> packets) {
  279. int tp = 0;
  280. int fp = 0;
  281. int tn = 0;
  282. int fn = 0;
  283. long start = Long.MAX_VALUE;
  284. long end = Long.MIN_VALUE;
  285. for (Iterator<Entry<Link, LinkedList<Packet>>> it = packets.entrySet().iterator(); it.hasNext();) {
  286. /**
  287. * Link & its packets
  288. */
  289. Entry<Link, LinkedList<Packet>> entry = it.next();
  290. /**
  291. * Link the packets were captured on
  292. */
  293. Link l = entry.getKey();
  294. for (Iterator<Packet> itPacket = entry.getValue().iterator(); itPacket.hasNext();) {
  295. /**
  296. * Packet which should be checked
  297. */
  298. Packet packet = (Packet) itPacket.next();
  299. start = Math.min(start, packet.getTimestamp());
  300. end = Math.max(end, packet.getTimestamp());
  301. /**
  302. * Instance Representation
  303. */
  304. Instance packet_instance = packet2Instance(l, packet, dataset);
  305. if(packet_instance == null)continue;
  306. try {
  307. double dist = classifyInstance(packet_instance, packet);
  308. if(dist<=1.0) {
  309. if(packet.getLabel()==0)
  310. tn++;
  311. else
  312. fn++;
  313. }else {
  314. if(packet.getLabel()==0)
  315. fp++;
  316. else
  317. tp++;
  318. }
  319. } catch (Exception e) {
  320. if(packet.getLabel()==0)
  321. fp++;
  322. else
  323. tp++;
  324. }
  325. }
  326. }
  327. int n = tp+tn+fp+fn;
  328. if(n!=0) {
  329. System.out.println(getAlgoName()+" Performance: ["+start+"ms, "+end+"ms]");
  330. System.out.println("n: "+n);
  331. System.out.println("TP: "+tp);
  332. System.out.println("FP: "+fp);
  333. System.out.println("TN: "+tn);
  334. System.out.println("FN: "+fn);
  335. System.out.println("TPR: "+(tp/(tp+fn+0.0)));
  336. System.out.println("FPR: "+(fp/(fp+tn+0.0)));
  337. System.out.println("");
  338. }
  339. }
  340. /**
  341. * Train the model using the given instances
  342. * @param instances training set, which should be learned
  343. */
  344. public abstract void trainModel(Instances instances);
  345. /**
  346. * classifies the given instance
  347. * @param instance instance which should be classified
  348. * @param origin original packet, which was transformed into the instance
  349. * @return distance to next centroid
  350. * @throws Exception if anomaly was detected
  351. */
  352. public abstract double classifyInstance(Instance instance, Packet origin) throws Exception;
  353. /**
  354. * Returns the timestep, after which the classifier should start classifying instead of training.
  355. * @return timestep of the testing begin.
  356. */
  357. public abstract long getClassificationStart();
  358. @Override
  359. public void setMode(boolean testing) {
  360. training = !testing;
  361. if(testing) {
  362. try {
  363. finishDataCollection();
  364. } catch (Exception e) {
  365. System.out.println("Clustering failed");
  366. e.printStackTrace();
  367. }
  368. }
  369. }
  370. @Override
  371. public boolean getMode() {
  372. return !training;
  373. }
  374. /**
  375. * Short String representation of the classifier
  376. * @return
  377. */
  378. public abstract String getAlgoName();
  379. }