/** * Hub Miner: a hubness-aware machine learning experimentation library. * Copyright (C) 2014 Nenad Tomasev. Email: nenad.tomasev at gmail.com * * This program is free software: you can redistribute it and/or modify it under * the terms of the GNU General Public License as published by the Free Software * Foundation, either version 3 of the License, or (at your option) any later * version. * * This program is distributed in the hope that it will be useful, but WITHOUT * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. * * You should have received a copy of the GNU General Public License along with * this program. If not, see . */ package learning.supervised.methods.knn; import learning.supervised.interfaces.AutomaticKFinderInterface; import learning.supervised.evaluation.ValidateableInterface; import distances.primary.CombinedMetric; import data.representation.DataInstance; import data.representation.DataSet; import learning.supervised.Category; import learning.supervised.Classifier; import data.neighbors.NeighborSetFinder; import java.io.Serializable; import learning.supervised.interfaces.DistToPointsQueryUserInterface; import learning.supervised.interfaces.NeighborPointsQueryUserInterface; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; /** * This class implements the basic k-nearest neighbor classifier. * * @author Nenad Tomasev */ public class KNN extends Classifier implements AutomaticKFinderInterface, DistToPointsQueryUserInterface, NeighborPointsQueryUserInterface, Serializable { private static final long serialVersionUID = 1L; // The training dataset. private DataSet trainingData = null; // The number of classes in the data. private int numClasses = 0; // The neighborhood size. private int k = 1; // The prior class distribution. float[] classPriors; /** * Default constructor. */ public KNN() { } @Override public HashMap getParameterNamesAndDescriptions() { HashMap paramMap = new HashMap<>(); paramMap.put("k", "Neighborhood size."); return paramMap; } @Override public long getVersion() { return serialVersionUID; } @Override public String getName() { return "kNN"; } /** * Initialization. * * @param k Integer that is the neighborhood size. * @param cmet CombinedMetric object for distance calculations. */ public KNN(int k, CombinedMetric cmet) { setCombinedMetric(cmet); this.k = k; } /** * Initialization. * * @param dset DataSet object used for model training. * @param numClasses Integer that is the number of classes in the data. * @param cmet CombinedMetric object for distance calculations. * @param k Integer that is the neighborhood size. */ public KNN(DataSet dset, int numClasses, CombinedMetric cmet, int k) { trainingData = dset; this.numClasses = numClasses; setCombinedMetric(cmet); this.k = k; } /** * @param numClasses Integer that is the number of classes in the data. */ public void setNumClasses(int numClasses) { this.numClasses = numClasses; } /** * @return Integer that is the number of classes in the data. */ public int getNumClasses() { return numClasses; } /** * @return Integer that is the neighborhood size used in calculations. */ public int getK() { return k; } /** * @param k Integer that is the neighborhood size used in calculations. */ public void setK(int k) { this.k = k; } /** * Initialization. * * @param categories Category[] representing the training data. * @param cmet CombinedMetric object for distance calculations. * @param k Integer that is the neighborhood size. */ public KNN(Category[] categories, CombinedMetric cmet, int k) { int totalSize = 0; int indexFirstNonEmptyClass = -1; for (int cIndex = 0; cIndex < categories.length; cIndex++) { totalSize += categories[cIndex].size(); if (indexFirstNonEmptyClass == -1 && categories[cIndex].size() > 0) { indexFirstNonEmptyClass = cIndex; } } // Instances are not embedded in the internal data context. trainingData = new DataSet(); trainingData.fAttrNames = categories[indexFirstNonEmptyClass]. getInstance(0).getEmbeddingDataset().fAttrNames; trainingData.iAttrNames = categories[indexFirstNonEmptyClass]. getInstance(0).getEmbeddingDataset().iAttrNames; trainingData.sAttrNames = categories[indexFirstNonEmptyClass]. getInstance(0).getEmbeddingDataset().sAttrNames; trainingData.data = new ArrayList<>(totalSize); for (int cIndex = 0; cIndex < categories.length; cIndex++) { for (int i = 0; i < categories[cIndex].size(); i++) { categories[cIndex].getInstance(i).setCategory(cIndex); trainingData.addDataInstance(categories[cIndex].getInstance(i)); } } setCombinedMetric(cmet); this.k = k; numClasses = trainingData.countCategories(); } @Override public void setClasses(Category[] categories) { int totalSize = 0; int indexFirstNonEmptyClass = -1; for (int cIndex = 0; cIndex < categories.length; cIndex++) { totalSize += categories[cIndex].size(); if (indexFirstNonEmptyClass == -1 && categories[cIndex].size() > 0) { indexFirstNonEmptyClass = cIndex; } } // Instances are not embedded in the internal data context. trainingData = new DataSet(); trainingData.fAttrNames = categories[indexFirstNonEmptyClass]. getInstance(0).getEmbeddingDataset().fAttrNames; trainingData.iAttrNames = categories[indexFirstNonEmptyClass]. getInstance(0).getEmbeddingDataset().iAttrNames; trainingData.sAttrNames = categories[indexFirstNonEmptyClass]. getInstance(0).getEmbeddingDataset().sAttrNames; trainingData.data = new ArrayList<>(totalSize); for (int cIndex = 0; cIndex < categories.length; cIndex++) { for (int i = 0; i < categories[cIndex].size(); i++) { categories[cIndex].getInstance(i).setCategory(cIndex); trainingData.addDataInstance(categories[cIndex].getInstance(i)); } } numClasses = trainingData.countCategories(); } @Override public ValidateableInterface copyConfiguration() { return new KNN(trainingData, numClasses, getCombinedMetric(), k); } @Override public void findK(int kMin, int kMax) throws Exception { numClasses = trainingData.countCategories(); NeighborSetFinder nsf = new NeighborSetFinder(trainingData, getCombinedMetric()); nsf.calculateDistances(); nsf.calculateNeighborSets(kMax); // The array that holds the accuracy for the entire range of tested // neighborhood sizes. float[] accuracyArray = new float[kMax - kMin + 1]; // The current best achieved accuracy. float currMaxAcc = -1f; // The current optimal neighborhood size. int currMaxK = 0; int dataSize = trainingData.size(); // Votes and decisions are updated incrementally, which reduces the // computational complexity. float[][] currVoteClassCounts = new float[dataSize][numClasses]; int[] currPredictions = new int[dataSize]; // The label of the current vote. int voteLabel; // The k-nearest neighbor sets on the training data. int[][] kneighbors = nsf.getKNeighbors(); // The current accuracy. float currAccuracy; for (int kInc = 0; kInc < accuracyArray.length; kInc++) { currAccuracy = 0; // Find the accuracy of the method on the training data for the // given k value. for (int catIndex = 0; catIndex < getClasses().length; catIndex++) { for (int i = 0; i < dataSize; i++) { voteLabel = trainingData.getLabelOf( kneighbors[i][kMin + kInc - 1]); currVoteClassCounts[i][voteLabel]++; if (currPredictions[i] != voteLabel) { // Check if the decision needs to be updated. if (currVoteClassCounts[i][voteLabel] > currVoteClassCounts[i][currPredictions[i]]) { currPredictions[i] = voteLabel; } } if (currPredictions[i] == trainingData.getLabelOf(i)) { currAccuracy++; } } } // Normalize the accuracy. currAccuracy /= (float) dataSize; accuracyArray[kInc] = currAccuracy; // Update the best parameter values. if (currMaxAcc < currAccuracy) { currMaxAcc = currAccuracy; currMaxK = kMin + kInc; } } // Set the optimal neighborhood size as the actual one. k = currMaxK; } @Override public void train() throws Exception { if (k <= 0) { // If an invalid neighborhood size was provided, automatically // search for the optimal one in the lower k-range. findK(1, 20); } if (trainingData != null) { numClasses = Math.max(numClasses, trainingData.countCategories()); classPriors = trainingData.getClassPriors(); } } @Override public int classify(DataInstance instance) throws Exception { float[] classProbs = classifyProbabilistically(instance); float maxProb = 0; int maxClass = 0; for (int cIndex = 0; cIndex < numClasses; cIndex++) { if (classProbs[cIndex] > maxProb) { maxProb = classProbs[cIndex]; maxClass = cIndex; } } return maxClass; } @Override public float[] classifyProbabilistically(DataInstance instance) throws Exception { CombinedMetric cmet = getCombinedMetric(); // Calculate the kNN set. float[] kDistances = new float[k]; Arrays.fill(kDistances, Float.MAX_VALUE); int[] kNeighbors = new int[k]; float currDistance; int index; for (int i = 0; i < trainingData.size(); i++) { currDistance = cmet.dist(trainingData.data.get(i), instance); if (currDistance < kDistances[k - 1]) { // Insertion. index = k - 1; while (index > 0 && kDistances[index - 1] > currDistance) { kDistances[index] = kDistances[index - 1]; kNeighbors[index] = kNeighbors[index - 1]; index--; } kDistances[index] = currDistance; kNeighbors[index] = i; } } // Perform the voting. float[] classProbEstimates = new float[numClasses]; for (int kIndex = 0; kIndex < k; kIndex++) { classProbEstimates[trainingData.getLabelOf(kNeighbors[kIndex])]++; } // Normalize. float probTotal = 0; for (int cIndex = 0; cIndex < numClasses; cIndex++) { probTotal += classProbEstimates[cIndex]; } if (probTotal > 0) { for (int cIndex = 0; cIndex < numClasses; cIndex++) { classProbEstimates[cIndex] /= probTotal; } } else { classProbEstimates = Arrays.copyOf(classPriors, numClasses); } return classProbEstimates; } @Override public int classify(DataInstance instance, float[] distToTraining) throws Exception { float[] classProbs = classifyProbabilistically(instance, distToTraining); float maxProb = 0; int maxClassIndex = 0; for (int cIndex = 0; cIndex < numClasses; cIndex++) { if (classProbs[cIndex] > maxProb) { maxProb = classProbs[cIndex]; maxClassIndex = cIndex; } } return maxClassIndex; } @Override public float[] classifyProbabilistically(DataInstance instance, float[] distToTraining) throws Exception { // Calculate the kNN set. float[] kDistances = new float[k]; Arrays.fill(kDistances, Float.MAX_VALUE); int[] kNeighbors = new int[k]; int index; for (int i = 0; i < trainingData.size(); i++) { if (distToTraining[i] < kDistances[k - 1]) { // Insertion. index = k - 1; while (index > 0 && kDistances[index - 1] > distToTraining[i]) { kDistances[index] = kDistances[index - 1]; kNeighbors[index] = kNeighbors[index - 1]; index--; } kDistances[index] = distToTraining[i]; kNeighbors[index] = i; } } // Perform the voting. float[] classProbEstimates = new float[numClasses]; for (int kIndex = 0; kIndex < k; kIndex++) { classProbEstimates[trainingData.getLabelOf(kNeighbors[kIndex])]++; } // Normalize. float probTotal = 0; for (int cIndex = 0; cIndex < numClasses; cIndex++) { probTotal += classProbEstimates[cIndex]; } if (probTotal > 0) { for (int cIndex = 0; cIndex < numClasses; cIndex++) { classProbEstimates[cIndex] /= probTotal; } } else { classProbEstimates = Arrays.copyOf(classPriors, numClasses); } return classProbEstimates; } /** * Classify the point of interest based on the kNN set and the distances to * the neighbor points. * * @param instance DataInstance object that is to be classified. * @param kDistances float[] representing the distances to the k-nearest * neighbors. * @param trNeighbors int[] representing the indexes of the kNN set. * @return Integer that is the predicted class affiliation in the point of * interest. * @throws Exception */ public int classifyWithKDistAndNeighbors(DataInstance instance, float[] kDistances, int[] trNeighbors) throws Exception { float[] classProbEstimates = new float[numClasses]; // Perform the voting. for (int kIndex = 0; kIndex < k; kIndex++) { classProbEstimates[trainingData.getLabelOf(trNeighbors[kIndex])]++; } // Normalize. float probTotal = 0; for (int cIndex = 0; cIndex < numClasses; cIndex++) { probTotal += classProbEstimates[cIndex]; } if (probTotal > 0) { for (int cIndex = 0; cIndex < numClasses; cIndex++) { classProbEstimates[cIndex] /= probTotal; } } else { classProbEstimates = Arrays.copyOf(classPriors, numClasses); } float maxProb = 0; int maxClassIndex = 0; for (int cIndex = 0; cIndex < numClasses; cIndex++) { if (classProbEstimates[cIndex] > maxProb) { maxProb = classProbEstimates[cIndex]; maxClassIndex = cIndex; } } return maxClassIndex; } /** * Classify the point of interest based on the kNN set and the distances to * the neighbor points. * * @param instance DataInstance object that is to be classified. * @param kDistances float[] representing the distances to the k-nearest * neighbors. * @param trNeighbors int[] representing the indexes of the kNN set. * @return float[] that is the predicted class distribution in the point of * interest. * @throws Exception */ public float[] classifyProbabilisticallyWithKDistAndNeighbors( DataInstance instance, float[] kDistances, int[] trNeighbors) throws Exception { float[] classProbEstimates = new float[numClasses]; // Perform the voting. for (int kIndex = 0; kIndex < k; kIndex++) { classProbEstimates[trainingData.getLabelOf(trNeighbors[kIndex])]++; } // Normalize. float probTotal = 0; for (int cIndex = 0; cIndex < numClasses; cIndex++) { probTotal += classProbEstimates[cIndex]; } if (probTotal > 0) { for (int cIndex = 0; cIndex < numClasses; cIndex++) { classProbEstimates[cIndex] /= probTotal; } } else { classProbEstimates = Arrays.copyOf(classPriors, numClasses); } return classProbEstimates; } @Override public int classify(DataInstance instance, float[] distToTraining, int[] trNeighbors) throws Exception { float[] classProbs = classifyProbabilistically(instance, distToTraining, trNeighbors); float maxProb = 0; int maxClassIndex = 0; for (int cIndex = 0; cIndex < numClasses; cIndex++) { if (classProbs[cIndex] > maxProb) { maxProb = classProbs[cIndex]; maxClassIndex = cIndex; } } return maxClassIndex; } @Override public float[] classifyProbabilistically(DataInstance instance, float[] distToTraining, int[] trNeighbors) throws Exception { float[] classProbEstimates = new float[numClasses]; for (int kIndex = 0; kIndex < k; kIndex++) { classProbEstimates[trainingData.getLabelOf(trNeighbors[kIndex])]++; } float probTotal = 0; for (int cIndex = 0; cIndex < numClasses; cIndex++) { probTotal += classProbEstimates[cIndex]; } if (probTotal > 0) { for (int cIndex = 0; cIndex < numClasses; cIndex++) { classProbEstimates[cIndex] /= probTotal; } } else { classProbEstimates = Arrays.copyOf(classPriors, numClasses); } return classProbEstimates; } }