/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.compress.colgroup;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import org.apache.commons.lang.NotImplementedException;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed;
import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty;
import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory;
import org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary;
import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator;
import org.apache.sysds.runtime.compress.utils.Util;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
import org.apache.sysds.runtime.matrix.data.LibMatrixMult;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.matrix.operators.CMOperator;
import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
import org.apache.sysds.runtime.matrix.operators.UnaryOperator;

public class ColGroupConst
extends AColGroupCompressed {
    private static final long serialVersionUID = -7387793538322386611L;
    protected ADictionary _dict;

    protected ColGroupConst() {
    }

    private ColGroupConst(int[] colIndices, ADictionary dict) {
        super(colIndices);
        this._dict = dict;
    }

    protected static AColGroup create(int[] colIndices, ADictionary dict) {
        if (dict == null) {
            return new ColGroupEmpty(colIndices);
        }
        return new ColGroupConst(colIndices, dict);
    }

    public static AColGroup create(double[] values) {
        int[] colIndices = Util.genColsIndices(values.length);
        return ColGroupConst.create(colIndices, values);
    }

    public static AColGroup create(int[] cols, double value) {
        int numCols = cols.length;
        double[] values = new double[numCols];
        for (int i = 0; i < numCols; ++i) {
            values[i] = value;
        }
        return ColGroupConst.create(cols, values);
    }

    public static AColGroup create(int[] cols, double[] values) {
        if (cols.length != values.length) {
            throw new DMLCompressionException("Invalid size of values compared to columns");
        }
        Dictionary dict = new Dictionary(values);
        return ColGroupConst.create(cols, (ADictionary)dict);
    }

    public static AColGroup create(int numCols, ADictionary dict) {
        if (numCols != dict.getValues().length) {
            throw new DMLCompressionException("Invalid construction of const column group with different number of columns in arguments");
        }
        int[] colIndices = Util.genColsIndices(numCols);
        return ColGroupConst.create(colIndices, dict);
    }

    public static AColGroup create(int numCols, double value) {
        if (numCols <= 0) {
            throw new DMLCompressionException("Invalid construction of constant column group with cols: " + numCols);
        }
        int[] colIndices = Util.genColsIndices(numCols);
        if (value == 0.0) {
            return new ColGroupEmpty(colIndices);
        }
        return ColGroupConst.create(colIndices, value);
    }

    @Override
    protected void computeRowMxx(double[] c, Builtin builtin, int rl, int ru, double[] preAgg) {
        double v = preAgg[0];
        for (int i = rl; i < ru; ++i) {
            c[i] = builtin.execute(c[i], v);
        }
    }

    @Override
    public AColGroup.CompressionType getCompType() {
        return AColGroup.CompressionType.CONST;
    }

    @Override
    public AColGroup.ColGroupType getColGroupType() {
        return AColGroup.ColGroupType.CONST;
    }

    @Override
    public void decompressToDenseBlock(DenseBlock db, int rl, int ru, int offR, int offC) {
        if (db.isContiguous() && this._colIndexes.length == db.getDim(1) && offC == 0) {
            this.decompressToDenseBlockAllColumnsContiguous(db, rl, ru, offR, offC);
        } else {
            this.decompressToDenseBlockGeneric(db, rl, ru, offR, offC);
        }
    }

    private void decompressToDenseBlockAllColumnsContiguous(DenseBlock db, int rl, int ru, int offR, int offC) {
        double[] c = db.values(0);
        int nCol = this._colIndexes.length;
        double[] values = this._dict.getValues();
        for (int r = rl; r < ru; ++r) {
            int offStart = (offR + r) * nCol;
            int off = offStart;
            for (int vOff = 0; vOff < nCol; ++vOff) {
                int n = off++;
                c[n] = c[n] + values[vOff];
            }
        }
    }

    private void decompressToDenseBlockGeneric(DenseBlock db, int rl, int ru, int offR, int offC) {
        int i = rl;
        int offT = rl + offR;
        while (i < ru) {
            double[] c = db.values(offT);
            int off = db.pos(offT) + offC;
            for (int j = 0; j < this._colIndexes.length; ++j) {
                int n = off + this._colIndexes[j];
                c[n] = c[n] + this._dict.getValue(j);
            }
            ++i;
            ++offT;
        }
    }

    @Override
    public void decompressToSparseBlock(SparseBlock ret, int rl, int ru, int offR, int offC) {
        int nCol = this._colIndexes.length;
        int i = rl;
        int offT = rl + offR;
        while (i < ru) {
            for (int j = 0; j < nCol; ++j) {
                ret.append(offT, this._colIndexes[j] + offC, this._dict.getValue(j));
            }
            ++i;
            ++offT;
        }
    }

    @Override
    public double getIdx(int r, int colIdx) {
        return this._dict.getValue(colIdx);
    }

    @Override
    public AColGroup scalarOperation(ScalarOperator op) {
        return ColGroupConst.create(this._colIndexes, this._dict.applyScalarOp(op));
    }

    @Override
    public AColGroup unaryOperation(UnaryOperator op) {
        return ColGroupConst.create(this._colIndexes, this._dict.applyUnaryOp(op));
    }

    @Override
    public AColGroup binaryRowOpLeft(BinaryOperator op, double[] v, boolean isRowSafe) {
        return ColGroupConst.create(this._colIndexes, this._dict.binOpLeft(op, v, this._colIndexes));
    }

    @Override
    public AColGroup binaryRowOpRight(BinaryOperator op, double[] v, boolean isRowSafe) {
        return ColGroupConst.create(this._colIndexes, this._dict.binOpRight(op, v, this._colIndexes));
    }

    public void addToCommon(double[] constV) {
        double[] values = this._dict.getValues();
        for (int i = 0; i < this._colIndexes.length; ++i) {
            int n = this._colIndexes[i];
            constV[n] = constV[n] + values[i];
        }
    }

    public double[] getValues() {
        return this._dict.getValues();
    }

    @Override
    protected double computeMxx(double c, Builtin builtin) {
        return this._dict.aggregate(c, builtin);
    }

    @Override
    protected void computeColMxx(double[] c, Builtin builtin) {
        this._dict.aggregateCols(c, builtin, this._colIndexes);
    }

    @Override
    protected void computeSum(double[] c, int nRows) {
        c[0] = c[0] + this._dict.sum(new int[]{nRows}, this._colIndexes.length);
    }

    @Override
    public void computeColSums(double[] c, int nRows) {
        this._dict.colSum(c, new int[]{nRows}, this._colIndexes);
    }

    @Override
    protected void computeSumSq(double[] c, int nRows) {
        c[0] = c[0] + this._dict.sumSq(new int[]{nRows}, this._colIndexes.length);
    }

    @Override
    protected void computeColSumsSq(double[] c, int nRows) {
        this._dict.colSumSq(c, new int[]{nRows}, this._colIndexes);
    }

    @Override
    protected void computeRowSums(double[] c, int rl, int ru, double[] preAgg) {
        double vals = preAgg[0];
        int rix = rl;
        while (rix < ru) {
            int n = rix++;
            c[n] = c[n] + vals;
        }
    }

    @Override
    public int getNumValues() {
        return 1;
    }

    private synchronized MatrixBlock forceValuesToMatrixBlock() {
        this._dict = this._dict.getMBDict(this._colIndexes.length);
        MatrixBlock ret = ((MatrixBlockDictionary)this._dict).getMatrixBlock();
        return ret;
    }

    @Override
    public AColGroup rightMultByMatrix(MatrixBlock right) {
        if (right.isEmpty()) {
            return null;
        }
        int rr = right.getNumRows();
        int cr = right.getNumColumns();
        if (this._colIndexes.length == rr) {
            MatrixBlock left = this.forceValuesToMatrixBlock();
            if (left == null) {
                return null;
            }
            MatrixBlock ret = new MatrixBlock(1, cr, false);
            LibMatrixMult.matrixMult(left, right, ret);
            if (ret.isEmpty()) {
                return null;
            }
            MatrixBlockDictionary d = new MatrixBlockDictionary(ret, cr);
            return ColGroupConst.create(cr, (ADictionary)d);
        }
        throw new NotImplementedException();
    }

    @Override
    public void tsmm(double[] result, int numColumns, int nRows) {
        ColGroupConst.tsmm(result, numColumns, new int[]{nRows}, this._dict, this._colIndexes);
    }

    @Override
    public void leftMultByMatrixNoPreAgg(MatrixBlock matrix, MatrixBlock result, int rl, int ru, int cl, int cu) {
        throw new DMLCompressionException("This method should never be called");
    }

    @Override
    public void leftMultByAColGroup(AColGroup lhs, MatrixBlock result) {
        throw new DMLCompressionException("Should not be called");
    }

    @Override
    public void tsmmAColGroup(AColGroup other, MatrixBlock result) {
        throw new DMLCompressionException("Should not be called");
    }

    @Override
    protected AColGroup sliceSingleColumn(int idx) {
        int[] colIndexes = new int[]{0};
        double v = this._dict.getValue(idx);
        if (v == 0.0) {
            return new ColGroupEmpty(colIndexes);
        }
        Dictionary retD = new Dictionary(new double[]{this._dict.getValue(idx)});
        return ColGroupConst.create(colIndexes, (ADictionary)retD);
    }

    @Override
    protected AColGroup sliceMultiColumns(int idStart, int idEnd, int[] outputCols) {
        ADictionary retD = this._dict.sliceOutColumnRange(idStart, idEnd, this._colIndexes.length);
        return ColGroupConst.create(outputCols, retD);
    }

    @Override
    public AColGroup copy() {
        return ColGroupConst.create(this._colIndexes, this._dict.clone());
    }

    @Override
    public boolean containsValue(double pattern) {
        return this._dict.containsValue(pattern);
    }

    @Override
    public long getNumberNonZeros(int nRows) {
        return this._dict.getNumberNonZeros(new int[]{nRows}, this._colIndexes.length);
    }

    @Override
    public AColGroup replace(double pattern, double replace) {
        ADictionary replaced = this._dict.replace(pattern, replace, this._colIndexes.length);
        return ColGroupConst.create(this._colIndexes, replaced);
    }

    @Override
    public void readFields(DataInput in) throws IOException {
        super.readFields(in);
        this._dict = DictionaryFactory.read(in);
    }

    @Override
    public void write(DataOutput out) throws IOException {
        super.write(out);
        this._dict.write(out);
    }

    @Override
    public long getExactSizeOnDisk() {
        long ret = super.getExactSizeOnDisk();
        return ret += this._dict.getExactSizeOnDisk();
    }

    @Override
    protected void computeProduct(double[] c, int nRows) {
        this._dict.product(c, new int[]{nRows}, this._colIndexes.length);
    }

    @Override
    protected void computeRowProduct(double[] c, int rl, int ru, double[] preAgg) {
        throw new NotImplementedException();
    }

    @Override
    protected void computeColProduct(double[] c, int nRows) {
        throw new NotImplementedException();
    }

    @Override
    protected double[] preAggSumRows() {
        return this._dict.sumAllRowsToDouble(this._colIndexes.length);
    }

    @Override
    protected double[] preAggSumSqRows() {
        return this._dict.sumAllRowsToDoubleSq(this._colIndexes.length);
    }

    @Override
    protected double[] preAggProductRows() {
        throw new NotImplementedException();
    }

    @Override
    protected double[] preAggBuiltinRows(Builtin builtin) {
        return this._dict.aggregateRows(builtin, this._colIndexes.length);
    }

    @Override
    public long estimateInMemorySize() {
        long size = super.estimateInMemorySize();
        size += this._dict.getInMemorySize();
        return size += 8L;
    }

    @Override
    public CM_COV_Object centralMoment(CMOperator op, int nRows) {
        CM_COV_Object ret = new CM_COV_Object();
        op.fn.execute(ret, this._dict.getValue(0), nRows);
        return ret;
    }

    @Override
    public AColGroup rexpandCols(int max, boolean ignore, boolean cast, int nRows) {
        ADictionary d = this._dict.rexpandCols(max, ignore, cast, this._colIndexes.length);
        if (d == null) {
            return ColGroupEmpty.create(max);
        }
        return ColGroupConst.create(max, d);
    }

    @Override
    public double getCost(ComputationCostEstimator e, int nRows) {
        int nCols = this.getNumCols();
        return e.getCost(nRows, 1, nCols, 1, 1.0);
    }

    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(super.toString());
        sb.append(String.format("\n%15s", "Values: " + this._dict.getClass().getSimpleName()));
        sb.append(this._dict.getString(this._colIndexes.length));
        return sb.toString();
    }
}

