package JKernelMachines.fr.lip6.classifier;

import JKernelMachines.fr.lip6.kernel.Kernel;
import JKernelMachines.fr.lip6.kernel.SimpleCacheKernel;
import JKernelMachines.fr.lip6.kernel.adaptative.ThreadedProductKernel;
import JKernelMachines.fr.lip6.type.TrainingSample;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Hashtable;
import java.util.List;
import org.apache.xpath.XPath;

/* loaded from: input_file:JKernelMachines/fr/lip6/classifier/GradPKL.class */
public class GradPKL<T> implements Classifier<T> {
    SMOSVM<T> svm;
    double stopGap = 0.001d;
    double eps_regul = 0.001d;
    double num_cleaning = 1.0E-7d;
    double p_norm = 1.0d;
    double C = 100000.0d;
    private int VERBOSITY_LEVEL = 0;
    ArrayList<Kernel<T>> listOfKernels = new ArrayList<>();
    ArrayList<Double> listOfKernelWeights = new ArrayList<>();
    ArrayList<TrainingSample<T>> listOfExamples = new ArrayList<>();
    ArrayList<Double> listOfExampleWeights = new ArrayList<>();

    public void addKernel(Kernel<T> kernel) {
        this.listOfKernels.add(kernel);
        this.listOfKernelWeights.add(Double.valueOf(1.0d));
    }

    @Override // JKernelMachines.fr.lip6.classifier.Classifier
    public void train(TrainingSample<T> trainingSample) {
    }

    @Override // JKernelMachines.fr.lip6.classifier.Classifier
    public void train(List<TrainingSample<T>> list) {
        double d;
        long currentTimeMillis = System.currentTimeMillis();
        eprintln(2, "training on " + this.listOfKernels.size() + " kernels and " + list.size() + " examples");
        ArrayList<SimpleCacheKernel<T>> arrayList = new ArrayList<>();
        ArrayList<Double> arrayList2 = new ArrayList<>();
        for (int i = 0; i < this.listOfKernels.size(); i++) {
            SimpleCacheKernel<T> simpleCacheKernel = new SimpleCacheKernel<>(this.listOfKernels.get(i), list);
            simpleCacheKernel.setName(this.listOfKernels.get(i).toString());
            double[][] kernelMatrix = simpleCacheKernel.getKernelMatrix(list);
            double d2 = 0.0d;
            for (int i2 = 0; i2 < kernelMatrix.length; i2++) {
                d2 += kernelMatrix[i2][i2];
            }
            for (int i3 = 0; i3 < kernelMatrix.length; i3++) {
                for (int i4 = i3; i4 < kernelMatrix.length; i4++) {
                    double[] dArr = kernelMatrix[i3];
                    int i5 = i4;
                    dArr[i5] = dArr[i5] * (kernelMatrix.length / d2);
                    kernelMatrix[i4][i3] = kernelMatrix[i3][i4];
                }
            }
            arrayList.add(simpleCacheKernel);
            arrayList2.add(Double.valueOf(Math.pow(1.0d / this.listOfKernels.size(), 1.0d / this.p_norm)));
            eprintln(3, "kernel : " + simpleCacheKernel + " weight : " + arrayList2.get(i));
        }
        ThreadedProductKernel threadedProductKernel = new ThreadedProductKernel();
        for (int i6 = 0; i6 < arrayList.size(); i6++) {
            threadedProductKernel.addKernel(arrayList.get(i6), arrayList2.get(i6).doubleValue());
        }
        this.svm = new SMOSVM<>(threadedProductKernel);
        this.svm.setC(this.C);
        this.svm.setVerbosityLevel(this.VERBOSITY_LEVEL);
        this.svm.train(list);
        do {
            eprintln(3, "weights : " + arrayList2);
            ThreadedProductKernel threadedProductKernel2 = new ThreadedProductKernel();
            for (int i7 = 0; i7 < arrayList.size(); i7++) {
                threadedProductKernel2.addKernel(arrayList.get(i7), arrayList2.get(i7).doubleValue());
            }
            this.svm.setKernel(threadedProductKernel2);
            this.svm.retrain();
            double performPKLStep = performPKLStep(computeSumAlpha(), computeGradBeta(arrayList, arrayList2, list), arrayList, arrayList2, list);
            if (performPKLStep < XPath.MATCH_SCORE_QNAME) {
                eprintln(1, "Error, performMKLStep return wrong value");
                System.exit(0);
            }
            d = 1.0d - performPKLStep;
            double d3 = 0.0d;
            for (int i8 = 0; i8 < arrayList2.size(); i8++) {
                d3 += Math.pow(arrayList2.get(i8).doubleValue(), this.p_norm);
            }
            eprintln(1, "objective_gap : " + d + " norm : " + Math.pow(d3, (-1.0d) / this.p_norm));
        } while (d >= this.stopGap);
        this.listOfKernelWeights.clear();
        this.listOfKernelWeights.addAll(arrayList2);
        ThreadedProductKernel threadedProductKernel3 = new ThreadedProductKernel();
        for (int i9 = 0; i9 < arrayList.size(); i9++) {
            threadedProductKernel3.addKernel(this.listOfKernels.get(i9), this.listOfKernelWeights.get(i9).doubleValue());
        }
        this.svm.setKernel(threadedProductKernel3);
        this.svm.retrain();
        this.listOfExamples.addAll(list);
        this.listOfExampleWeights.clear();
        for (double d4 : this.svm.getAlphas()) {
            this.listOfExampleWeights.add(Double.valueOf(d4));
        }
        eprintln(1, "MKL trained in " + (System.currentTimeMillis() - currentTimeMillis) + " milis.");
    }

    private double performPKLStep(double d, double[] dArr, ArrayList<SimpleCacheKernel<T>> arrayList, ArrayList<Double> arrayList2, List<TrainingSample<T>> list) {
        eprint(2, ".");
        double d2 = d;
        ThreadedProductKernel threadedProductKernel = new ThreadedProductKernel();
        for (int i = 0; i < arrayList.size(); i++) {
            threadedProductKernel.addKernel(arrayList.get(i), arrayList2.get(i).doubleValue());
        }
        double[][] kernelMatrix = threadedProductKernel.getKernelMatrix(list);
        double[] alphas = this.svm.getAlphas();
        for (int i2 = 0; i2 < kernelMatrix.length; i2++) {
            int i3 = list.get(i2).label;
            for (int i4 = i2; i4 < kernelMatrix.length; i4++) {
                if (kernelMatrix[i2][i4] != XPath.MATCH_SCORE_QNAME) {
                    d2 += (i2 == i4 ? -0.5d : -1.0d) * i3 * list.get(i4).label * alphas[i2] * alphas[i4] * kernelMatrix[i2][i4];
                }
            }
        }
        eprintln(3, "old weights : 3" + arrayList2);
        eprintln(3, "oldObjective : " + d2 + " sumAlpha : " + d);
        double[] dArr2 = new double[dArr.length];
        for (int i5 = 0; i5 < dArr.length; i5++) {
            dArr2[i5] = arrayList2.get(i5).doubleValue() + (dArr[i5] * arrayList2.get(i5).doubleValue());
            if (dArr2[i5] < this.num_cleaning) {
                dArr2[i5] = 0.0d;
            }
        }
        double d3 = 0.0d;
        for (double d4 : dArr2) {
            d3 += Math.pow(d4, this.p_norm);
        }
        double pow = Math.pow(d3, (-1.0d) / this.p_norm);
        if (pow < XPath.MATCH_SCORE_QNAME) {
            eprintln(1, "Error normalization, norm < 0");
            return -1.0d;
        }
        for (int i6 = 0; i6 < dArr2.length; i6++) {
            int i7 = i6;
            dArr2[i7] = dArr2[i7] * pow;
        }
        for (int i8 = 0; i8 < arrayList2.size(); i8++) {
            arrayList2.set(i8, Double.valueOf(dArr2[i8]));
        }
        ThreadedProductKernel threadedProductKernel2 = new ThreadedProductKernel();
        for (int i9 = 0; i9 < arrayList.size(); i9++) {
            threadedProductKernel2.addKernel(arrayList.get(i9), arrayList2.get(i9).doubleValue());
        }
        double[][] kernelMatrix2 = threadedProductKernel2.getKernelMatrix(list);
        this.svm.setKernel(threadedProductKernel2);
        this.svm.retrain();
        double[] alphas2 = this.svm.getAlphas();
        double computeSumAlpha = computeSumAlpha();
        for (int i10 = 0; i10 < kernelMatrix2.length; i10++) {
            int i11 = list.get(i10).label;
            for (int i12 = i10; i12 < kernelMatrix2.length; i12++) {
                if (kernelMatrix2[i10][i12] != XPath.MATCH_SCORE_QNAME) {
                    computeSumAlpha += (i10 == i12 ? -0.5d : -1.0d) * i11 * list.get(i12).label * alphas2[i10] * alphas2[i12] * kernelMatrix2[i10][i12];
                }
            }
        }
        eprintln(3, "new weights : " + arrayList2);
        eprintln(3, "new objective : " + computeSumAlpha + " sumAlpha : " + d);
        return computeSumAlpha / d2;
    }

    private double[] computeGradBeta(ArrayList<SimpleCacheKernel<T>> arrayList, ArrayList<Double> arrayList2, List<TrainingSample<T>> list) {
        double[] dArr = new double[arrayList.size()];
        ThreadedProductKernel threadedProductKernel = new ThreadedProductKernel();
        for (int i = 0; i < arrayList.size(); i++) {
            threadedProductKernel.addKernel(arrayList.get(i), arrayList2.get(i).doubleValue());
        }
        double[][] kernelMatrix = threadedProductKernel.getKernelMatrix(list);
        for (int i2 = 0; i2 < arrayList.size(); i2++) {
            double[][] kernelMatrix2 = arrayList.get(i2).getKernelMatrix(list);
            double[] alphas = this.svm.getAlphas();
            for (int i3 = 0; i3 < kernelMatrix2.length; i3++) {
                int i4 = list.get(i3).label;
                for (int i5 = i3; i5 < kernelMatrix2.length; i5++) {
                    if (kernelMatrix2[i3][i5] != XPath.MATCH_SCORE_QNAME) {
                        int i6 = i2;
                        dArr[i6] = dArr[i6] + ((i3 == i5 ? 0.5d : 1.0d) * i4 * list.get(i5).label * alphas[i3] * alphas[i5] * Math.log(kernelMatrix2[i3][i5]) * kernelMatrix[i3][i5]);
                    }
                }
            }
        }
        eprintln(3, "gradDir : " + Arrays.toString(dArr));
        return dArr;
    }

    private double computeSumAlpha() {
        double d = 0.0d;
        for (double d2 : this.svm.getAlphas()) {
            d += Math.abs(d2);
        }
        return d;
    }

    @Override // JKernelMachines.fr.lip6.classifier.Classifier
    public double valueOf(T t) {
        return this.svm.valueOf(t);
    }

    public void setVerbosityLevel(int i) {
        this.VERBOSITY_LEVEL = i;
    }

    public void eprint(int i, String str) {
        if (this.VERBOSITY_LEVEL >= i) {
            System.err.print(str);
        }
    }

    public void eprintln(int i, String str) {
        if (this.VERBOSITY_LEVEL >= i) {
            System.err.println(str);
        }
    }

    public double getC() {
        return this.C;
    }

    public void setC(double d) {
        this.C = d;
    }

    public void setMKLNorm(double d) {
        this.p_norm = d;
    }

    public void setStopGap(double d) {
        this.stopGap = d;
    }

    public ArrayList<Double> getExampleWeights() {
        return this.listOfExampleWeights;
    }

    public ArrayList<Double> getKernelWeights() {
        return this.listOfKernelWeights;
    }

    public Hashtable<Kernel<T>, Double> getWeights() {
        Hashtable<Kernel<T>, Double> hashtable = new Hashtable<>();
        for (int i = 0; i < this.listOfKernels.size(); i++) {
            hashtable.put(this.listOfKernels.get(i), this.listOfKernelWeights.get(i));
        }
        return hashtable;
    }
}
