package javastat.regression.glm;

import JSci.maths.statistics.NormalDistribution;
import Jama.Matrix;
import java.util.Hashtable;
import javastat.StatisticalAnalysis;
import javastat.util.Argument;
import javastat.util.BasicStatistics;
import javastat.util.DataManager;
import javastat.util.GLMDataManager;
import javastat.util.Output;
import weka.classifiers.lazy.kstar.KStarConstants;

/* loaded from: input_file:javastat/regression/glm/GLMTemplate.class */
public abstract class GLMTemplate extends StatisticalAnalysis {
    public double[] response;
    public double[][] covariate;
    public double[] coefficients;
    public double[] coefficientSE;
    public double alpha;
    public double[][] confidenceInterval;
    public double[] testStatistic;
    public double[] pValue;
    public double[] linearPredictors;
    public double[][] correlation;
    public double[][] variance;
    public double[] fittedValues;
    public double[] devianceResiduals;
    public double[] pearsonResiduals;
    public double[] responseResiduals;
    public double[][] weights;
    public double[] means;
    public double[] responseVariance;
    public double[][] devianceTable;
    public LinkFunction link;
    private double[][] xwx;
    private double[][] inversedWeights;
    private double error;
    private Matrix responseMatrix;
    private Matrix covariateMatrix;
    private Matrix coefficientMatrix;
    private Matrix linearPredictorMatrix;
    private Matrix weightMatrix;
    private Matrix zMatrix;
    private Matrix xwxMatrix;
    private Matrix updatedCoefficientMatrix;
    private NormalDistribution normalDistribution;
    private double zAlpha;

    protected double[] coefficients(Hashtable hashtable, Object[] objArr) {
        return coefficients(objArr);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double[] coefficients(Object[] objArr) {
        this.response = (double[]) objArr[0];
        this.covariate = (double[][]) objArr[1];
        BasicStatistics.convergenceCriterion = new double[]{1.0E-6d};
        new DataManager().checkDimension(this.covariate);
        if (this.response.length != this.covariate[0].length) {
            throw new IllegalArgumentException("The response vector and rows of the covariate matrix must have the same length.");
        }
        this.coefficients = new GLMDataManager().setInitialEstimate(this.covariate);
        this.covariateMatrix = new Matrix(this.covariate);
        this.responseMatrix = new Matrix(this.response, this.response.length);
        this.error = 1.0d;
        while (this.error > BasicStatistics.convergenceCriterion[0]) {
            this.coefficientMatrix = new Matrix(this.coefficients, this.coefficients.length);
            this.weights = weights(this.coefficients, this.covariate);
            this.inversedWeights = new double[this.weights.length][this.weights.length];
            this.linearPredictors = linearPredictors(this.covariate, this.coefficients);
            this.means = means(this.coefficients, this.covariate);
            this.weightMatrix = new Matrix(this.weights);
            for (int i = 0; i < this.weights.length; i++) {
                if (this.weights[i][i] == KStarConstants.FLOOR) {
                    this.inversedWeights[i][i] = 0.0d;
                } else {
                    this.inversedWeights[i][i] = 1.0d / this.weights[i][i];
                }
            }
            this.linearPredictorMatrix = this.covariateMatrix.transpose().times(this.coefficientMatrix);
            this.zMatrix = this.linearPredictorMatrix.plus(new Matrix(this.inversedWeights).times(this.responseMatrix.minus(new Matrix(this.means, this.means.length))));
            this.xwxMatrix = this.covariateMatrix.times(this.weightMatrix).times(this.covariateMatrix.transpose());
            if (Math.abs(this.xwxMatrix.det()) <= 1.0E-8d) {
                this.xwx = this.xwxMatrix.getArray();
                for (int i2 = 0; i2 < this.xwx.length; i2++) {
                    this.xwx[i2][i2] = this.xwx[i2][i2] + 0.1d;
                }
                this.xwxMatrix = new Matrix(this.xwx);
            }
            this.updatedCoefficientMatrix = this.xwxMatrix.inverse().times(this.covariateMatrix.times(this.weightMatrix)).times(this.zMatrix);
            this.coefficients = this.updatedCoefficientMatrix.getColumnPackedCopy();
            this.error = Math.pow(this.updatedCoefficientMatrix.minus(this.coefficientMatrix).normF(), 2.0d);
        }
        this.output.put(Output.COEFFICIENTS, this.coefficients);
        return this.coefficients;
    }

    public double[] linearPredictors(Hashtable hashtable, Object[] objArr) {
        return objArr[1].getClass().getName().equalsIgnoreCase("[[D") ? linearPredictors((double[]) objArr[0], (double[][]) objArr[1]) : linearPredictors((double[][]) objArr[0], (double[]) objArr[1]);
    }

    public double[] linearPredictors(double[] dArr, double[][] dArr2) {
        this.coefficients = coefficients(new Object[]{dArr, dArr2});
        this.coefficientMatrix = new Matrix(this.coefficients, this.coefficients.length);
        this.linearPredictors = this.covariateMatrix.transpose().times(this.coefficientMatrix).getColumnPackedCopy();
        this.output.put(Output.LINEAR_PREDICTORS, this.linearPredictors);
        return this.linearPredictors;
    }

    public double[] linearPredictors(double[][] dArr, double[] dArr2) {
        this.linearPredictors = new Matrix(dArr).transpose().times(new Matrix(dArr2, dArr2.length)).getColumnPackedCopy();
        return this.linearPredictors;
    }

    public double[][] confidenceInterval(Hashtable hashtable, Object[] objArr) {
        this.alpha = ((Double) hashtable.get(Argument.ALPHA)).doubleValue();
        if (this.alpha <= KStarConstants.FLOOR || this.alpha > 1.0d) {
            throw new IllegalArgumentException("The level of significance should be (strictly) positive and not greater than 1.");
        }
        this.response = (double[]) objArr[0];
        this.covariate = (double[][]) objArr[1];
        this.testStatistic = testStatistic(new Object[]{this.response, this.covariate});
        this.zAlpha = new NormalDistribution().inverse(1.0d - (this.alpha / 2.0d));
        this.confidenceInterval = new double[this.testStatistic.length][2];
        for (int i = 0; i < this.testStatistic.length; i++) {
            this.confidenceInterval[i][0] = this.coefficients[i] - (this.zAlpha * Math.sqrt(this.variance[i][i]));
            this.confidenceInterval[i][1] = this.coefficients[i] + (this.zAlpha * Math.sqrt(this.variance[i][i]));
        }
        this.output.put(Output.CONFIDENCE_INTERVAL, this.confidenceInterval);
        return this.confidenceInterval;
    }

    protected double[] testStatistic(Hashtable hashtable, Object[] objArr) {
        return testStatistic(objArr);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double[] testStatistic(Object[] objArr) {
        this.response = (double[]) objArr[0];
        this.covariate = (double[][]) objArr[1];
        this.coefficients = coefficients(new Object[]{this.response, this.covariate});
        this.weights = weights(this.coefficients, this.covariate);
        this.weightMatrix = new Matrix(this.weights);
        this.xwxMatrix = this.covariateMatrix.times(this.weightMatrix).times(this.covariateMatrix.transpose());
        this.variance = this.xwxMatrix.inverse().getArray();
        this.coefficientSE = new double[this.variance.length];
        for (int i = 0; i < this.variance.length; i++) {
            this.coefficientSE[i] = Math.sqrt(this.variance[i][i]);
        }
        this.correlation = new double[this.variance.length][this.variance.length];
        for (int i2 = 0; i2 < this.variance.length; i2++) {
            for (int i3 = i2; i3 < this.variance.length; i3++) {
                this.correlation[i2][i3] = this.variance[i2][i3] / Math.sqrt(this.variance[i2][i2] * this.variance[i3][i3]);
                this.correlation[i3][i2] = this.correlation[i2][i3];
            }
        }
        this.testStatistic = new double[this.coefficients.length];
        for (int i4 = 0; i4 < this.testStatistic.length; i4++) {
            this.testStatistic[i4] = this.coefficients[i4] / this.coefficientSE[i4];
        }
        this.output.put(Output.TEST_STATISTIC, this.testStatistic);
        return this.testStatistic;
    }

    protected double[] pValue(Hashtable hashtable, Object[] objArr) {
        return pValue(objArr);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double[] pValue(Object[] objArr) {
        this.response = (double[]) objArr[0];
        this.covariate = (double[][]) objArr[1];
        this.testStatistic = testStatistic(new Object[]{this.response, this.covariate});
        this.linearPredictors = linearPredictors(this.covariate, this.coefficients);
        this.normalDistribution = new NormalDistribution();
        this.pValue = new double[this.testStatistic.length];
        for (int i = 0; i < this.testStatistic.length; i++) {
            this.pValue[i] = 2.0d * (1.0d - this.normalDistribution.cumulative(Math.abs(this.testStatistic[i])));
        }
        this.output.put(Output.PVALUE, this.pValue);
        return this.pValue;
    }

    protected double[] responseResiduals(Hashtable hashtable, Object[] objArr) {
        return responseResiduals((double[]) objArr[0], (double[][]) objArr[1]);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double[] responseResiduals(double[] dArr, double[][] dArr2) {
        this.coefficients = coefficients(new Object[]{dArr, dArr2});
        this.means = means(this.coefficients, dArr2);
        this.responseResiduals = new double[dArr.length];
        for (int i = 0; i < dArr.length; i++) {
            this.responseResiduals[i] = dArr[i] - this.means[i];
        }
        this.output.put(Output.RESPONSE_RESIDUALS, this.responseResiduals);
        return this.responseResiduals;
    }

    protected double[] pearsonResiduals(Hashtable hashtable, Object[] objArr) {
        return pearsonResiduals((double[]) objArr[0], (double[]) objArr[1], (double[][]) objArr[2]);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double[] pearsonResiduals(double[] dArr, double[] dArr2, double[][] dArr3) {
        this.responseResiduals = responseResiduals(dArr2, dArr3);
        this.pearsonResiduals = new double[dArr2.length];
        for (int i = 0; i < dArr2.length; i++) {
            if (dArr[i] == KStarConstants.FLOOR) {
                this.pearsonResiduals[i] = 0.0d;
            } else {
                this.pearsonResiduals[i] = this.responseResiduals[i] / Math.sqrt(dArr[i]);
            }
        }
        return this.pearsonResiduals;
    }

    protected abstract double[][] weights(double[] dArr, double[][] dArr2);

    protected abstract double[] means(double[] dArr, double[][] dArr2);

    protected double[] means(Hashtable hashtable, Object[] objArr) {
        return means((LinkFunction) objArr[2], (double[]) objArr[1], (double[][]) objArr[0]);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double[] means(LinkFunction linkFunction, double[] dArr, double[][] dArr2) {
        this.linearPredictors = linearPredictors(dArr2, dArr);
        this.means = new double[this.linearPredictors.length];
        linkFunction.getClass();
        if (linkFunction == LinkFunction.IDENTITY) {
            for (int i = 0; i < this.means.length; i++) {
                this.means[i] = this.linearPredictors[i];
            }
        } else if (linkFunction == LinkFunction.LOG) {
            for (int i2 = 0; i2 < this.means.length; i2++) {
                this.means[i2] = Math.exp(this.linearPredictors[i2]);
            }
        } else if (linkFunction == LinkFunction.INVERSE) {
            for (int i3 = 0; i3 < this.means.length; i3++) {
                this.means[i3] = 1.0d / this.linearPredictors[i3];
            }
        } else if (linkFunction == LinkFunction.INVERSE_SQUARE) {
            for (int i4 = 0; i4 < this.means.length; i4++) {
                this.means[i4] = Math.pow(this.linearPredictors[i4], -0.5d);
            }
        } else if (linkFunction == LinkFunction.SQUARE_ROOT) {
            for (int i5 = 0; i5 < this.means.length; i5++) {
                this.means[i5] = Math.pow(this.linearPredictors[i5], 2.0d);
            }
        } else if (linkFunction == LinkFunction.LOGIT) {
            for (int i6 = 0; i6 < this.means.length; i6++) {
                this.means[i6] = Math.exp(this.linearPredictors[i6]) / (1.0d + Math.exp(this.linearPredictors[i6]));
            }
        } else if (linkFunction == LinkFunction.PROBIT) {
            for (int i7 = 0; i7 < this.means.length; i7++) {
                this.means[i7] = this.normalDistribution.inverse(this.linearPredictors[i7]);
            }
        } else {
            if (linkFunction != LinkFunction.COMPLEMENTARY_LOGLOG) {
                throw new IllegalArgumentException("Wrong input link function.");
            }
            for (int i8 = 0; i8 < this.means.length; i8++) {
                this.means[i8] = 1.0d - Math.exp((-1.0d) * Math.exp(this.linearPredictors[i8]));
            }
        }
        this.output.put(Output.MEANS, this.means);
        return this.means;
    }

    public double[] responseVariance(Hashtable hashtable, Object[] objArr) {
        return responseVariance((double[]) objArr[0], (ExponentialFamily) objArr[1]);
    }

    public double[] responseVariance(double[] dArr, ExponentialFamily exponentialFamily) {
        this.responseVariance = new double[dArr.length];
        exponentialFamily.getClass();
        if (exponentialFamily == ExponentialFamily.NORMAL) {
            for (int i = 0; i < dArr.length; i++) {
                this.responseVariance[i] = 1.0d;
            }
        } else if (exponentialFamily == ExponentialFamily.POISSON) {
            for (int i2 = 0; i2 < dArr.length; i2++) {
                this.responseVariance[i2] = dArr[i2];
            }
        } else if (exponentialFamily == ExponentialFamily.BINOMIAL) {
            for (int i3 = 0; i3 < dArr.length; i3++) {
                this.responseVariance[i3] = dArr[i3] * (1.0d - dArr[i3]);
            }
        } else if (exponentialFamily == ExponentialFamily.GAMMA) {
            for (int i4 = 0; i4 < dArr.length; i4++) {
                this.responseVariance[i4] = Math.pow(dArr[i4], 2.0d);
            }
        } else {
            if (exponentialFamily != ExponentialFamily.INVERSE_GAUSSIAN) {
                throw new IllegalArgumentException("Wrong input distribution function.");
            }
            for (int i5 = 0; i5 < dArr.length; i5++) {
                this.responseVariance[i5] = Math.pow(dArr[i5], 3.0d);
            }
        }
        this.output.put(Output.RESPONSE_VARIANCE, this.responseVariance);
        return this.responseVariance;
    }
}
