package fork.lib.math.applied.learning.classifier.neural;

import fork.lib.math.algebra.advanced.linearalgebra.Vector;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:fork/lib/math/applied/learning/classifier/neural/NeuronNetwork.class */
public class NeuronNetwork extends ArrayList<NeuronLayer> {
    protected Vector value;
    protected Vector err;
    protected Vector input;
    protected Vector target;
    protected double rate;

    public NeuronNetwork(List<NeuronLayer> list) throws Exception {
        this.value = null;
        this.err = null;
        this.input = null;
        this.target = null;
        this.rate = 0.1d;
        addAll(list);
    }

    public NeuronNetwork(NeuronLayer... neuronLayerArr) throws Exception {
        this((List<NeuronLayer>) Arrays.asList(neuronLayerArr));
    }

    public NeuronNetwork(int i, List<Integer> list) throws Exception {
        this.value = null;
        this.err = null;
        this.input = null;
        this.target = null;
        this.rate = 0.1d;
        int i2 = i;
        for (int i3 = 0; i3 < list.size(); i3++) {
            int intValue = list.get(i3).intValue();
            add(new NeuronLayer(i2, intValue));
            i2 = intValue;
        }
    }

    public Vector evaluate(Vector vector) throws Exception {
        this.input = vector;
        this.value = null;
        Iterator<NeuronLayer> it = iterator();
        while (it.hasNext()) {
            NeuronLayer next = it.next();
            if (this.value == null) {
                this.value = next.evaluate(vector);
            } else {
                this.value = next.evaluate(this.value);
            }
        }
        return this.value;
    }

    public Vector computeError(Vector vector) throws Exception {
        this.target = vector;
        this.err = new Vector();
        for (int i = 0; i < vector.size(); i++) {
            this.err.add((Vector) Double.valueOf(0.5d * Math.pow(this.value.get(i).doubleValue() - vector.get(i).doubleValue(), 2.0d)));
        }
        return this.err;
    }

    public void setLearningRate(double d) {
        this.rate = d;
    }

    public void backTrack() throws Exception {
        for (int size = size() - 1; size >= 0; size--) {
            NeuronLayer neuronLayer = get(size);
            if (size == size() - 1) {
                for (int i = 0; i < neuronLayer.size(); i++) {
                    neuronLayer.get(i).computeDerivative(this.target.get(i).doubleValue());
                }
            } else {
                for (int i2 = 0; i2 < neuronLayer.size(); i2++) {
                    neuronLayer.get(i2).computeDerivative(get(size + 1), i2);
                }
            }
        }
        int i3 = 0;
        while (i3 < size()) {
            NeuronLayer neuronLayer2 = get(i3);
            Vector activatedValues = i3 == 0 ? this.input : get(i3 - 1).getActivatedValues();
            Iterator<Neuron> it = neuronLayer2.iterator();
            while (it.hasNext()) {
                it.next().updateWeights(activatedValues, this.rate);
            }
            i3++;
        }
    }

    public Vector getValue() {
        return this.value;
    }

    public double getError() {
        double d = 0.0d;
        for (int i = 0; i < this.err.size(); i++) {
            d += this.err.get(i).doubleValue();
        }
        return d;
    }

    public Vector getInput() {
        return this.input;
    }

    public double getLearningRate() {
        return this.rate;
    }

    public void print() {
        System.out.println();
        for (int i = 0; i < size(); i++) {
            NeuronLayer neuronLayer = get(i);
            System.out.println("layer " + i + ":");
            for (int i2 = 0; i2 < neuronLayer.size(); i2++) {
                System.out.println("   " + neuronLayer.get(i2) + "  " + neuronLayer.get(i2).bias);
            }
        }
    }

    public static void main(String[] strArr) throws Exception {
        NeuronNetwork neuronNetwork = new NeuronNetwork(2, Arrays.asList(3, 4, 1));
        neuronNetwork.setLearningRate(5.0d);
        for (int i = 0; i < neuronNetwork.size(); i++) {
            NeuronLayer neuronLayer = neuronNetwork.get(i);
            System.out.println("layer " + i + ":");
            for (int i2 = 0; i2 < neuronLayer.size(); i2++) {
                System.out.println("   " + neuronLayer.get(i2));
            }
        }
        Vector[] vectorArr = {new Vector(0.0d, 0.0d), new Vector(0.0d, 1.0d), new Vector(1.0d, 0.0d), new Vector(1.0d, 1.0d)};
        Vector[] vectorArr2 = {new Vector(0.0d), new Vector(1.0d), new Vector(1.0d), new Vector(0.0d)};
        for (int i3 = 0; i3 < 10000; i3++) {
            for (int i4 = 0; i4 < vectorArr.length; i4++) {
                neuronNetwork.evaluate(vectorArr[i4]);
                neuronNetwork.computeError(vectorArr2[i4]);
                neuronNetwork.backTrack();
            }
        }
        System.out.println();
        for (int i5 = 0; i5 < neuronNetwork.size(); i5++) {
            NeuronLayer neuronLayer2 = neuronNetwork.get(i5);
            System.out.println("layer " + i5 + ":");
            for (int i6 = 0; i6 < neuronLayer2.size(); i6++) {
                System.out.println("   " + neuronLayer2.get(i6) + "  " + neuronLayer2.get(i6).bias);
            }
        }
        System.out.println();
        for (int i7 = 0; i7 < vectorArr.length; i7++) {
            System.out.println(vectorArr[i7] + "   " + neuronNetwork.evaluate(vectorArr[i7]));
        }
    }
}
