123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121 |
- package classifier;
- import de.tu_darmstadt.tk.SmartHomeNetworkSim.core.Packet;
- import weka.clusterers.SimpleKMeans;
- import weka.core.Instance;
- import weka.core.Instances;
- import weka.core.SelectedTag;
- public class KMeansClustering extends BasicPacketClassifier {
-
- private SimpleKMeans clusterer;
-
-
-
-
- protected int NUMBER_OF_CLUSTERS = 34;
- protected double[] stdv = new double[NUMBER_OF_CLUSTERS];
-
- public KMeansClustering() {
- super();
- clusterer = new SimpleKMeans();
- clusterer.setSeed(42);
-
- clusterer.setInitializationMethod(new SelectedTag(SimpleKMeans.FARTHEST_FIRST,SimpleKMeans.TAGS_SELECTION));
-
-
-
- try {
- clusterer.setNumClusters(this.NUMBER_OF_CLUSTERS);
- } catch (Exception e) {
- System.out.println("Error while building cluster");
- e.printStackTrace();
- }
- }
- @Override
- public void trainModel(Instances instances) {
- try {
- clusterer.buildClusterer(instances);
- double[] sumOfSquares = new double[NUMBER_OF_CLUSTERS];
- for(Instance i: instances) {
-
- int x = clusterer.clusterInstance(i);
-
- Instance center = clusterer.getClusterCentroids().get(x);
-
- double dist = clusterer.getDistanceFunction().distance(center, i);
- sumOfSquares[x] += dist*dist;
- }
-
- for(int i = 0; i<NUMBER_OF_CLUSTERS; i++)
- this.stdv[i] = Math.sqrt(sumOfSquares[i]);
- } catch (Exception e) {
- System.out.println("Failed while training the classifier");
- e.printStackTrace();
- }
- }
- private boolean test = true;
- @Override
- public double classifyInstance(Instance instance, Packet origin) throws Exception {
-
- int x = clusterer.clusterInstance(instance);
-
- Instance center = clusterer.getClusterCentroids().get(x);
-
- double dist = clusterer.getDistanceFunction().distance(center, instance);
- if(test && dist<stdv[x] && origin.getLabel()!=0) {
- test = false;
- System.out.println("Analysis of: "+origin.getTextualRepresentation());
- System.out.println("Classified as: "+x+" Dist: "+dist+" Stdv: "+stdv[x]);
- for(int i=0; i<NUMBER_OF_CLUSTERS; i++) {
- Instance centroid = clusterer.getClusterCentroids().get(i);
- if(centroid == null)continue;
- double d = clusterer.getDistanceFunction().distance(centroid, instance);
-
- System.out.println("Cluster: "+i+" Dist: "+d+" Stdv: "+stdv[i]);
- }
- test = false;
- System.out.println("");
- }
- if(dist < stdv[x])
- return 0;
- else
- return Double.MAX_VALUE;
-
- }
- @Override
- public long getClassificationStart() {
- return 3600000;
- }
- @Override
- public String getAlgoName() {
- return "KNN";
- }
- }
|