package fork.lib.math.applied.optim.em;

import fork.lib.math.algebra.elementary.function.v1.FunctionV1;
import fork.lib.math.algebra.elementary.function.v1.distr.DistributionException;
import fork.lib.math.algebra.elementary.function.v1.distr.DistributionFunction;
import fork.lib.math.algebra.elementary.function.v1.distr.NormalDistribution;
import fork.lib.math.algebra.elementary.function.v1.distr.ParameterEstimation;
import fork.lib.math.algebra.elementary.function.v1.polynomial.ConstantFunction;
import fork.lib.math.applied.stat.Distribution;
import java.util.ArrayList;

/* loaded from: input_file:fork/lib/math/applied/optim/em/EM.class */
public class EM {
    protected Distribution data;
    protected DistributionFunction[] types;
    public EMParam par;
    public ArrayList<EMResultEntry> res;

    public EM(Distribution distribution, DistributionFunction[] distributionFunctionArr, EMParam eMParam) throws Exception {
        this.data = distribution;
        this.types = distributionFunctionArr;
        this.par = eMParam;
        init();
    }

    private void init() throws Exception {
        if (this.par == null) {
            this.par = new EMParam();
        }
        this.res = new ArrayList<>();
        if (this.par.autoStart) {
            start();
        }
    }

    public void start() throws Exception {
        double d = Double.POSITIVE_INFINITY;
        int i = 0;
        while (true) {
            EMCycle eMCycle = new EMCycle(this.data, i == 0 ? findStartDistrs() : this.res.get(this.res.size() - 1).distrs, i == 0 ? findStartCoeffs() : this.res.get(this.res.size() - 1).coeffs);
            double bic = bic(this.data, eMCycle.parEstim.funcsOut, eMCycle.parEstim.coeffsOut);
            this.res.add(new EMResultEntry(eMCycle.parEstim.coeffsOut, eMCycle.parEstim.funcsOut, bic));
            if (i >= this.par.maxCycle || Math.abs(bic - d) < this.par.llhThr) {
                return;
            }
            d = bic;
            i++;
        }
    }

    protected ArrayList<DistributionFunction> findStartDistrs() throws EMException, DistributionException {
        ArrayList<DistributionFunction> arrayList = new ArrayList<>();
        ArrayList<Distribution> quantileSubsets = this.data.quantileSubsets(this.types.length);
        for (int i = 0; i < this.types.length; i++) {
            if (!(this.types[i] instanceof NormalDistribution)) {
                if (this.types[i] instanceof NormalDistribution) {
                    throw new EMException();
                }
                throw new EMException();
            }
            arrayList.add(new ParameterEstimation(quantileSubsets.get(i), new NormalDistribution(0.0d, 1.0d)).funcOut);
        }
        return arrayList;
    }

    protected ArrayList<Double> findStartCoeffs() throws EMException {
        ArrayList<Double> arrayList = new ArrayList<>();
        for (int i = 0; i < this.types.length; i++) {
            arrayList.add(Double.valueOf(1.0d / this.types.length));
        }
        return arrayList;
    }

    public static double logLikelihood(Distribution distribution, ArrayList<DistributionFunction> arrayList, ArrayList<Double> arrayList2) throws Exception {
        FunctionV1 constantFunction = new ConstantFunction(0.0d);
        for (int i = 0; i < arrayList.size(); i++) {
            constantFunction = constantFunction.add(arrayList.get(i).multiply(new ConstantFunction(arrayList2.get(i).doubleValue())));
        }
        return distribution.logLikelihood(constantFunction);
    }

    public static double bic(Distribution distribution, ArrayList<DistributionFunction> arrayList, ArrayList<Double> arrayList2) throws Exception {
        double logLikelihood = logLikelihood(distribution, arrayList, arrayList2);
        int i = 0;
        for (int i2 = 0; i2 < arrayList.size(); i2++) {
            i += arrayList.get(i2).par.npar + 1;
        }
        return (-2.0d) * logLikelihood;
    }

    public EMResultEntry getOptimalResultEntry() {
        EMResultEntry eMResultEntry = null;
        double d = Double.POSITIVE_INFINITY;
        for (int i = 0; i < this.res.size(); i++) {
            EMResultEntry eMResultEntry2 = this.res.get(i);
            if (eMResultEntry2.bic < d) {
                d = eMResultEntry2.bic;
                eMResultEntry = eMResultEntry2;
            }
        }
        return eMResultEntry;
    }

    public static void main(String[] strArr) throws Exception {
        Distribution distribution = new Distribution();
        distribution.add(0.0d, 8.0d);
        distribution.add(1.0d, 10.0d);
        distribution.add(2.0d, 15.0d);
        distribution.add(7.0d, 5.0d);
        distribution.add(8.0d, 8.0d);
        distribution.add(9.0d, 10.0d);
        distribution.add(10.0d, 5.0d);
        distribution.add(20.0d, 5.0d);
        distribution.add(21.0d, 10.0d);
        distribution.add(22.0d, 8.0d);
        distribution.add(120.0d, 5.0d);
        distribution.add(121.0d, 10.0d);
        distribution.add(122.0d, 8.0d);
        EM em = new EM(distribution, new DistributionFunction[]{new NormalDistribution(0.0d, 1.0d), new NormalDistribution(0.0d, 1.0d), new NormalDistribution(0.0d, 1.0d), new NormalDistribution(0.0d, 1.0d)}, null);
        em.start();
        EMResultEntry optimalResultEntry = em.getOptimalResultEntry();
        ArrayList<Double> arrayList = optimalResultEntry.coeffs;
        ArrayList<DistributionFunction> arrayList2 = optimalResultEntry.distrs;
        for (int i = 0; i < arrayList.size(); i++) {
            System.out.println("## " + arrayList.get(i));
            arrayList2.get(i).printParametersAfter("  ");
        }
        System.out.println(optimalResultEntry.bic);
    }
}
