package JKernelMachines.fr.lip6.classifier;

import JKernelMachines.fr.lip6.kernel.Kernel;
import JKernelMachines.fr.lip6.type.TrainingSample;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.TreeSet;
import java.util.logging.Logger;
import org.apache.xpath.XPath;

/* loaded from: input_file:JKernelMachines/fr/lip6/classifier/EnhancedSMOSVM.class */
public class EnhancedSMOSVM<T> implements Classifier<T>, Serializable {
    private static final long serialVersionUID = -3409002064617582128L;
    private double[] alphay;
    private double[] alpha;
    private ArrayList<TrainingSample<T>> ts;
    int size;
    private Kernel<T> kernel;
    private double[][] kcache;
    private double b;
    protected double m_bLow;
    protected double m_bUp;
    protected int m_iLow;
    protected int m_iUp;
    protected ArrayList<T> m_data;
    protected double[] m_weights;
    protected double[] m_sparseWeights;
    protected int[] m_sparseIndices;
    protected double[] m_class;
    protected double[] m_errors;
    protected TreeSet<Integer> m_I0;
    protected TreeSet<Integer> m_I1;
    protected TreeSet<Integer> m_I2;
    protected TreeSet<Integer> m_I3;
    protected TreeSet<Integer> m_I4;
    protected TreeSet<Integer> m_supportVectors;
    private static Logger logger = Logger.getLogger(EnhancedSMOSVM.class.toString());
    protected static double m_Del = 4.94E-321d;
    private double C = 100000.0d;
    private double eps = 1.0E-12d;
    private double tolerance = 1.0E-10d;
    protected double m_sumOfWeights = XPath.MATCH_SCORE_QNAME;

    public EnhancedSMOSVM(Kernel<T> kernel) {
        this.kernel = kernel;
    }

    @Override // JKernelMachines.fr.lip6.classifier.Classifier
    public void train(TrainingSample<T> trainingSample) {
        if (this.ts == null) {
            this.ts = new ArrayList<>();
        }
        this.ts.add(trainingSample);
        double[] copyOf = Arrays.copyOf(this.alpha, this.alpha.length + 1);
        copyOf[this.alpha.length] = 0.0d;
        this.alpha = (double[]) copyOf.clone();
        this.size = this.ts.size();
        train();
    }

    @Override // JKernelMachines.fr.lip6.classifier.Classifier
    public void train(List<TrainingSample<T>> list) {
        this.ts = new ArrayList<>(list);
        this.alpha = new double[this.ts.size()];
        Arrays.fill(this.alpha, XPath.MATCH_SCORE_QNAME);
        this.size = this.ts.size();
        train();
    }

    private void train() {
        this.m_bUp = -1.0d;
        this.m_bLow = 1.0d;
        this.b = XPath.MATCH_SCORE_QNAME;
        this.alpha = null;
        this.m_data = null;
        this.m_weights = null;
        this.m_errors = null;
        this.m_I0 = null;
        this.m_I1 = null;
        this.m_I2 = null;
        this.m_I3 = null;
        this.m_I4 = null;
        this.m_sparseWeights = null;
        this.m_sparseIndices = null;
        this.m_class = new double[this.size];
        this.m_iUp = -1;
        this.m_iLow = -1;
        for (int i = 0; i < this.m_class.length; i++) {
            if (this.ts.get(i).label == -1) {
                this.m_class[i] = -1.0d;
                this.m_iLow = i;
            } else if (this.ts.get(i).label != 1) {
                logger.severe("this should never happen !!!");
                return;
            } else {
                this.m_class[i] = 1.0d;
                this.m_iUp = i;
            }
        }
        if (this.m_iUp == -1 || this.m_iLow == -1) {
            if (this.m_iUp != -1) {
                this.b = -1.0d;
            } else {
                if (this.m_iLow == -1) {
                    this.m_class = null;
                    return;
                }
                this.b = 1.0d;
            }
            this.m_supportVectors = new TreeSet<>();
            this.alpha = new double[0];
            this.m_class = new double[0];
            return;
        }
        this.m_data = new ArrayList<>();
        Iterator<TrainingSample<T>> it = this.ts.iterator();
        while (it.hasNext()) {
            this.m_data.add(it.next().sample);
        }
        this.alpha = new double[this.size];
        this.m_supportVectors = new TreeSet<>();
        this.m_I0 = new TreeSet<>();
        this.m_I1 = new TreeSet<>();
        this.m_I2 = new TreeSet<>();
        this.m_I3 = new TreeSet<>();
        this.m_I4 = new TreeSet<>();
        this.m_sparseWeights = null;
        this.m_sparseIndices = null;
        logger.info("Building kernel cache");
        this.kcache = this.kernel.getKernelMatrix(this.ts);
        logger.info("Kernel cache built");
        this.m_errors = new double[this.size];
        this.m_errors[this.m_iLow] = 1.0d;
        this.m_errors[this.m_iUp] = -1.0d;
        for (int i2 = 0; i2 < this.m_class.length; i2++) {
            if (this.m_class[i2] == 1.0d) {
                this.m_I1.add(Integer.valueOf(i2));
            } else {
                this.m_I4.add(Integer.valueOf(i2));
            }
        }
        int i3 = 0;
        boolean z = true;
        int i4 = 0;
        while (true) {
            if (i3 <= 0 && !z) {
                break;
            }
            i3 = 0;
            if (z) {
                for (int i5 = 0; i5 < this.alpha.length; i5++) {
                    if (examineExample(i5)) {
                        i3++;
                    }
                }
            } else {
                int i6 = 0;
                while (true) {
                    if (i6 >= this.alpha.length) {
                        break;
                    }
                    if (this.alpha[i6] > XPath.MATCH_SCORE_QNAME && this.alpha[i6] < this.C * 1.0d) {
                        if (examineExample(i6)) {
                            i3++;
                        }
                        if (this.m_bUp > this.m_bLow - (2.0d * this.tolerance)) {
                            i3 = 0;
                            break;
                        }
                    }
                    i6++;
                }
            }
            if (z) {
                z = false;
            } else if (i3 == 0) {
                z = true;
            }
            i4++;
            if (i4 % 100 == 0) {
                logger.info("iteration : " + i4);
            }
        }
        this.b = (this.m_bLow + this.m_bUp) / 2.0d;
        this.m_errors = null;
        this.m_I4 = null;
        this.m_I3 = null;
        this.m_I2 = null;
        this.m_I1 = null;
        this.m_I0 = null;
        this.alphay = new double[this.alpha.length];
        for (int i7 = 0; i7 < this.alpha.length; i7++) {
            this.alphay[i7] = this.alpha[i7] * this.ts.get(i7).label;
        }
        logger.info("Training done in " + i4 + " iterations.");
    }

    protected boolean examineExample(int i) {
        double valueOf;
        int i2 = -1;
        double d = this.m_class[i];
        if (this.m_I0.contains(Integer.valueOf(i))) {
            valueOf = this.m_errors[i];
        } else {
            valueOf = (valueOf(this.m_data.get(i)) + this.b) - d;
            this.m_errors[i] = valueOf;
            if ((this.m_I1.contains(Integer.valueOf(i)) || this.m_I2.contains(Integer.valueOf(i))) && valueOf < this.m_bUp) {
                this.m_bUp = valueOf;
                this.m_iUp = i;
            } else if ((this.m_I3.contains(Integer.valueOf(i)) || this.m_I4.contains(Integer.valueOf(i))) && valueOf > this.m_bLow) {
                this.m_bLow = valueOf;
                this.m_iLow = i;
            }
        }
        boolean z = true;
        if ((this.m_I0.contains(Integer.valueOf(i)) || this.m_I1.contains(Integer.valueOf(i)) || this.m_I2.contains(Integer.valueOf(i))) && this.m_bLow - valueOf > 2.0d * this.tolerance) {
            z = false;
            i2 = this.m_iLow;
        }
        if ((this.m_I0.contains(Integer.valueOf(i)) || this.m_I3.contains(Integer.valueOf(i)) || this.m_I4.contains(Integer.valueOf(i))) && valueOf - this.m_bUp > 2.0d * this.tolerance) {
            z = false;
            i2 = this.m_iUp;
        }
        if (z) {
            return false;
        }
        if (this.m_I0.contains(Integer.valueOf(i))) {
            i2 = this.m_bLow - valueOf > valueOf - this.m_bUp ? this.m_iLow : this.m_iUp;
        }
        if (i2 != -1) {
            return takeStep(i2, i, valueOf);
        }
        logger.severe("This should never happen!");
        return false;
    }

    protected boolean takeStep(int i, int i2, double d) {
        double max;
        double min;
        double d2;
        double d3 = this.C;
        double d4 = this.C;
        if (i == i2) {
            return false;
        }
        double d5 = this.alpha[i];
        double d6 = this.alpha[i2];
        double d7 = this.m_class[i];
        double d8 = this.m_class[i2];
        double d9 = this.m_errors[i];
        double d10 = d7 * d8;
        if (d7 != d8) {
            max = Math.max(XPath.MATCH_SCORE_QNAME, d6 - d5);
            min = Math.min(d4, (d3 + d6) - d5);
        } else {
            max = Math.max(XPath.MATCH_SCORE_QNAME, (d5 + d6) - d3);
            min = Math.min(d4, d5 + d6);
        }
        if (max >= min) {
            return false;
        }
        double d11 = this.kcache[i][i];
        double d12 = this.kcache[i][i2];
        double d13 = this.kcache[i2][i2];
        double d14 = ((2.0d * d12) - d11) - d13;
        if (d14 < XPath.MATCH_SCORE_QNAME) {
            d2 = d6 - ((d8 * (d9 - d)) / d14);
            if (d2 < max) {
                d2 = max;
            } else if (d2 > min) {
                d2 = min;
            }
        } else {
            double valueOf = valueOf(this.m_data.get(i));
            double valueOf2 = valueOf(this.m_data.get(i2));
            double d15 = ((valueOf + this.b) - ((d7 * d5) * d11)) - ((d8 * d6) * d12);
            double d16 = ((valueOf2 + this.b) - ((d7 * d5) * d12)) - ((d8 * d6) * d13);
            double d17 = d5 + (d10 * d6);
            double d18 = ((((((d17 - (d10 * max)) + max) - (((0.5d * d11) * (d17 - (d10 * max))) * (d17 - (d10 * max)))) - (((0.5d * d13) * max) * max)) - (((d10 * d12) * (d17 - (d10 * max))) * max)) - ((d7 * (d17 - (d10 * max))) * d15)) - ((d8 * max) * d16);
            double d19 = ((((((d17 - (d10 * min)) + min) - (((0.5d * d11) * (d17 - (d10 * min))) * (d17 - (d10 * min)))) - (((0.5d * d13) * min) * min)) - (((d10 * d12) * (d17 - (d10 * min))) * min)) - ((d7 * (d17 - (d10 * min))) * d15)) - ((d8 * min) * d16);
            d2 = d18 > d19 + this.eps ? max : d18 < d19 - this.eps ? min : d6;
        }
        if (Math.abs(d2 - d6) < this.eps * (d2 + d6 + this.eps)) {
            return false;
        }
        if (d2 > d4 - (m_Del * d4)) {
            d2 = d4;
        } else if (d2 <= m_Del * d4) {
            d2 = 0.0d;
        }
        double d20 = d5 + (d10 * (d6 - d2));
        if (d20 > d3 - (m_Del * d3)) {
            d20 = d3;
        } else if (d20 <= m_Del * d3) {
            d20 = 0.0d;
        }
        if (d20 > XPath.MATCH_SCORE_QNAME) {
            this.m_supportVectors.add(Integer.valueOf(i));
        } else {
            this.m_supportVectors.remove(Integer.valueOf(i));
        }
        if (d20 <= XPath.MATCH_SCORE_QNAME || d20 >= d3) {
            this.m_I0.remove(Integer.valueOf(i));
        } else {
            this.m_I0.add(Integer.valueOf(i));
        }
        if (d7 == 1.0d && d20 == XPath.MATCH_SCORE_QNAME) {
            this.m_I1.add(Integer.valueOf(i));
        } else {
            this.m_I1.remove(Integer.valueOf(i));
        }
        if (d7 == -1.0d && d20 == d3) {
            this.m_I2.add(Integer.valueOf(i));
        } else {
            this.m_I2.remove(Integer.valueOf(i));
        }
        if (d7 == 1.0d && d20 == d3) {
            this.m_I3.add(Integer.valueOf(i));
        } else {
            this.m_I3.remove(Integer.valueOf(i));
        }
        if (d7 == -1.0d && d20 == XPath.MATCH_SCORE_QNAME) {
            this.m_I4.add(Integer.valueOf(i));
        } else {
            this.m_I4.remove(Integer.valueOf(i));
        }
        if (d2 > XPath.MATCH_SCORE_QNAME) {
            this.m_supportVectors.add(Integer.valueOf(i2));
        } else {
            this.m_supportVectors.remove(Integer.valueOf(i2));
        }
        if (d2 <= XPath.MATCH_SCORE_QNAME || d2 >= d4) {
            this.m_I0.remove(Integer.valueOf(i2));
        } else {
            this.m_I0.add(Integer.valueOf(i2));
        }
        if (d8 == 1.0d && d2 == XPath.MATCH_SCORE_QNAME) {
            this.m_I1.add(Integer.valueOf(i2));
        } else {
            this.m_I1.remove(Integer.valueOf(i2));
        }
        if (d8 == -1.0d && d2 == d4) {
            this.m_I2.add(Integer.valueOf(i2));
        } else {
            this.m_I2.remove(Integer.valueOf(i2));
        }
        if (d8 == 1.0d && d2 == d4) {
            this.m_I3.add(Integer.valueOf(i2));
        } else {
            this.m_I3.remove(Integer.valueOf(i2));
        }
        if (d8 == -1.0d && d2 == XPath.MATCH_SCORE_QNAME) {
            this.m_I4.add(Integer.valueOf(i2));
        } else {
            this.m_I4.remove(Integer.valueOf(i2));
        }
        Iterator<Integer> it = this.m_I0.iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            if (intValue != i && intValue != i2) {
                double[] dArr = this.m_errors;
                dArr[intValue] = dArr[intValue] + (d7 * (d20 - d5) * this.kcache[i][intValue]) + (d8 * (d2 - d6) * this.kcache[i2][intValue]);
            }
        }
        double[] dArr2 = this.m_errors;
        dArr2[i] = dArr2[i] + (d7 * (d20 - d5) * d11) + (d8 * (d2 - d6) * d12);
        double[] dArr3 = this.m_errors;
        dArr3[i2] = dArr3[i2] + (d7 * (d20 - d5) * d12) + (d8 * (d2 - d6) * d13);
        this.alpha[i] = d20;
        this.alpha[i2] = d2;
        this.m_bLow = -1.7976931348623157E308d;
        this.m_bUp = Double.MAX_VALUE;
        this.m_iLow = -1;
        this.m_iUp = -1;
        Iterator<Integer> it2 = this.m_I0.iterator();
        while (it2.hasNext()) {
            int intValue2 = it2.next().intValue();
            if (this.m_errors[intValue2] < this.m_bUp) {
                this.m_bUp = this.m_errors[intValue2];
                this.m_iUp = intValue2;
            }
            if (this.m_errors[intValue2] > this.m_bLow) {
                this.m_bLow = this.m_errors[intValue2];
                this.m_iLow = intValue2;
            }
        }
        if (!this.m_I0.contains(Integer.valueOf(i))) {
            if (this.m_I3.contains(Integer.valueOf(i)) || this.m_I4.contains(Integer.valueOf(i))) {
                if (this.m_errors[i] > this.m_bLow) {
                    this.m_bLow = this.m_errors[i];
                    this.m_iLow = i;
                }
            } else if (this.m_errors[i] < this.m_bUp) {
                this.m_bUp = this.m_errors[i];
                this.m_iUp = i;
            }
        }
        if (!this.m_I0.contains(Integer.valueOf(i2))) {
            if (this.m_I3.contains(Integer.valueOf(i2)) || this.m_I4.contains(Integer.valueOf(i2))) {
                if (this.m_errors[i2] > this.m_bLow) {
                    this.m_bLow = this.m_errors[i2];
                    this.m_iLow = i2;
                }
            } else if (this.m_errors[i2] < this.m_bUp) {
                this.m_bUp = this.m_errors[i2];
                this.m_iUp = i2;
            }
        }
        if (this.m_iLow != -1 && this.m_iUp != -1) {
            return true;
        }
        logger.severe("This should never happen!");
        return false;
    }

    @Override // JKernelMachines.fr.lip6.classifier.Classifier
    public double valueOf(T t) {
        double d = 0.0d;
        for (int i = 0; i < this.size; i++) {
            d += this.alpha[i] * r0.label * this.kernel.valueOf(this.ts.get(i).sample, t);
        }
        return d - this.b;
    }

    public double[] getAlphas() {
        return this.alpha;
    }
}
