/*
 * Decompiled with CFR 0.152.
 */
package org.carrot2.text.vsm;

import org.apache.mahout.math.matrix.DoubleMatrix2D;
import org.apache.mahout.math.matrix.impl.DenseDoubleMatrix2D;
import org.carrot2.core.attribute.Processing;
import org.carrot2.matrix.MatrixUtils;
import org.carrot2.matrix.factorization.IMatrixFactorization;
import org.carrot2.matrix.factorization.IMatrixFactorizationFactory;
import org.carrot2.matrix.factorization.IterationNumberGuesser;
import org.carrot2.matrix.factorization.IterativeMatrixFactorizationFactory;
import org.carrot2.matrix.factorization.KMeansMatrixFactorizationFactory;
import org.carrot2.matrix.factorization.LocalNonnegativeMatrixFactorizationFactory;
import org.carrot2.matrix.factorization.NonnegativeMatrixFactorizationEDFactory;
import org.carrot2.matrix.factorization.NonnegativeMatrixFactorizationKLFactory;
import org.carrot2.matrix.factorization.PartialSingularValueDecompositionFactory;
import org.carrot2.text.vsm.ReducedVectorSpaceModelContext;
import org.carrot2.text.vsm.VectorSpaceModelContext;
import org.carrot2.util.attribute.Attribute;
import org.carrot2.util.attribute.AttributeLevel;
import org.carrot2.util.attribute.Bindable;
import org.carrot2.util.attribute.Group;
import org.carrot2.util.attribute.Input;
import org.carrot2.util.attribute.Label;
import org.carrot2.util.attribute.Level;
import org.carrot2.util.attribute.Required;
import org.carrot2.util.attribute.constraint.ImplementingClasses;

@Bindable(prefix="TermDocumentMatrixReducer")
public class TermDocumentMatrixReducer {
    @Input
    @Processing
    @Attribute
    @Required
    @ImplementingClasses(classes={PartialSingularValueDecompositionFactory.class, NonnegativeMatrixFactorizationEDFactory.class, NonnegativeMatrixFactorizationKLFactory.class, LocalNonnegativeMatrixFactorizationFactory.class, KMeansMatrixFactorizationFactory.class}, strict=false)
    @Label(value="Factorization method")
    @Level(value=AttributeLevel.ADVANCED)
    @Group(value="Matrix model")
    public IMatrixFactorizationFactory factorizationFactory = new NonnegativeMatrixFactorizationEDFactory();
    @Input
    @Processing
    @Required
    @Attribute
    @Label(value="Factorization quality")
    @Level(value=AttributeLevel.ADVANCED)
    @Group(value="Matrix model")
    public IterationNumberGuesser.FactorizationQuality factorizationQuality = IterationNumberGuesser.FactorizationQuality.HIGH;

    public void reduce(ReducedVectorSpaceModelContext reducedVectorSpaceModelContext, int n) {
        VectorSpaceModelContext vectorSpaceModelContext = reducedVectorSpaceModelContext.vsmContext;
        if (vectorSpaceModelContext.termDocumentMatrix.columns() == 0 || vectorSpaceModelContext.termDocumentMatrix.rows() == 0) {
            reducedVectorSpaceModelContext.baseMatrix = new DenseDoubleMatrix2D(vectorSpaceModelContext.termDocumentMatrix.rows(), vectorSpaceModelContext.termDocumentMatrix.columns());
            return;
        }
        if (this.factorizationFactory instanceof IterativeMatrixFactorizationFactory) {
            ((IterativeMatrixFactorizationFactory)this.factorizationFactory).setK(n);
            IterationNumberGuesser.setEstimatedIterationsNumber((IterativeMatrixFactorizationFactory)((IterativeMatrixFactorizationFactory)this.factorizationFactory), (DoubleMatrix2D)vectorSpaceModelContext.termDocumentMatrix, (IterationNumberGuesser.FactorizationQuality)this.factorizationQuality);
        }
        MatrixUtils.normalizeColumnL2((DoubleMatrix2D)vectorSpaceModelContext.termDocumentMatrix, null);
        IMatrixFactorization iMatrixFactorization = this.factorizationFactory.factorize(vectorSpaceModelContext.termDocumentMatrix);
        reducedVectorSpaceModelContext.baseMatrix = iMatrixFactorization.getU();
        reducedVectorSpaceModelContext.coefficientMatrix = iMatrixFactorization.getV();
        reducedVectorSpaceModelContext.baseMatrix = this.trim(iMatrixFactorization.getU(), n);
        reducedVectorSpaceModelContext.coefficientMatrix = this.trim(iMatrixFactorization.getV(), n);
    }

    private final DoubleMatrix2D trim(DoubleMatrix2D doubleMatrix2D, int n) {
        if (!(this.factorizationFactory instanceof IterativeMatrixFactorizationFactory) && doubleMatrix2D.columns() > n) {
            return doubleMatrix2D.viewPart(0, 0, doubleMatrix2D.rows(), n);
        }
        return doubleMatrix2D;
    }
}

