package com.github.psambit9791.jdsp.transform;

import com.github.psambit9791.jdsp.misc.UtilMethods;
import java.lang.reflect.Array;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.SingularValueDecomposition;
import org.apache.commons.math3.stat.StatUtils;
import org.apache.commons.math3.util.MathArrays;

/* loaded from: classes2.dex */
public class PCA {
    private double[][] S;
    private double[][] U;
    private double[][] V;
    public double[] explained_variance_;
    public double[] explained_variance_ratio_;
    private double[] mean_;
    private int n_components;
    private int n_samples;
    private double[][] output;
    private double[][] signal;
    public double[] singular_values_;
    private double[][] zm_signal;

    public PCA(double[][] dArr, int i) throws ExceptionInInitializerError, IllegalArgumentException {
        int length = dArr.length;
        double[] dArr2 = dArr[0];
        if (length < dArr2.length) {
            throw new ExceptionInInitializerError("Signal length must be more than number of channels");
        }
        if (i > dArr2.length || i <= 0) {
            throw new ExceptionInInitializerError("n_components must be greater than 0 and less than total channels in signal");
        }
        this.signal = dArr;
        this.n_samples = dArr.length;
        this.n_components = i;
        this.output = (double[][]) Array.newInstance((Class<?>) Double.TYPE, i, dArr.length);
    }

    private double[][][] svdFlip(double[][] dArr, double[][] dArr2) {
        double[][] absoluteArray = UtilMethods.absoluteArray(dArr);
        double[] dArr3 = dArr[0];
        int length = dArr3.length;
        int[] iArr = new int[length];
        double[] dArr4 = new double[dArr3.length];
        for (int i = 0; i < absoluteArray[0].length; i++) {
            double[] dArr5 = new double[absoluteArray.length];
            for (int i2 = 0; i2 < absoluteArray.length; i2++) {
                dArr5[i2] = absoluteArray[i2][i];
            }
            iArr[i] = UtilMethods.argmax(dArr5, false);
        }
        for (int i3 = 0; i3 < length; i3++) {
            dArr4[i3] = Math.signum(dArr[iArr[i3]][i3]);
        }
        for (int i4 = 0; i4 < dArr.length; i4++) {
            dArr[i4] = MathArrays.ebeMultiply(dArr[i4], dArr4);
        }
        double[][] transpose = UtilMethods.transpose(dArr2);
        for (int i5 = 0; i5 < transpose.length; i5++) {
            transpose[i5] = MathArrays.ebeMultiply(transpose[i5], dArr4);
        }
        return new double[][][]{dArr, UtilMethods.transpose(transpose)};
    }

    public void fit() {
        double[] dArr;
        double[][] transpose = UtilMethods.transpose(this.signal);
        double[][] dArr2 = this.signal;
        double[][] dArr3 = (double[][]) Array.newInstance((Class<?>) Double.TYPE, dArr2.length, dArr2[0].length);
        this.zm_signal = dArr3;
        this.zm_signal = UtilMethods.transpose(dArr3);
        this.mean_ = new double[transpose.length];
        for (int i = 0; i < transpose.length; i++) {
            this.mean_[i] = StatUtils.mean(transpose[i]);
            this.zm_signal[i] = UtilMethods.zeroCenter(transpose[i]);
        }
        double[][] transpose2 = UtilMethods.transpose(this.zm_signal);
        this.zm_signal = transpose2;
        SingularValueDecomposition singularValueDecomposition = new SingularValueDecomposition(MatrixUtils.createRealMatrix(transpose2));
        double[][] data = singularValueDecomposition.getU().getData();
        double[][] data2 = singularValueDecomposition.getS().getData();
        double[][][] svdFlip = svdFlip(data, singularValueDecomposition.getVT().getData());
        double[][] dArr4 = svdFlip[0];
        double[][] dArr5 = svdFlip[1];
        double[] singularValues = singularValueDecomposition.getSingularValues();
        this.singular_values_ = singularValues;
        this.explained_variance_ = MathArrays.ebeMultiply(singularValues, singularValues);
        int i2 = 0;
        while (true) {
            dArr = this.explained_variance_;
            if (i2 >= dArr.length) {
                break;
            }
            dArr[i2] = dArr[i2] / (this.n_samples - 1);
            i2++;
        }
        double sum = StatUtils.sum(dArr);
        this.explained_variance_ratio_ = new double[data2.length];
        int i3 = 0;
        while (true) {
            double[] dArr6 = this.explained_variance_;
            if (i3 >= dArr6.length) {
                this.singular_values_ = UtilMethods.splitByIndex(this.singular_values_, 0, this.n_components);
                this.explained_variance_ = UtilMethods.splitByIndex(this.explained_variance_, 0, this.n_components);
                this.explained_variance_ratio_ = UtilMethods.splitByIndex(this.explained_variance_ratio_, 0, this.n_components);
                this.U = dArr4;
                this.S = data2;
                this.V = dArr5;
                return;
            }
            this.explained_variance_ratio_[i3] = dArr6[i3] / sum;
            i3++;
        }
    }

    public double[][][] getUSV() throws ExceptionInInitializerError {
        if (this.singular_values_ != null) {
            return new double[][][]{this.U, this.S, this.V};
        }
        throw new ExceptionInInitializerError("Execute fit() before calling this function");
    }

    public double[][] transform() throws ExceptionInInitializerError {
        if (this.singular_values_ == null) {
            throw new ExceptionInInitializerError("Execute fit() before calling this function");
        }
        double[][] dArr = (double[][]) Array.newInstance((Class<?>) Double.TYPE, this.n_components, this.n_samples);
        for (int i = 0; i < this.n_components; i++) {
            dArr[i] = this.V[i];
        }
        double[][] matrixMultiply = UtilMethods.matrixMultiply(this.zm_signal, UtilMethods.transpose(dArr));
        this.output = matrixMultiply;
        return matrixMultiply;
    }

    public double[][] transform(double[][] dArr) throws ExceptionInInitializerError, ArithmeticException {
        if (this.singular_values_ == null) {
            throw new ExceptionInInitializerError("Execute fit() before calling this function");
        }
        if (dArr[0].length != this.signal[0].length) {
            throw new ArithmeticException("Number of channels has to be same as original signal");
        }
        double[][] transpose = UtilMethods.transpose(dArr);
        double[][] transpose2 = UtilMethods.transpose((double[][]) Array.newInstance((Class<?>) Double.TYPE, dArr.length, dArr[0].length));
        for (int i = 0; i < transpose.length; i++) {
            transpose2[i] = UtilMethods.scalarArithmetic(transpose[i], this.mean_[i], "sub");
        }
        double[][] transpose3 = UtilMethods.transpose(transpose2);
        double[][] dArr2 = (double[][]) Array.newInstance((Class<?>) Double.TYPE, this.n_components, this.n_samples);
        for (int i2 = 0; i2 < this.n_components; i2++) {
            dArr2[i2] = this.V[i2];
        }
        return UtilMethods.matrixMultiply(transpose3, UtilMethods.transpose(dArr2));
    }
}
