package diva.sketch.classification;

import com.jrefinery.chart.ValueAxis;
import java.util.Iterator;
import java.util.Vector;

/* loaded from: input_file:jsky-2.0/lib/diva.jar:diva/sketch/classification/KNNClassifier.class */
public class KNNClassifier implements TrainableClassifier {
    private int _k;
    private TrainingSet _set;
    public static double[] f1 = {2.0d, 2.0d, 2.0d, 2.0d, 2.0d};
    public static double[] f2 = {3.0d, 3.0d, 3.0d, 3.0d, 3.0d};
    public static double[] f3 = {4.0d, 4.0d, 4.0d, 4.0d, 4.0d};
    public static double[] f4 = {1.0d, 1.0d, 1.0d, 1.0d, 1.0d};
    public static double[] f5 = {10.0d, 10.0d, 10.0d, 10.0d, 10.0d};
    public static double[] f6 = {20.0d, 25.0d, 40.0d, 45.0d, 100.0d};
    public static double[] f7 = {ValueAxis.DEFAULT_MINIMUM_AXIS_VALUE, ValueAxis.DEFAULT_MINIMUM_AXIS_VALUE, ValueAxis.DEFAULT_MINIMUM_AXIS_VALUE, ValueAxis.DEFAULT_MINIMUM_AXIS_VALUE, ValueAxis.DEFAULT_MINIMUM_AXIS_VALUE};
    public static double[] f8 = {-1.0d, -1.0d, -1.0d, -1.0d, -1.0d};
    public static double[] f9 = {-2.0d, -2.0d, -2.0d, -2.0d, -2.0d};

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:jsky-2.0/lib/diva.jar:diva/sketch/classification/KNNClassifier$Pair.class */
    public class Pair {
        public String type;
        public double distance;
        private final KNNClassifier this$0;

        public Pair(KNNClassifier kNNClassifier, String str, double d) {
            this.this$0 = kNNClassifier;
            this.type = str;
            this.distance = d;
        }
    }

    public KNNClassifier() {
        this._k = 5;
    }

    public KNNClassifier(int i) {
        this._k = 5;
        this._k = i;
    }

    public void setK(int i) {
        this._k = i;
    }

    @Override // diva.sketch.classification.TrainableClassifier
    public void train(TrainingSet trainingSet) throws ClassifierException {
        if (this._set == null) {
            this._set = trainingSet;
            return;
        }
        Iterator types = trainingSet.types();
        while (types.hasNext()) {
            String str = (String) types.next();
            Iterator positiveExamples = trainingSet.positiveExamples(str);
            while (positiveExamples.hasNext()) {
                this._set.addPositiveExample(str, (FeatureSet) positiveExamples.next());
            }
            Iterator negativeExamples = trainingSet.negativeExamples(str);
            while (negativeExamples.hasNext()) {
                this._set.addNegativeExample(str, (FeatureSet) negativeExamples.next());
            }
        }
    }

    @Override // diva.sketch.classification.TrainableClassifier
    public boolean isIncremental() {
        return true;
    }

    @Override // diva.sketch.classification.TrainableClassifier
    public void clear() {
        this._set = null;
    }

    @Override // diva.sketch.classification.Classifier
    public Classification classify(FeatureSet featureSet) throws ClassifierException {
        Vector vector = new Vector(this._k);
        int i = 0;
        featureSet.getFeatureCount();
        Iterator types = this._set.types();
        while (types.hasNext()) {
            String str = (String) types.next();
            Iterator positiveExamples = this._set.positiveExamples(str);
            while (positiveExamples.hasNext()) {
                double compare = compare((FeatureSet) positiveExamples.next(), featureSet);
                boolean z = false;
                int i2 = 0;
                while (true) {
                    if (i2 >= i) {
                        break;
                    }
                    if (compare < ((Pair) vector.get(i2)).distance) {
                        vector.insertElementAt(new Pair(this, str, compare), i2);
                        z = true;
                        if (vector.size() > this._k) {
                            vector.removeElementAt(vector.size() - 1);
                        } else {
                            i++;
                        }
                    } else {
                        i2++;
                    }
                }
                if (!z && i < this._k) {
                    vector.add(new Pair(this, str, compare));
                    i++;
                }
            }
        }
        String[] strArr = new String[i];
        double[] dArr = new double[i];
        for (int i3 = 0; i3 < i; i3++) {
            Pair pair = (Pair) vector.get(i3);
            strArr[i3] = pair.type;
            dArr[i3] = -pair.distance;
        }
        return new Classification(strArr, dArr);
    }

    private static double compare(FeatureSet featureSet, FeatureSet featureSet2) {
        if (featureSet.getFeatureCount() != featureSet2.getFeatureCount()) {
            return Double.MAX_VALUE;
        }
        double[] features = featureSet.getFeatures();
        double[] features2 = featureSet2.getFeatures();
        int i = 0;
        for (int i2 = 0; i2 < features.length; i2++) {
            i = (int) (i + Math.abs(features[i2] - features2[i2]));
        }
        return i;
    }

    public static void main(String[] strArr) {
        try {
            KNNClassifier kNNClassifier = new KNNClassifier();
            TrainingSet trainingSet = new TrainingSet();
            trainingSet.addPositiveExample("t1", new FeatureSet(f1));
            trainingSet.addPositiveExample("t1", new FeatureSet(f2));
            trainingSet.addPositiveExample("t1", new FeatureSet(f3));
            trainingSet.addPositiveExample("t2", new FeatureSet(f4));
            trainingSet.addPositiveExample("t2", new FeatureSet(f5));
            trainingSet.addPositiveExample("t2", new FeatureSet(f6));
            trainingSet.addPositiveExample("t3", new FeatureSet(f7));
            trainingSet.addPositiveExample("t3", new FeatureSet(f8));
            trainingSet.addPositiveExample("t3", new FeatureSet(f9));
            kNNClassifier.train(trainingSet);
            System.out.println(kNNClassifier.classify(new FeatureSet(new double[]{1.0d, 1.0d, 1.0d, 1.0d, 1.0d})));
        } catch (ClassifierException e) {
            e.printStackTrace();
        }
    }
}
