package fork.lib.math.applied.learning.classifier.tree;

import fork.lib.math.algebra.advanced.linearalgebra.Matrix;
import fork.lib.math.algebra.advanced.linearalgebra.Vector;
import fork.lib.math.applied.learning.classifier.Classifier;
import fork.lib.math.applied.stat.FrequencyCount;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;

/* loaded from: input_file:fork/lib/math/applied/learning/classifier/tree/TreeNode.class */
public abstract class TreeNode {
    protected Matrix mat;
    protected Vector outs;
    protected HashSet<Integer> unusedinds;
    protected TreeNode parent;
    protected ArrayList<TreeNode> children;
    protected NodeSplitCondition cond;
    protected int ind;
    protected boolean ifLeaf;
    protected double leafVal;

    public TreeNode(Matrix matrix, Vector vector, HashSet<Integer> hashSet) throws Exception {
        this.parent = null;
        this.children = new ArrayList<>();
        this.ind = -1;
        this.ifLeaf = false;
        this.leafVal = Double.NEGATIVE_INFINITY;
        this.mat = matrix.clone();
        this.outs = vector.clone();
        this.unusedinds = hashSet;
        init();
    }

    public TreeNode(Matrix matrix, Vector vector) throws Exception {
        this(matrix, vector, null);
    }

    public TreeNode() {
        this.parent = null;
        this.children = new ArrayList<>();
        this.ind = -1;
        this.ifLeaf = false;
        this.leafVal = Double.NEGATIVE_INFINITY;
    }

    protected abstract ArrayList<ArrayList<NodeSplitCondition>> columnVectorToConditions(Vector vector) throws Exception;

    public abstract Classifier getClassifier(ArrayList<String> arrayList) throws Exception;

    protected abstract TreeNode childNode(Matrix matrix, Vector vector, HashSet<Integer> hashSet) throws Exception;

    protected void init() throws Exception {
        int i = 0;
        while (i < this.mat.rowNumber()) {
            if (this.mat.get(i).toDistribution().keySize() == 1) {
                this.mat.remove(i);
                this.outs.remove(i);
                i--;
            }
            i++;
        }
        if (this.unusedinds == null) {
            this.unusedinds = new HashSet<>();
            for (int i2 = 0; i2 < this.mat.columnNumber(); i2++) {
                this.unusedinds.add(Integer.valueOf(i2));
            }
        }
    }

    public void computeChildren() throws Exception {
        if (this.mat.isEmpty()) {
            return;
        }
        if (this.unusedinds.isEmpty()) {
            this.ifLeaf = true;
            this.leafVal = computeLeaf();
            return;
        }
        if (isLeaf()) {
            return;
        }
        double d = Double.NEGATIVE_INFINITY;
        NodeSplit nodeSplit = null;
        for (int i = 0; i < this.mat.columnNumber(); i++) {
            if (this.unusedinds.contains(Integer.valueOf(i))) {
                Vector column = this.mat.getColumn(i);
                Iterator<ArrayList<NodeSplitCondition>> it = columnVectorToConditions(column).iterator();
                while (it.hasNext()) {
                    NodeSplit nodeSplit2 = new NodeSplit(column, it.next(), this.outs);
                    double doubleValue = nodeSplit2.informationGain().doubleValue();
                    if (doubleValue > d) {
                        d = doubleValue;
                        this.ind = i;
                        nodeSplit = nodeSplit2;
                    }
                }
            }
        }
        if (!isRoot() && d == 0.0d) {
            this.ifLeaf = true;
            this.leafVal = nodeSplit.mostFrequentOutput();
            return;
        }
        HashMap<Double, ArrayList<Integer>> classSplitIndeces = nodeSplit.classSplitIndeces();
        ArrayList arrayList = new ArrayList();
        arrayList.addAll(classSplitIndeces.keySet());
        Collections.sort(arrayList);
        Iterator it2 = arrayList.iterator();
        while (it2.hasNext()) {
            Double d2 = (Double) it2.next();
            TreeNode initChild = initChild(classSplitIndeces.get(d2));
            initChild.setCondition(nodeSplit.clasCond.get(d2));
            this.children.add(initChild);
        }
        HashSet<Double> outputSet = outputSet();
        if (outputSet.size() == 1) {
            this.ifLeaf = true;
            this.leafVal = outputSet.iterator().next().doubleValue();
        }
    }

    public HashSet<Double> outputSet() {
        HashSet<Double> hashSet = new HashSet<>();
        appendOutputSet(hashSet);
        return hashSet;
    }

    protected void appendOutputSet(HashSet<Double> hashSet) {
        if (isLeaf()) {
            hashSet.add(Double.valueOf(this.leafVal));
            return;
        }
        Iterator<TreeNode> it = this.children.iterator();
        while (it.hasNext()) {
            it.next().appendOutputSet(hashSet);
        }
    }

    public double entropy() {
        return NodeSplit.entropy(this.outs);
    }

    protected TreeNode initChild(ArrayList<Integer> arrayList) throws Exception {
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        Iterator<Integer> it = arrayList.iterator();
        while (it.hasNext()) {
            Integer next = it.next();
            arrayList2.add(this.mat.getRow(next.intValue()));
            arrayList3.add(this.outs.get(next.intValue()));
        }
        HashSet<Integer> hashSet = (HashSet) this.unusedinds.clone();
        hashSet.remove(Integer.valueOf(this.ind));
        TreeNode childNode = childNode(new Matrix(arrayList2), new Vector(arrayList3), hashSet);
        childNode.parent = this;
        if (childNode.entropy() == 0.0d) {
            childNode.ifLeaf = true;
            childNode.leafVal = childNode.computeLeaf();
        } else {
            childNode.computeChildren();
        }
        return childNode;
    }

    protected double computeLeaf() throws Exception {
        return ((Double) new FrequencyCount(this.outs).mostFrequentKey()).doubleValue();
    }

    public void setCondition(NodeSplitCondition nodeSplitCondition) {
        this.cond = nodeSplitCondition;
    }

    public void setParent(TreeNode treeNode) {
        this.parent = treeNode;
    }

    public boolean isLeaf() {
        return this.ifLeaf;
    }

    public boolean isRoot() {
        return this.parent == null;
    }

    public double leafValue() {
        return this.leafVal;
    }
}
