SWCKMeansClustering.java 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. package de.tu_darmstadt.tk.SmartHomeNetworkSim.evaluation;
  2. import de.tu_darmstadt.tk.SmartHomeNetworkSim.core.Packet;
  3. import weka.clusterers.SimpleKMeans;
  4. import weka.core.Instance;
  5. import weka.core.Instances;
  6. import weka.core.SelectedTag;
  7. /**
  8. * Unsupervised Example: K Means Clustering
  9. *
  10. * @author Andreas T. Meyer-Berg
  11. */
  12. public class SWCKMeansClustering extends BasicPacketClassifier {
  13. /**
  14. * Clusterer
  15. */
  16. private SimpleKMeans clusterer;
  17. /**
  18. * Number of Clusters
  19. */
  20. //17 works fine
  21. //34 found value anomalies
  22. protected int NUMBER_OF_CLUSTERS = 8;
  23. protected double[] stdv = new double[NUMBER_OF_CLUSTERS];
  24. /**
  25. * Initializes the k means clusterer
  26. */
  27. public SWCKMeansClustering() {
  28. super();
  29. clusterer = new SimpleKMeans();
  30. clusterer.setSeed(42);
  31. //clusterer.setDisplayStdDevs(true);
  32. clusterer.setInitializationMethod(new SelectedTag(SimpleKMeans.FARTHEST_FIRST,SimpleKMeans.TAGS_SELECTION));
  33. //clusterer.setCanopyPeriodicPruningRate(100);
  34. //clusterer.setCanopyT1(0.001);
  35. //clusterer.setCanopyT2(0.1);
  36. try {
  37. clusterer.setNumClusters(this.NUMBER_OF_CLUSTERS);
  38. } catch (Exception e) {
  39. System.out.println("Error while building cluster");
  40. e.printStackTrace();
  41. }
  42. }
  43. @Override
  44. public void trainModel(Instances instances) {
  45. try {
  46. clusterer.buildClusterer(instances);
  47. double[] sumOfSquares = new double[NUMBER_OF_CLUSTERS];
  48. for(Instance i: instances) {
  49. /**
  50. * Id of the closest cluster centroid
  51. */
  52. int x = clusterer.clusterInstance(i);
  53. /**
  54. * centroid instance
  55. */
  56. Instance center = clusterer.getClusterCentroids().get(x);
  57. /**
  58. * Distance
  59. */
  60. double dist = clusterer.getDistanceFunction().distance(center, i);
  61. sumOfSquares[x] += dist*dist;
  62. }
  63. /**
  64. * Calculate Standard Deviations
  65. */
  66. for(int i = 0; i<NUMBER_OF_CLUSTERS; i++)
  67. this.stdv[i] = Math.sqrt(sumOfSquares[i]);
  68. } catch (Exception e) {
  69. System.out.println("Failed while training the classifier");
  70. e.printStackTrace();
  71. }
  72. }
  73. private boolean test = true;
  74. @Override
  75. public double classifyInstance(Instance instance, Packet origin) throws Exception {
  76. /**
  77. * Id of the closest cluster centroid
  78. */
  79. int x = clusterer.clusterInstance(instance);
  80. /**
  81. * centroid instance
  82. */
  83. Instance center = clusterer.getClusterCentroids().get(x);
  84. double dist = clusterer.getDistanceFunction().distance(center, instance);
  85. if(test && dist<stdv[x] && origin.getLabel()!=0) {
  86. test = false;
  87. System.out.println("Analysis of: "+origin.getTextualRepresentation());
  88. System.out.println("Classified as: "+x+" Dist: "+dist+" Stdv: "+stdv[x]);
  89. for(int i=0; i<NUMBER_OF_CLUSTERS; i++) {
  90. Instance centroid = clusterer.getClusterCentroids().get(i);
  91. if(centroid == null)continue;
  92. double d = clusterer.getDistanceFunction().distance(centroid, instance);
  93. System.out.println("Cluster: "+i+" Dist: "+d+" Stdv: "+stdv[i]);
  94. }
  95. test = false;
  96. System.out.println("");
  97. }
  98. if(dist < stdv[x])
  99. return 0;
  100. else
  101. return Double.MAX_VALUE;
  102. }
  103. @Override
  104. public long getClassificationStart() {
  105. return 3600000;
  106. }
  107. @Override
  108. public String getAlgoName() {
  109. return "KNN";
  110. }
  111. }