package fork.lib.math.applied.learning.classifier.tree;

import fork.lib.math.algebra.advanced.linearalgebra.Vector;
import fork.lib.math.applied.stat.FrequencyCount;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:fork/lib/math/applied/learning/classifier/tree/NodeSplit.class */
public class NodeSplit {
    protected Vector vec;
    protected Vector outs;
    protected Vector clas;
    protected double entropy;
    protected double ig;
    protected HashMap<NodeSplitCondition, Double> condClas = new HashMap<>();
    protected HashMap<Double, NodeSplitCondition> clasCond = new HashMap<>();
    protected HashMap<Double, Double> clasOut = new HashMap<>();
    protected HashMap<Double, Vector> clasVec = new HashMap<>();

    public NodeSplit(Vector vector, List<NodeSplitCondition> list, Vector vector2) throws Exception {
        this.vec = vector;
        this.outs = vector2;
        for (int i = 0; i < list.size(); i++) {
            NodeSplitCondition nodeSplitCondition = list.get(i);
            double d = i;
            this.condClas.put(nodeSplitCondition, Double.valueOf(d));
            this.clasCond.put(Double.valueOf(d), nodeSplitCondition);
        }
        init();
    }

    protected void init() throws Exception {
        this.entropy = entropy(this.outs);
        this.clas = computeClasses();
        this.ig = informationGain(this.clas, this.outs);
    }

    protected Vector computeClasses() throws Exception {
        Vector vector = new Vector();
        Iterator<Double> it = this.vec.iterator();
        while (it.hasNext()) {
            Double next = it.next();
            double d = -1.0d;
            Iterator<NodeSplitCondition> it2 = this.condClas.keySet().iterator();
            while (true) {
                if (it2.hasNext()) {
                    NodeSplitCondition next2 = it2.next();
                    if (next2.satisfy(next.doubleValue())) {
                        d = this.condClas.get(next2).doubleValue();
                        break;
                    }
                }
            }
            vector.add((Vector) Double.valueOf(d));
        }
        return vector;
    }

    public Double informationGain() {
        return Double.valueOf(this.ig);
    }

    private double informationGain(Vector vector, Vector vector2) {
        for (int i = 0; i < vector.size(); i++) {
            Double d = vector.get(i);
            if (!this.clasVec.containsKey(d)) {
                this.clasVec.put(d, new Vector());
            }
            this.clasVec.get(d).add((Vector) vector2.get(i));
        }
        double d2 = this.entropy;
        Iterator<Double> it = this.clasVec.keySet().iterator();
        while (it.hasNext()) {
            d2 -= (entropy(this.clasVec.get(it.next())) * r0.size()) / vector2.size();
        }
        return d2;
    }

    public static double entropy(Vector vector) {
        FrequencyCount frequencyCount = new FrequencyCount();
        Iterator<Double> it = vector.iterator();
        while (it.hasNext()) {
            frequencyCount.add(it.next());
        }
        double d = 0.0d;
        Iterator it2 = frequencyCount.keys().iterator();
        while (it2.hasNext()) {
            d += entropyForOutput(frequencyCount.getCount((Double) it2.next()) / vector.size());
        }
        return d;
    }

    public static double entropyForOutput(double d) {
        if (d == 0.0d || d == 1.0d) {
            return 0.0d;
        }
        return d * Math.log(1.0d / d);
    }

    public HashMap<Double, ArrayList<Integer>> classSplitIndeces() throws Exception {
        HashMap<Double, ArrayList<Integer>> hashMap = new HashMap<>();
        for (int i = 0; i < this.clas.size(); i++) {
            double doubleValue = this.clas.get(i).doubleValue();
            if (!hashMap.containsKey(Double.valueOf(doubleValue))) {
                hashMap.put(Double.valueOf(doubleValue), new ArrayList<>());
            }
            hashMap.get(Double.valueOf(doubleValue)).add(Integer.valueOf(i));
        }
        return hashMap;
    }

    public double mostFrequentOutput() {
        double d = -1.0d;
        Iterator<Double> it = this.clasVec.keySet().iterator();
        while (it.hasNext()) {
            double doubleValue = this.clasVec.get(it.next()).toDistribution().mostFrequentKey().doubleValue();
            if (doubleValue > d) {
                d = doubleValue;
            }
        }
        return d;
    }
}
