package JKernelMachines.fr.lip6.classifier;

import JKernelMachines.fr.lip6.kernel.typed.DoubleLinear;
import JKernelMachines.fr.lip6.type.TrainingSample;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import org.apache.xpath.XPath;

/* loaded from: input_file:JKernelMachines/fr/lip6/classifier/DoublePegasosSVM.class */
public class DoublePegasosSVM implements Classifier<double[]>, Serializable {
    private static final long serialVersionUID = 5289136605543751554L;
    private List<TrainingSample<double[]>> tList;
    private double[] w;
    private DoubleLinear kernel = new DoubleLinear();
    private double b = XPath.MATCH_SCORE_QNAME;
    int T = 100000;
    int k = 10;
    double lambda = 0.001d;
    double t0 = 100.0d;
    boolean bias = true;
    double C = 1.0d;
    boolean hasC = false;
    private int VERBOSITY_LEVEL = 0;

    @Override // JKernelMachines.fr.lip6.classifier.Classifier
    public void train(List<TrainingSample<double[]>> list) {
        Random random = new Random(System.currentTimeMillis());
        if (this.k > list.size()) {
            this.k = list.size();
        }
        this.tList = list;
        int length = this.tList.get(0).sample.length;
        this.w = new double[length];
        for (int i = 0; i < length; i++) {
            this.w[i] = 0.0d;
        }
        this.b = XPath.MATCH_SCORE_QNAME;
        if (this.hasC) {
            this.lambda = 1.0d / (this.C * this.tList.size());
        }
        eprintln(1, "begin training");
        long currentTimeMillis = System.currentTimeMillis();
        for (int i2 = 0; i2 < this.T; i2++) {
            ArrayList arrayList = new ArrayList();
            while (arrayList.size() < this.k) {
                int nextInt = random.nextInt(this.tList.size());
                if (!arrayList.contains(Integer.valueOf(nextInt))) {
                    arrayList.add(Integer.valueOf(nextInt));
                }
            }
            Iterator it = arrayList.iterator();
            while (it.hasNext()) {
                if ((this.kernel.valueOf(this.w, this.tList.get(((Integer) it.next()).intValue()).sample) - this.b) * this.tList.get(r0.intValue()).label > 1.0d) {
                    it.remove();
                }
            }
            double d = 1.0d / (this.lambda * (i2 + this.t0));
            double[] dArr = new double[this.w.length];
            for (int i3 = 0; i3 < length; i3++) {
                dArr[i3] = (1.0d - (d * this.lambda)) * this.w[i3];
            }
            Iterator it2 = arrayList.iterator();
            while (it2.hasNext()) {
                TrainingSample<double[]> trainingSample = this.tList.get(((Integer) it2.next()).intValue());
                for (int i4 = 0; i4 < length; i4++) {
                    if (trainingSample.sample[i4] != XPath.MATCH_SCORE_QNAME) {
                        int i5 = i4;
                        dArr[i5] = dArr[i5] + ((d / this.k) * trainingSample.label * trainingSample.sample[i4]);
                    }
                }
            }
            double d2 = 0.0d;
            if (this.bias) {
                while (arrayList.iterator().hasNext()) {
                    d2 += this.tList.get(((Integer) r0.next()).intValue()).label;
                }
            }
            double sqrt = (1.0d / Math.sqrt(this.lambda)) / Math.sqrt(this.kernel.valueOf(dArr, dArr));
            if (sqrt > 1.0d) {
                sqrt = 1.0d;
            }
            double[] dArr2 = (double[]) dArr.clone();
            for (int i6 = 0; i6 < length; i6++) {
                dArr2[i6] = dArr[i6] * sqrt;
            }
            this.w = dArr2;
            if (this.bias) {
                this.b = sqrt * (((1.0d - (d * this.lambda)) * this.b) - ((d / this.k) * d2));
            } else {
                this.b = XPath.MATCH_SCORE_QNAME;
            }
            eprintln(4, "w : " + Arrays.toString(this.w) + " b : " + this.b);
            if (this.T > 20 && i2 % (this.T / 20) == 0) {
                eprint(2, ".");
            }
        }
        eprintln(2, "");
        eprintln(1, "done in " + (System.currentTimeMillis() - currentTimeMillis) + " ms");
        eprintln(3, "w : " + Arrays.toString(this.w) + " b : " + this.b);
    }

    @Override // JKernelMachines.fr.lip6.classifier.Classifier
    public void train(TrainingSample<double[]> trainingSample) {
        if (this.tList == null) {
            this.tList = new ArrayList();
        }
        this.tList.add(trainingSample);
        train(this.tList);
    }

    @Override // JKernelMachines.fr.lip6.classifier.Classifier
    public double valueOf(double[] dArr) {
        return this.kernel.valueOf(this.w, dArr) - this.b;
    }

    public int getT() {
        return this.T;
    }

    public void setT(int i) {
        this.T = i;
    }

    public int getK() {
        return this.k;
    }

    public void setK(int i) {
        this.k = i;
    }

    public double getLambda() {
        return this.lambda;
    }

    public void setLambda(double d) {
        this.lambda = d;
    }

    public double[] getW() {
        return this.w;
    }

    public void setW(double[] dArr) {
        this.w = dArr;
    }

    public double getB() {
        return this.b;
    }

    public void setB(double d) {
        this.b = d;
    }

    public boolean isBias() {
        return this.bias;
    }

    public void setBias(boolean z) {
        this.bias = z;
    }

    public void setVerbosityLevel(int i) {
        this.VERBOSITY_LEVEL = i;
    }

    public void eprint(int i, String str) {
        if (this.VERBOSITY_LEVEL >= i) {
            System.err.print(str);
        }
    }

    public void eprintln(int i, String str) {
        if (this.VERBOSITY_LEVEL >= i) {
            System.err.println(str);
        }
    }

    public double getT0() {
        return this.t0;
    }

    public void setT0(double d) {
        this.t0 = d;
    }

    public void setC(double d) {
        this.hasC = true;
        this.C = d;
    }
}
