KMeansClustering.java 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. package classifier;
  2. import de.tu_darmstadt.tk.SmartHomeNetworkSim.core.Packet;
  3. import weka.clusterers.SimpleKMeans;
  4. import weka.core.Instance;
  5. import weka.core.Instances;
  6. import weka.core.SelectedTag;
  7. import de.tu_darmstadt.tk.SmartHomeNetworkSim.core.simpleImplementation.BasicPacketClassifier;
  8. /**
  9. * Unsupervised Example: K Means Clustering
  10. *
  11. * @author Andreas T. Meyer-Berg
  12. */
  13. public class KMeansClustering extends BasicPacketClassifier {
  14. /**
  15. * Clusterer
  16. */
  17. private SimpleKMeans clusterer;
  18. /**
  19. * Number of Clusters
  20. */
  21. //17 works fine
  22. //34 found value anomalies
  23. protected int NUMBER_OF_CLUSTERS = 34;
  24. protected double[] stdv = new double[NUMBER_OF_CLUSTERS];
  25. /**
  26. * Initializes the k means clusterer
  27. */
  28. public KMeansClustering() {
  29. super();
  30. clusterer = new SimpleKMeans();
  31. clusterer.setSeed(42);
  32. //clusterer.setDisplayStdDevs(true);
  33. clusterer.setInitializationMethod(new SelectedTag(SimpleKMeans.FARTHEST_FIRST,SimpleKMeans.TAGS_SELECTION));
  34. //clusterer.setCanopyPeriodicPruningRate(100);
  35. //clusterer.setCanopyT1(0.001);
  36. //clusterer.setCanopyT2(0.1);
  37. try {
  38. clusterer.setNumClusters(this.NUMBER_OF_CLUSTERS);
  39. } catch (Exception e) {
  40. System.out.println("Error while building cluster");
  41. e.printStackTrace();
  42. }
  43. }
  44. @Override
  45. public void trainModel(Instances instances) {
  46. try {
  47. clusterer.buildClusterer(instances);
  48. double[] sumOfSquares = new double[NUMBER_OF_CLUSTERS];
  49. for(Instance i: instances) {
  50. /**
  51. * Id of the closest cluster centroid
  52. */
  53. int x = clusterer.clusterInstance(i);
  54. /**
  55. * centroid instance
  56. */
  57. Instance center = clusterer.getClusterCentroids().get(x);
  58. /**
  59. * Distance
  60. */
  61. double dist = clusterer.getDistanceFunction().distance(center, i);
  62. sumOfSquares[x] += dist*dist;
  63. }
  64. /**
  65. * Calculate Standard Deviations
  66. */
  67. for(int i = 0; i<NUMBER_OF_CLUSTERS; i++)
  68. this.stdv[i] = Math.sqrt(sumOfSquares[i]);
  69. } catch (Exception e) {
  70. System.out.println("Failed while training the classifier");
  71. e.printStackTrace();
  72. }
  73. }
  74. private boolean test = true;
  75. @Override
  76. public double classifyInstance(Instance instance, Packet origin) throws Exception {
  77. /**
  78. * Id of the closest cluster centroid
  79. */
  80. int x = clusterer.clusterInstance(instance);
  81. /**
  82. * centroid instance
  83. */
  84. Instance center = clusterer.getClusterCentroids().get(x);
  85. double dist = clusterer.getDistanceFunction().distance(center, instance);
  86. if(test && dist<stdv[x] && origin.getLabel()!=0) {
  87. test = false;
  88. System.out.println("Analysis of: "+origin.getTextualRepresentation());
  89. System.out.println("Classified as: "+x+" Dist: "+dist+" Stdv: "+stdv[x]);
  90. for(int i=0; i<NUMBER_OF_CLUSTERS; i++) {
  91. Instance centroid = clusterer.getClusterCentroids().get(i);
  92. if(centroid == null)continue;
  93. double d = clusterer.getDistanceFunction().distance(centroid, instance);
  94. System.out.println("Cluster: "+i+" Dist: "+d+" Stdv: "+stdv[i]);
  95. }
  96. test = false;
  97. System.out.println("");
  98. }
  99. if(dist < stdv[x])
  100. return 0;
  101. else
  102. return Double.MAX_VALUE;
  103. }
  104. @Override
  105. public long getClassificationStart() {
  106. return 3600000;
  107. }
  108. @Override
  109. public String getAlgoName() {
  110. return "KNN";
  111. }
  112. }