KMeansClustering.java 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. package classifier;
  2. import de.tu_darmstadt.tk.SmartHomeNetworkSim.core.Packet;
  3. import weka.clusterers.SimpleKMeans;
  4. import weka.core.Instance;
  5. import weka.core.Instances;
  6. /**
  7. * Unsupervised Example: K Means Clustering
  8. *
  9. * @author Andreas T. Meyer-Berg
  10. */
  11. public class KMeansClustering extends BasicPacketClassifier {
  12. /**
  13. * Clusterer
  14. */
  15. private SimpleKMeans clusterer;
  16. /**
  17. * Number of Clusters
  18. */
  19. protected int NUMBER_OF_CLUSTERS = 16;
  20. protected double[] stdv = new double[NUMBER_OF_CLUSTERS];
  21. /**
  22. * Initializes the k means clusterer
  23. */
  24. public KMeansClustering() {
  25. super();
  26. clusterer = new SimpleKMeans();
  27. clusterer.setSeed(42);/*
  28. clusterer.setCanopyPeriodicPruningRate(100);
  29. clusterer.setCanopyT1(0.5);
  30. clusterer.setCanopyT2(1.0);*/
  31. try {
  32. clusterer.setNumClusters(this.NUMBER_OF_CLUSTERS);
  33. } catch (Exception e) {
  34. System.out.println("Error while building cluster");
  35. e.printStackTrace();
  36. }
  37. }
  38. @Override
  39. public void trainModel(Instances instances) {
  40. try {
  41. clusterer.buildClusterer(instances);
  42. double[] sumOfSquares = new double[NUMBER_OF_CLUSTERS];
  43. for(Instance i: instances) {
  44. /**
  45. * Id of the closest cluster centroid
  46. */
  47. int x = clusterer.clusterInstance(i);
  48. /**
  49. * centroid instance
  50. */
  51. Instance center = clusterer.getClusterCentroids().get(x);
  52. /**
  53. * Distance
  54. */
  55. double dist = clusterer.getDistanceFunction().distance(center, i);
  56. sumOfSquares[x] += dist*dist;
  57. }
  58. /**
  59. * Calculate Standard Deviations
  60. */
  61. for(int i = 0; i<NUMBER_OF_CLUSTERS; i++)
  62. this.stdv[i] = Math.sqrt(sumOfSquares[i]);
  63. } catch (Exception e) {
  64. System.out.println("Failed while training the classifier");
  65. e.printStackTrace();
  66. }
  67. }
  68. @Override
  69. public double classifyInstance(Instance instance, Packet origin) throws Exception {
  70. /**
  71. * Id of the closest cluster centroid
  72. */
  73. int x = clusterer.clusterInstance(instance);
  74. /**
  75. * centroid instance
  76. */
  77. Instance center = clusterer.getClusterCentroids().get(x);
  78. double dist = clusterer.getDistanceFunction().distance(center, instance);
  79. if(dist < stdv[x])
  80. return 0;
  81. else
  82. return Double.MAX_VALUE;
  83. }
  84. @Override
  85. public long getClassificationStart() {
  86. return 3600000;
  87. }
  88. @Override
  89. public String getAlgoName() {
  90. return "KNN";
  91. }
  92. }