package JKernelMachines.fr.lip6.classifier.transductive;

import JKernelMachines.fr.lip6.classifier.DoublePegasosSVM;
import JKernelMachines.fr.lip6.type.TrainingSample;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.TreeSet;
import org.apache.xpath.XPath;

/* loaded from: input_file:JKernelMachines/fr/lip6/classifier/transductive/S3VMLightPegasos.class */
public class S3VMLightPegasos implements TransductiveClassifier<double[]> {
    ArrayList<TrainingSample<double[]>> train;
    ArrayList<TrainingSample<double[]>> test;
    DoublePegasosSVM svm;
    int numplus = 0;
    int T = 100000;
    int k = 10;
    double lambda = 0.001d;
    double t0 = 100.0d;
    boolean bias = true;
    private int VERBOSITY_LEVEL = 0;

    @Override // JKernelMachines.fr.lip6.classifier.transductive.TransductiveClassifier
    public void train(List<TrainingSample<double[]>> list, List<TrainingSample<double[]>> list2) {
        this.train = new ArrayList<>();
        this.train.addAll(list);
        this.test = new ArrayList<>();
        Iterator<TrainingSample<double[]>> it = list2.iterator();
        while (it.hasNext()) {
            this.test.add(new TrainingSample<>(it.next().sample, 0));
        }
        train();
    }

    private void train() {
        boolean z;
        eprintln(2, "training on " + this.train.size() + " train data and " + this.test.size() + " test data");
        eprint(3, "first training ");
        this.svm = new DoublePegasosSVM();
        this.svm.setLambda(this.lambda);
        this.svm.setK(this.k);
        this.svm.setT(this.T);
        this.svm.setT0(this.t0);
        this.svm.train(this.train);
        eprintln(3, " done.");
        eprintln(3, "affecting 1 to the " + this.numplus + " highest output");
        TreeSet<TrainingSample> treeSet = new TreeSet(new Comparator<TrainingSample<double[]>>() { // from class: JKernelMachines.fr.lip6.classifier.transductive.S3VMLightPegasos.1
            @Override // java.util.Comparator
            public int compare(TrainingSample<double[]> trainingSample, TrainingSample<double[]> trainingSample2) {
                int compareTo = new Double(S3VMLightPegasos.this.svm.valueOf(trainingSample2.sample)).compareTo(Double.valueOf(S3VMLightPegasos.this.svm.valueOf(trainingSample.sample)));
                if (compareTo == 0) {
                    compareTo = -1;
                }
                return compareTo;
            }
        });
        treeSet.addAll(this.test);
        eprintln(4, "sorted size : " + treeSet.size() + " test size : " + this.test.size());
        int i = 0;
        for (TrainingSample trainingSample : treeSet) {
            if (i < this.numplus) {
                trainingSample.label = 1;
            } else {
                trainingSample.label = -1;
            }
            i++;
        }
        double size = 1.0d / (this.train.size() * this.lambda);
        double d = 1.0E-5d;
        double size2 = (1.0E-5d * this.numplus) / (this.test.size() - this.numplus);
        while (true) {
            double d2 = size2;
            if (d >= size && d2 >= size) {
                eprintln(2, "training done");
                return;
            }
            ArrayList arrayList = new ArrayList();
            arrayList.addAll(this.train);
            arrayList.addAll(this.test);
            eprint(3, "full training ");
            this.svm = new DoublePegasosSVM();
            this.svm.setLambda(this.lambda);
            this.svm.setK(this.k);
            this.svm.setT(this.T);
            this.svm.setT0(this.t0);
            this.svm.train(arrayList);
            eprintln(3, "done.");
            do {
                z = false;
                final HashMap hashMap = new HashMap();
                Iterator<TrainingSample<double[]>> it = this.test.iterator();
                while (it.hasNext()) {
                    TrainingSample<double[]> next = it.next();
                    hashMap.put(next, Double.valueOf(1.0d - (next.label * this.svm.valueOf(next.sample))));
                }
                eprintln(3, "Error cache done.");
                TreeSet treeSet2 = new TreeSet(new Comparator<TrainingSample<double[]>>() { // from class: JKernelMachines.fr.lip6.classifier.transductive.S3VMLightPegasos.2
                    @Override // java.util.Comparator
                    public int compare(TrainingSample<double[]> trainingSample2, TrainingSample<double[]> trainingSample3) {
                        int compareTo = ((Double) hashMap.get(trainingSample3)).compareTo((Double) hashMap.get(trainingSample2));
                        if (compareTo == 0) {
                            compareTo = -1;
                        }
                        return compareTo;
                    }
                });
                treeSet2.addAll(this.test);
                ArrayList arrayList2 = new ArrayList();
                arrayList2.addAll(treeSet2);
                eprintln(3, "sorting done, checking couple");
                for (int i2 = 0; i2 < arrayList2.size(); i2++) {
                    TrainingSample<double[]> trainingSample2 = (TrainingSample) arrayList2.get(i2);
                    int i3 = i2 + 1;
                    while (true) {
                        if (i3 >= arrayList2.size()) {
                            break;
                        }
                        if (examine(trainingSample2, (TrainingSample) arrayList2.get(i3), hashMap)) {
                            eprintln(3, "couple found !");
                            z = true;
                            break;
                        }
                        i3++;
                    }
                    if (z) {
                        break;
                    }
                }
                if (z) {
                    eprintln(3, "re-training");
                    this.svm = new DoublePegasosSVM();
                    this.svm.setLambda(this.lambda);
                    this.svm.setK(this.k);
                    this.svm.setT(this.T);
                    this.svm.setT0(this.t0);
                    this.svm.train(arrayList);
                }
            } while (z);
            eprintln(3, "increasing C+ : " + d2 + " and C- : " + d);
            d = Math.min(2.0d * d, size);
            size2 = Math.min(2.0d * d2, size);
        }
    }

    private boolean examine(TrainingSample<double[]> trainingSample, TrainingSample<double[]> trainingSample2, Map<TrainingSample<double[]>, Double> map) {
        if (trainingSample.label * trainingSample2.label > 0 || !map.containsKey(trainingSample)) {
            return false;
        }
        double doubleValue = map.get(trainingSample).doubleValue();
        if (doubleValue <= XPath.MATCH_SCORE_QNAME || !map.containsKey(trainingSample2)) {
            return false;
        }
        double doubleValue2 = map.get(trainingSample2).doubleValue();
        if (doubleValue2 <= XPath.MATCH_SCORE_QNAME) {
            return false;
        }
        eprintln(4, "y1 : " + trainingSample.label + " err1 : " + doubleValue + " y2 : " + trainingSample2.label + " err2 : " + doubleValue2);
        if (doubleValue + doubleValue2 <= 2.0d) {
            return false;
        }
        int i = trainingSample.label;
        trainingSample.label = trainingSample2.label;
        trainingSample2.label = i;
        return true;
    }

    @Override // JKernelMachines.fr.lip6.classifier.transductive.TransductiveClassifier
    public double valueOf(double[] dArr) {
        return this.svm.valueOf(dArr);
    }

    public int getT() {
        return this.T;
    }

    public void setT(int i) {
        this.T = i;
    }

    public int getK() {
        return this.k;
    }

    public void setK(int i) {
        this.k = i;
    }

    public double getLambda() {
        return this.lambda;
    }

    public void setLambda(double d) {
        this.lambda = d;
    }

    public double getT0() {
        return this.t0;
    }

    public void setT0(double d) {
        this.t0 = d;
    }

    public boolean isBias() {
        return this.bias;
    }

    public void setBias(boolean z) {
        this.bias = z;
    }

    public int getNumplus() {
        return this.numplus;
    }

    public void setNumplus(int i) {
        this.numplus = i;
    }

    public void setVerbosityLevel(int i) {
        this.VERBOSITY_LEVEL = i;
    }

    public void eprint(int i, String str) {
        if (this.VERBOSITY_LEVEL >= i) {
            System.err.print(str);
        }
    }

    public void eprintln(int i, String str) {
        if (this.VERBOSITY_LEVEL >= i) {
            System.err.println(str);
        }
    }
}
