package JKernelMachines.fr.lip6.classifier;

import JKernelMachines.fr.lip6.kernel.Kernel;
import JKernelMachines.fr.lip6.kernel.SimpleCacheKernel;
import JKernelMachines.fr.lip6.kernel.adaptative.ThreadedProductKernel;
import JKernelMachines.fr.lip6.threading.ThreadedMatrixOperator;
import JKernelMachines.fr.lip6.type.TrainingSample;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Hashtable;
import java.util.List;
import org.apache.xpath.XPath;

/* loaded from: input_file:JKernelMachines/fr/lip6/classifier/SOGradPKL.class */
public class SOGradPKL<T> implements Classifier<T> {
    SMOSVM<T> svm;
    double d_lambda;
    double oldObjective;
    double stopGap = 1.0E-5d;
    double num_cleaning = 1.0E-8d;
    double p_norm = 1.0d;
    double C = 1.0d;
    boolean traceNorm = false;
    double[][] lambda_matrix = (double[][]) null;
    boolean cache = true;
    private int VERBOSITY_LEVEL = 0;
    List<Kernel<T>> listOfKernels = new ArrayList();
    List<Double> listOfKernelWeights = new ArrayList();
    List<TrainingSample<T>> listOfExamples = new ArrayList();
    List<Double> listOfExampleWeights = new ArrayList();

    public void addKernel(Kernel<T> kernel) {
        this.listOfKernels.add(kernel);
        this.listOfKernelWeights.add(Double.valueOf(1.0d));
    }

    @Override // JKernelMachines.fr.lip6.classifier.Classifier
    public void train(TrainingSample<T> trainingSample) {
        if (this.listOfExamples == null) {
            this.listOfExamples = new ArrayList();
        }
        if (!this.listOfExamples.contains(trainingSample)) {
            this.listOfExamples.add(trainingSample);
        }
        train(this.listOfExamples);
    }

    @Override // JKernelMachines.fr.lip6.classifier.Classifier
    public void train(List<TrainingSample<T>> list) {
        double d;
        long currentTimeMillis = System.currentTimeMillis();
        eprintln(2, "training on " + this.listOfKernels.size() + " kernels and " + list.size() + " examples");
        ArrayList<Kernel<T>> arrayList = new ArrayList<>();
        ArrayList<Double> arrayList2 = new ArrayList<>();
        for (int i = 0; i < this.listOfKernels.size(); i++) {
            if (this.cache) {
                eprintln(3, "+ cache is set, computing cache");
                Kernel<T> simpleCacheKernel = new SimpleCacheKernel<>(this.listOfKernels.get(i), list);
                simpleCacheKernel.setName(this.listOfKernels.get(i).toString());
                double[][] kernelMatrix = simpleCacheKernel.getKernelMatrix(list);
                if (this.traceNorm) {
                    double d2 = 0.0d;
                    for (int i2 = 0; i2 < kernelMatrix.length; i2++) {
                        d2 += kernelMatrix[i2][i2];
                    }
                    for (int i3 = 0; i3 < kernelMatrix.length; i3++) {
                        for (int i4 = i3; i4 < kernelMatrix.length; i4++) {
                            double[] dArr = kernelMatrix[i3];
                            int i5 = i4;
                            dArr[i5] = dArr[i5] * (kernelMatrix.length / d2);
                            kernelMatrix[i4][i3] = kernelMatrix[i3][i4];
                        }
                    }
                }
                arrayList.add(simpleCacheKernel);
            } else {
                eprintln(3, "+ cache is not set, skipping cache");
                arrayList.add(this.listOfKernels.get(i));
            }
            arrayList2.add(Double.valueOf(Math.pow(1.0d / this.listOfKernels.size(), 1.0d / this.p_norm)));
            eprintln(3, "+ kernel : " + arrayList.get(i) + " weight : " + arrayList2.get(i));
        }
        ThreadedProductKernel threadedProductKernel = new ThreadedProductKernel();
        for (int i6 = 0; i6 < arrayList.size(); i6++) {
            threadedProductKernel.addKernel(arrayList.get(i6), arrayList2.get(i6).doubleValue());
        }
        this.svm = new SMOSVM<>(threadedProductKernel);
        this.svm.setC(this.C);
        this.svm.setVerbosityLevel(this.VERBOSITY_LEVEL - 1);
        eprintln(3, "+ training svm");
        this.svm.train(list);
        double[] alphas = this.svm.getAlphas();
        updateLambdaMatrix(alphas, threadedProductKernel, list);
        this.oldObjective = computeObj(alphas, threadedProductKernel, list);
        eprintln(3, "+ initial weights : " + arrayList2);
        do {
            double performPKLStep = performPKLStep(arrayList, arrayList2, list);
            if (performPKLStep < XPath.MATCH_SCORE_QNAME) {
                eprintln(1, "Error, performPKLStep return wrong value");
                System.exit(0);
            }
            d = 1.0d - performPKLStep;
            eprintln(1, "+ objective_gap : " + ((float) d));
            eprintln(1, "+");
        } while (d >= this.stopGap);
        this.listOfKernelWeights.clear();
        this.listOfKernelWeights.addAll(arrayList2);
        ThreadedProductKernel threadedProductKernel2 = new ThreadedProductKernel();
        for (int i7 = 0; i7 < arrayList.size(); i7++) {
            threadedProductKernel2.addKernel(this.listOfKernels.get(i7), this.listOfKernelWeights.get(i7).doubleValue());
        }
        this.svm.setKernel(threadedProductKernel2);
        eprintln(3, "+ retraining svm");
        this.svm.retrain();
        this.listOfExamples.addAll(list);
        this.listOfExampleWeights.clear();
        for (double d3 : this.svm.getAlphas()) {
            this.listOfExampleWeights.add(Double.valueOf(d3));
        }
        eprintln(1, "PKL trained in " + (System.currentTimeMillis() - currentTimeMillis) + " milis.");
    }

    private double performPKLStep(ArrayList<Kernel<T>> arrayList, ArrayList<Double> arrayList2, List<TrainingSample<T>> list) {
        double computeObj;
        double d = this.oldObjective;
        eprintln(3, "+++ old weights : " + arrayList2);
        eprintln(3, "+++ oldObjective : " + this.oldObjective + " sumAlpha : " + computeSumAlpha());
        double[] gradBeta = gradBeta(arrayList, arrayList2, list);
        double[] secondGradBeta = secondGradBeta(arrayList, arrayList2, list);
        double[] dArr = new double[gradBeta.length];
        this.d_lambda = 1.0d;
        while (true) {
            for (int i = 0; i < gradBeta.length; i++) {
                if (secondGradBeta[i] != XPath.MATCH_SCORE_QNAME) {
                    dArr[i] = arrayList2.get(i).doubleValue() * (1.0d - ((this.d_lambda * gradBeta[i]) / secondGradBeta[i]));
                }
                if (dArr[i] < this.num_cleaning) {
                    dArr[i] = 0.0d;
                }
            }
            double d2 = 0.0d;
            if (this.p_norm == 1.0d) {
                for (double d3 : dArr) {
                    d2 += Math.abs(d3);
                }
            } else {
                for (double d4 : dArr) {
                    d2 += Math.pow(d4, this.p_norm);
                }
                d2 = Math.pow(d2, (-1.0d) / this.p_norm);
            }
            if (d2 < XPath.MATCH_SCORE_QNAME) {
                eprintln(1, "Error normalization, norm < 0");
                return -1.0d;
            }
            eprintln(3, "+++ norm : " + d2);
            for (int i2 = 0; i2 < dArr.length; i2++) {
                int i3 = i2;
                dArr[i3] = dArr[i3] / d2;
            }
            ThreadedProductKernel threadedProductKernel = new ThreadedProductKernel();
            for (int i4 = 0; i4 < arrayList.size(); i4++) {
                threadedProductKernel.addKernel(arrayList.get(i4), dArr[i4]);
            }
            this.svm.setKernel(threadedProductKernel);
            eprintln(3, "+ retraining svm");
            this.svm.retrain();
            double[] alphas = this.svm.getAlphas();
            updateLambdaMatrix(alphas, threadedProductKernel, list);
            computeObj = computeObj(alphas, threadedProductKernel, list);
            if (computeObj >= this.oldObjective + this.num_cleaning) {
                if (this.d_lambda <= this.num_cleaning) {
                    this.d_lambda = XPath.MATCH_SCORE_QNAME;
                    eprint(3, "+++ d_lambda is zero, stopping.");
                    eprintln(2, "");
                    break;
                }
                this.d_lambda /= 8.0d;
                eprint(2, "+");
                eprintln(3, "++ new objective (" + ((float) computeObj) + ") did not decrease (" + ((float) this.oldObjective) + "), reducing step : " + this.d_lambda);
            } else {
                for (int i5 = 0; i5 < arrayList2.size(); i5++) {
                    arrayList2.set(i5, Double.valueOf(dArr[i5]));
                }
                eprintln(3, "+++ new weights : " + arrayList2);
            }
            if (this.oldObjective + this.num_cleaning >= computeObj) {
                break;
            }
        }
        eprintln(2, "+ objective : " + ((float) computeObj) + "\t+\t sumAlpha : " + ((float) computeSumAlpha()));
        double d5 = computeObj / this.oldObjective;
        this.oldObjective = computeObj;
        return d5;
    }

    private double[] gradBeta(ArrayList<Kernel<T>> arrayList, ArrayList<Double> arrayList2, List<TrainingSample<T>> list) {
        double[] dArr = new double[arrayList.size()];
        for (int i = 0; i < arrayList.size(); i++) {
            double[][] kernelMatrix = arrayList.get(i).getKernelMatrix(list);
            for (int i2 = 0; i2 < kernelMatrix.length; i2++) {
                for (int i3 = i2; i3 < kernelMatrix.length; i3++) {
                    if (kernelMatrix[i2][i3] != XPath.MATCH_SCORE_QNAME) {
                        int i4 = i;
                        dArr[i4] = dArr[i4] + ((-Math.log(kernelMatrix[i2][i3])) * this.lambda_matrix[i2][i3]);
                    }
                }
            }
        }
        eprintln(4, "++++++ gradDir : " + Arrays.toString(dArr));
        return dArr;
    }

    private double[] secondGradBeta(ArrayList<Kernel<T>> arrayList, ArrayList<Double> arrayList2, List<TrainingSample<T>> list) {
        double[] dArr = new double[arrayList.size()];
        for (int i = 0; i < arrayList.size(); i++) {
            double[][] kernelMatrix = arrayList.get(i).getKernelMatrix(list);
            for (int i2 = 0; i2 < kernelMatrix.length; i2++) {
                for (int i3 = i2; i3 < kernelMatrix.length; i3++) {
                    if (kernelMatrix[i2][i3] != XPath.MATCH_SCORE_QNAME) {
                        double log = Math.log(kernelMatrix[i2][i3]);
                        int i4 = i;
                        dArr[i4] = dArr[i4] + (log * log * this.lambda_matrix[i2][i3]);
                    }
                }
            }
        }
        for (int i5 = 0; i5 < dArr.length; i5++) {
            if (dArr[i5] < this.num_cleaning) {
                dArr[i5] = 0.0d;
            }
        }
        eprintln(4, "++++++ secondGradDir : " + Arrays.toString(dArr));
        return dArr;
    }

    private double computeSumAlpha() {
        double d = 0.0d;
        for (double d2 : this.svm.getAlphas()) {
            d += Math.abs(d2);
        }
        return d;
    }

    private double computeObj(double[] dArr, Kernel<T> kernel, List<TrainingSample<T>> list) {
        double d = 0.0d;
        for (double d2 : dArr) {
            d += d2;
        }
        int i = 0;
        while (i < this.lambda_matrix.length) {
            int i2 = i;
            while (i2 < this.lambda_matrix.length) {
                if (this.lambda_matrix[i][i2] != XPath.MATCH_SCORE_QNAME) {
                    d = i != i2 ? d + (2.0d * this.lambda_matrix[i][i2]) : d + this.lambda_matrix[i][i2];
                }
                i2++;
            }
            i++;
        }
        return d;
    }

    private void updateLambdaMatrix(final double[] dArr, Kernel<T> kernel, final List<TrainingSample<T>> list) {
        final double[][] kernelMatrix = kernel.getKernelMatrix(list);
        this.lambda_matrix = new double[kernelMatrix.length][kernelMatrix.length];
        eprintln(3, "+ update lambda");
        this.lambda_matrix = new ThreadedMatrixOperator() { // from class: JKernelMachines.fr.lip6.classifier.SOGradPKL.1
            @Override // JKernelMachines.fr.lip6.threading.ThreadedMatrixOperator
            public void doLine(int i, double[] dArr2) {
                double d = (-0.5d) * dArr[i] * ((TrainingSample) list.get(i)).label;
                for (int length = dArr2.length - 1; length != 0; length--) {
                    dArr2[length] = d * ((TrainingSample) list.get(length)).label * dArr[length] * kernelMatrix[i][length];
                }
            }
        }.getMatrix(this.lambda_matrix);
    }

    @Override // JKernelMachines.fr.lip6.classifier.Classifier
    public double valueOf(T t) {
        return this.svm.valueOf(t);
    }

    public void setVerbosityLevel(int i) {
        this.VERBOSITY_LEVEL = i;
    }

    public boolean isCache() {
        return this.cache;
    }

    public void setCache(boolean z) {
        this.cache = z;
    }

    public void eprint(int i, String str) {
        if (this.VERBOSITY_LEVEL >= i) {
            System.err.print(str);
        }
    }

    public void eprintln(int i, String str) {
        if (this.VERBOSITY_LEVEL >= i) {
            System.err.println(str);
        }
    }

    public double getC() {
        return this.C;
    }

    public void setC(double d) {
        this.C = d;
    }

    public void setMKLNorm(double d) {
        this.p_norm = d;
    }

    public void setStopGap(double d) {
        this.stopGap = d;
    }

    public void setTraceNorm(boolean z) {
        this.traceNorm = z;
    }

    public double getNum_cleaning() {
        return this.num_cleaning;
    }

    public void setNum_cleaning(double d) {
        this.num_cleaning = d;
    }

    public List<Double> getExampleWeights() {
        return this.listOfExampleWeights;
    }

    public List<Double> getKernelWeights() {
        return this.listOfKernelWeights;
    }

    public Hashtable<Kernel<T>, Double> getWeights() {
        Hashtable<Kernel<T>, Double> hashtable = new Hashtable<>();
        for (int i = 0; i < this.listOfKernels.size(); i++) {
            hashtable.put(this.listOfKernels.get(i), this.listOfKernelWeights.get(i));
        }
        return hashtable;
    }
}
