/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.codegen;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.stream.IntStream;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.codegen.LibSpoofPrimitives;
import org.apache.sysds.runtime.codegen.SpoofCUDAOperator;
import org.apache.sysds.runtime.codegen.SpoofCUDARowwise;
import org.apache.sysds.runtime.codegen.SpoofOperator;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.DenseBlockFactory;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.data.SparseRow;
import org.apache.sysds.runtime.data.SparseRowVector;
import org.apache.sysds.runtime.instructions.cp.DoubleObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.matrix.data.LibMatrixMult;
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.util.CommonThreadPool;
import org.apache.sysds.runtime.util.UtilFunctions;

public abstract class SpoofRowwise
extends SpoofOperator {
    private static final long serialVersionUID = 6242910797139642998L;
    protected final RowType _type;
    protected final long _constDim2;
    protected final boolean _tB1;
    protected final int _reqVectMem;

    public SpoofRowwise(RowType type, long constDim2, boolean tB1, int reqVectMem) {
        this._type = type;
        this._constDim2 = constDim2;
        this._tB1 = tB1;
        this._reqVectMem = reqVectMem;
    }

    public RowType getRowType() {
        return this._type;
    }

    public long getConstDim2() {
        return this._constDim2;
    }

    public int getNumIntermediates() {
        return this._reqVectMem;
    }

    @Override
    public String getSpoofType() {
        return "RA" + this.getClass().getName().split("\\.")[1];
    }

    @Override
    public SpoofCUDAOperator createCUDAInstrcution(Integer opID, SpoofCUDAOperator.PrecisionProxy ep) {
        return new SpoofCUDARowwise(this._type, this._constDim2, this._tB1, this._reqVectMem, opID, ep);
    }

    @Override
    public ScalarObject execute(ArrayList<MatrixBlock> inputs, ArrayList<ScalarObject> scalarObjects, int k) {
        MatrixBlock out = k > 1 ? this.execute(inputs, scalarObjects, new MatrixBlock(1, 1, false), k) : this.execute(inputs, scalarObjects, new MatrixBlock(1, 1, false));
        return new DoubleObject(out.quickGetValue(0, 0));
    }

    @Override
    public MatrixBlock execute(ArrayList<MatrixBlock> inputs, ArrayList<ScalarObject> scalarObjects, MatrixBlock out) {
        return this.execute(inputs, scalarObjects, out, true, false, 0L);
    }

    public MatrixBlock execute(ArrayList<MatrixBlock> inputs, ArrayList<ScalarObject> scalarObjects, MatrixBlock out, boolean allocTmp, boolean aggIncr, long rix) {
        MatrixBlock a;
        int n2;
        if (inputs == null || inputs.size() < 1 || out == null) {
            throw new RuntimeException("Invalid input arguments.");
        }
        int m = inputs.get(0).getNumRows();
        int n = inputs.get(0).getNumColumns();
        int n3 = this._type.isConstDim2(this._constDim2) ? (int)this._constDim2 : (n2 = this._type.isRowTypeB1() || SpoofRowwise.hasMatrixSideInput(inputs) ? SpoofRowwise.getMinColsMatrixSideInputs(inputs) : -1);
        if (!aggIncr || !out.isAllocated()) {
            this.allocateOutputMatrix(m, n, n2, out);
        }
        DenseBlock c = out.getDenseBlock();
        boolean flipOut = this._type.isRowTypeB1ColumnAgg() && LibSpoofPrimitives.isFlipOuter(out.getNumRows(), out.getNumColumns());
        SpoofOperator.SideInput[] b = this.prepInputMatrices(inputs, 1, inputs.size() - 1, false, this._tB1);
        double[] scalars = SpoofRowwise.prepInputScalars(scalarObjects);
        if (allocTmp && this._reqVectMem > 0) {
            LibSpoofPrimitives.setupThreadLocalMemory(this._reqVectMem, n, n2);
        }
        if (!(a = inputs.get(0)).isInSparseFormat()) {
            this.executeDense(a.getDenseBlock(), b, scalars, c, n, 0, m, rix);
        } else {
            this.executeSparse(a.getSparseBlock(), b, scalars, c, n, 0, m, rix);
        }
        if (allocTmp && this._reqVectMem > 0) {
            LibSpoofPrimitives.cleanupThreadLocalMemory();
        }
        if (flipOut) {
            SpoofRowwise.fixTransposeDimensions(out);
            out = LibMatrixReorg.transpose(out, new MatrixBlock(out.getNumColumns(), out.getNumRows(), false));
        }
        if (!aggIncr) {
            out.recomputeNonZeros();
            out.examSparsity();
        }
        return out;
    }

    @Override
    public MatrixBlock execute(ArrayList<MatrixBlock> inputs, ArrayList<ScalarObject> scalarObjects, MatrixBlock out, int k) {
        if (k <= 1 || this._type.isColumnAgg() && !LibMatrixMult.satisfiesMultiThreadingConstraints(inputs.get(0), k) || SpoofRowwise.getTotalInputSize(inputs) < 0x100000L) {
            return this.execute(inputs, scalarObjects, out);
        }
        if (inputs == null || inputs.size() < 1 || out == null) {
            throw new RuntimeException("Invalid input arguments.");
        }
        int m = inputs.get(0).getNumRows();
        int n = inputs.get(0).getNumColumns();
        int n2 = this._type.isConstDim2(this._constDim2) ? (int)this._constDim2 : (this._type.isRowTypeB1() || SpoofRowwise.hasMatrixSideInput(inputs) ? SpoofRowwise.getMinColsMatrixSideInputs(inputs) : -1);
        this.allocateOutputMatrix(m, n, n2, out);
        boolean flipOut = this._type.isRowTypeB1ColumnAgg() && LibSpoofPrimitives.isFlipOuter(out.getNumRows(), out.getNumColumns());
        MatrixBlock a = inputs.get(0);
        SpoofOperator.SideInput[] b = this.prepInputMatrices(inputs, 1, inputs.size() - 1, false, this._tB1);
        double[] scalars = SpoofRowwise.prepInputScalars(scalarObjects);
        ExecutorService pool = CommonThreadPool.get(k);
        ArrayList<Integer> blklens = UtilFunctions.getBalancedBlockSizesDefault(m, k, (long)m * (long)n < 0x1000000L);
        try {
            if (this._type.isColumnAgg() || this._type == RowType.FULL_AGG) {
                ArrayList<ParColAggTask> tasks = new ArrayList<ParColAggTask>();
                int outLen = out.getNumRows() * out.getNumColumns();
                int lb = 0;
                for (int i = 0; i < blklens.size(); ++i) {
                    tasks.add(new ParColAggTask(a, b, scalars, n, n2, outLen, lb, lb + blklens.get(i)));
                    lb += blklens.get(i).intValue();
                }
                List taskret = pool.invokeAll(tasks);
                int len = this._type.isColumnAgg() ? out.getNumRows() * out.getNumColumns() : 1;
                for (Future task : taskret) {
                    LibMatrixMult.vectAdd(((DenseBlock)task.get()).valuesAt(0), out.getDenseBlockValues(), 0, 0, len);
                }
                out.recomputeNonZeros();
            } else {
                ArrayList<ParExecTask> tasks = new ArrayList<ParExecTask>();
                int lb = 0;
                for (int i = 0; i < blklens.size(); ++i) {
                    tasks.add(new ParExecTask(a, b, out, scalars, n, n2, lb, lb + blklens.get(i)));
                    lb += blklens.get(i).intValue();
                }
                List taskret = pool.invokeAll(tasks);
                long nnz = 0L;
                for (Future task : taskret) {
                    nnz += ((Long)task.get()).longValue();
                }
                out.setNonZeros(nnz);
            }
            pool.shutdown();
            if (flipOut) {
                SpoofRowwise.fixTransposeDimensions(out);
                out = LibMatrixReorg.transpose(out, new MatrixBlock(out.getNumColumns(), out.getNumRows(), false));
            }
            out.examSparsity();
        }
        catch (Exception ex) {
            throw new DMLRuntimeException(ex);
        }
        return out;
    }

    public static boolean hasMatrixSideInput(ArrayList<MatrixBlock> inputs) {
        return IntStream.range(1, inputs.size()).mapToObj(i -> (MatrixBlock)inputs.get(i)).anyMatch(in -> in.getNumColumns() > 1);
    }

    protected static int getMinColsMatrixSideInputs(ArrayList<MatrixBlock> inputs) {
        return IntStream.range(1, inputs.size()).map(i -> ((MatrixBlock)inputs.get(i)).getNumColumns()).filter(ncol -> ncol > 1).min().orElse(1);
    }

    public static boolean hasMatrixObjectSideInput(ArrayList<MatrixObject> inputs) {
        return IntStream.range(1, inputs.size()).mapToObj(i -> (MatrixObject)inputs.get(i)).anyMatch(in -> in.getNumColumns() > 1L);
    }

    protected static int getMinColsMatrixObjectSideInputs(ArrayList<MatrixObject> inputs) {
        return IntStream.range(1, inputs.size()).map(i -> (int)((MatrixObject)inputs.get(i)).getNumColumns()).filter(ncol -> ncol > 1).min().orElse(1);
    }

    private void allocateOutputMatrix(int m, int n, int n2, MatrixBlock out) {
        OutputDimensions dims = new OutputDimensions(m, n, n2);
        out.reset(dims.rows, dims.cols, false);
        out.allocateDenseBlock();
    }

    private static void fixTransposeDimensions(MatrixBlock out) {
        int rlen = out.getNumRows();
        out.setNumRows(out.getNumColumns());
        out.setNumColumns(rlen);
        out.setNonZeros(out.getNumRows() * out.getNumColumns());
    }

    private void executeDense(DenseBlock a, SpoofOperator.SideInput[] b, double[] scalars, DenseBlock c, int n, int rl, int ru, long rix) {
        if (a == null) {
            this.executeSparse(null, b, scalars, c, n, rl, ru, rix);
            return;
        }
        SpoofOperator.SideInput[] lb = SpoofRowwise.createSparseSideInputs(b, true);
        for (int i = rl; i < ru; ++i) {
            this.genexec(a.values(i), a.pos(i), lb, scalars, c.values(i), c.pos(i), n, rix + (long)i, i);
        }
    }

    private void executeSparse(SparseBlock a, SpoofOperator.SideInput[] b, double[] scalars, DenseBlock c, int n, int rl, int ru, long rix) {
        SpoofOperator.SideInput[] lb = SpoofRowwise.createSparseSideInputs(b, true);
        SparseRowVector empty = new SparseRowVector(1);
        for (int i = rl; i < ru; ++i) {
            if (a != null && !a.isEmpty(i)) {
                this.genexec(a.values(i), a.indexes(i), a.pos(i), lb, scalars, c.values(i), c.pos(i), a.size(i), n, rix + (long)i, i);
                continue;
            }
            this.genexec(((SparseRow)empty).values(), ((SparseRow)empty).indexes(), 0, lb, scalars, c.values(i), c.pos(i), 0, n, rix + (long)i, i);
        }
    }

    protected final void genexec(double[] a, int ai, SpoofOperator.SideInput[] b, double[] scalars, double[] c, int ci, int len, int rix) {
        this.genexec(a, ai, b, scalars, c, ci, len, rix, rix);
    }

    protected final void genexec(double[] avals, int[] aix, int ai, SpoofOperator.SideInput[] b, double[] scalars, double[] c, int ci, int alen, int n, int rix) {
        this.genexec(avals, aix, ai, b, scalars, c, ci, alen, n, rix, rix);
    }

    protected abstract void genexec(double[] var1, int var2, SpoofOperator.SideInput[] var3, double[] var4, double[] var5, int var6, int var7, long var8, int var10);

    protected abstract void genexec(double[] var1, int[] var2, int var3, SpoofOperator.SideInput[] var4, double[] var5, double[] var6, int var7, int var8, int var9, long var10, int var12);

    private class ParExecTask
    implements Callable<Long> {
        private final MatrixBlock _a;
        private final SpoofOperator.SideInput[] _b;
        private final MatrixBlock _c;
        private final double[] _scalars;
        private final int _clen;
        private final int _clen2;
        private final int _rl;
        private final int _ru;

        protected ParExecTask(MatrixBlock a, SpoofOperator.SideInput[] b, MatrixBlock c, double[] scalars, int clen, int clen2, int rl, int ru) {
            this._a = a;
            this._b = b;
            this._c = c;
            this._scalars = scalars;
            this._clen = clen;
            this._clen2 = clen2;
            this._rl = rl;
            this._ru = ru;
        }

        @Override
        public Long call() {
            if (SpoofRowwise.this._reqVectMem > 0) {
                LibSpoofPrimitives.setupThreadLocalMemory(SpoofRowwise.this._reqVectMem, this._clen, this._clen2);
            }
            if (!this._a.isInSparseFormat()) {
                SpoofRowwise.this.executeDense(this._a.getDenseBlock(), this._b, this._scalars, this._c.getDenseBlock(), this._clen, this._rl, this._ru, 0L);
            } else {
                SpoofRowwise.this.executeSparse(this._a.getSparseBlock(), this._b, this._scalars, this._c.getDenseBlock(), this._clen, this._rl, this._ru, 0L);
            }
            if (SpoofRowwise.this._reqVectMem > 0) {
                LibSpoofPrimitives.cleanupThreadLocalMemory();
            }
            return this._c.recomputeNonZeros(this._rl, this._ru - 1, 0, this._c.getNumColumns() - 1);
        }
    }

    private class ParColAggTask
    implements Callable<DenseBlock> {
        private final MatrixBlock _a;
        private final SpoofOperator.SideInput[] _b;
        private final double[] _scalars;
        private final int _clen;
        private final int _clen2;
        private final int _outLen;
        private final int _rl;
        private final int _ru;

        protected ParColAggTask(MatrixBlock a, SpoofOperator.SideInput[] b, double[] scalars, int clen, int clen2, int outLen, int rl, int ru) {
            this._a = a;
            this._b = b;
            this._scalars = scalars;
            this._clen = clen;
            this._clen2 = clen2;
            this._outLen = outLen;
            this._rl = rl;
            this._ru = ru;
        }

        @Override
        public DenseBlock call() {
            if (SpoofRowwise.this._reqVectMem > 0) {
                LibSpoofPrimitives.setupThreadLocalMemory(SpoofRowwise.this._reqVectMem, this._clen, this._clen2);
            }
            DenseBlock c = DenseBlockFactory.createDenseBlock(1, this._outLen);
            if (!this._a.isInSparseFormat()) {
                SpoofRowwise.this.executeDense(this._a.getDenseBlock(), this._b, this._scalars, c, this._clen, this._rl, this._ru, 0L);
            } else {
                SpoofRowwise.this.executeSparse(this._a.getSparseBlock(), this._b, this._scalars, c, this._clen, this._rl, this._ru, 0L);
            }
            if (SpoofRowwise.this._reqVectMem > 0) {
                LibSpoofPrimitives.cleanupThreadLocalMemory();
            }
            return c;
        }
    }

    protected class OutputDimensions {
        public final int rows;
        public final int cols;

        OutputDimensions(int m, int n, int n2) {
            switch (SpoofRowwise.this._type) {
                case NO_AGG: {
                    this.rows = m;
                    this.cols = n;
                    break;
                }
                case NO_AGG_B1: {
                    this.rows = m;
                    this.cols = n2;
                    break;
                }
                case NO_AGG_CONST: {
                    this.rows = m;
                    this.cols = (int)SpoofRowwise.this._constDim2;
                    break;
                }
                case FULL_AGG: {
                    this.rows = 1;
                    this.cols = 1;
                    break;
                }
                case ROW_AGG: {
                    this.rows = m;
                    this.cols = 1;
                    break;
                }
                case COL_AGG: {
                    this.rows = 1;
                    this.cols = n;
                    break;
                }
                case COL_AGG_T: {
                    this.rows = n;
                    this.cols = 1;
                    break;
                }
                case COL_AGG_B1: {
                    this.rows = n2;
                    this.cols = n;
                    break;
                }
                case COL_AGG_B1_T: {
                    this.rows = n;
                    this.cols = n2;
                    break;
                }
                case COL_AGG_B1R: {
                    this.rows = 1;
                    this.cols = n2;
                    break;
                }
                case COL_AGG_CONST: {
                    this.rows = 1;
                    this.cols = (int)SpoofRowwise.this._constDim2;
                    break;
                }
                default: {
                    this.rows = 0;
                    this.cols = 0;
                }
            }
        }
    }

    public static enum RowType {
        NO_AGG(0),
        NO_AGG_B1(1),
        NO_AGG_CONST(2),
        FULL_AGG(3),
        ROW_AGG(4),
        COL_AGG(5),
        COL_AGG_T(6),
        COL_AGG_B1(7),
        COL_AGG_B1_T(8),
        COL_AGG_B1R(9),
        COL_AGG_CONST(10);

        private final int value;
        private static final HashMap<Integer, RowType> map;

        private RowType(int value) {
            this.value = value;
        }

        public static RowType valueOf(int rowType) {
            return map.get(rowType);
        }

        public int getValue() {
            return this.value;
        }

        public boolean isColumnAgg() {
            return this == COL_AGG || this == COL_AGG_T || this == COL_AGG_B1 || this == COL_AGG_B1_T || this == COL_AGG_B1R || this == COL_AGG_CONST;
        }

        public boolean isRowTypeB1() {
            return this == NO_AGG_B1 || this == COL_AGG_B1 || this == COL_AGG_B1_T || this == COL_AGG_B1R;
        }

        public boolean isRowTypeB1ColumnAgg() {
            return this == COL_AGG_B1 || this == COL_AGG_B1_T;
        }

        public boolean isConstDim2(long dim2) {
            return this == NO_AGG_CONST || this == COL_AGG_CONST || dim2 >= 0L && this.isRowTypeB1();
        }

        static {
            map = new HashMap();
            for (RowType rowType : RowType.values()) {
                map.put(rowType.value, rowType);
            }
        }
    }
}

