package fork.lib.math.applied.learning.classifier.neural;

import fork.lib.math.algebra.advanced.linearalgebra.Vector;
import fork.lib.math.algebra.elementary.function.v1.FunctionV1;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:fork/lib/math/applied/learning/classifier/neural/Neuron.class */
public class Neuron extends Vector {
    protected FunctionV1 func;
    protected double bias;
    protected double valBefore;
    protected double val;
    protected double deriv;

    public Neuron(List<Double> list, Double d, FunctionV1 functionV1) throws Exception {
        super(list);
        this.func = null;
        this.bias = Double.NEGATIVE_INFINITY;
        this.valBefore = 0.0d;
        this.val = 0.0d;
        this.func = functionV1;
        this.bias = d.doubleValue();
        init();
    }

    public Neuron(Double... dArr) throws Exception {
        this(Arrays.asList(dArr), Double.valueOf(0.0d), null);
    }

    public Neuron(int i) throws Exception {
        this(Vector.sequence(Double.valueOf(0.0d), Double.valueOf(1.0d), Integer.valueOf(i)), Double.valueOf(0.0d), null);
    }

    protected void init() throws Exception {
        if (this.bias == Double.NEGATIVE_INFINITY) {
            this.bias = 1.0d;
        }
        if (this.func == null) {
            this.func = defaultActivationFunction();
        }
    }

    public static FunctionV1 defaultActivationFunction() throws Exception {
        return new FunctionV1() { // from class: fork.lib.math.applied.learning.classifier.neural.Neuron.1
            @Override // fork.lib.math.algebra.elementary.function.v1.FunctionV1
            public double getY(double d) {
                return 1.0d / (1.0d + Math.exp(-d));
            }
        };
    }

    public double evaluate(Vector vector) throws Exception {
        this.valBefore = innerProduct(vector) + this.bias;
        this.val = this.func.getY(this.valBefore);
        return this.val;
    }

    public double getInactivatedValue() {
        return this.valBefore;
    }

    public double getActivatedValue() {
        return this.val;
    }

    public double getDerivative() {
        return this.deriv;
    }

    public double computeDerivative(double d) {
        this.deriv = (this.val - d) * this.val * (1.0d - this.val);
        return this.deriv;
    }

    public double computeDerivative(NeuronLayer neuronLayer, int i) {
        this.deriv = 0.0d;
        Iterator<Neuron> it = neuronLayer.iterator();
        while (it.hasNext()) {
            Neuron next = it.next();
            this.deriv += ((Double) next.get(i)).doubleValue() * next.getDerivative() * this.val * (1.0d - this.val);
        }
        return this.deriv;
    }

    public void updateWeights(Vector vector, double d) {
        for (int i = 0; i < size(); i++) {
            set(i, Double.valueOf(((Double) get(i)).doubleValue() - ((d * this.deriv) * vector.get(i).doubleValue())));
        }
        this.bias -= d * this.deriv;
    }

    public static void main(String[] strArr) throws Exception {
        NeuronNetwork.main(strArr);
    }
}
