/*
 * Decompiled with CFR 0.152.
 */
package net.loomchild.maligna.filter.aligner.align.hmm.fb;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import net.loomchild.maligna.calculator.Calculator;
import net.loomchild.maligna.coretypes.Alignment;
import net.loomchild.maligna.coretypes.Category;
import net.loomchild.maligna.filter.aligner.align.AlignAlgorithm;
import net.loomchild.maligna.filter.aligner.align.hmm.Util;
import net.loomchild.maligna.matrix.Matrix;
import net.loomchild.maligna.matrix.MatrixFactory;
import net.loomchild.maligna.matrix.MatrixIterator;
import net.loomchild.maligna.progress.ProgressManager;
import net.loomchild.maligna.progress.ProgressMeter;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public class ForwardBackwardAlgorithm
implements AlignAlgorithm {
    private Log log = LogFactory.getLog(ForwardBackwardAlgorithm.class);
    private Map<Category, Float> categoryMap;
    private Calculator calculator;
    private MatrixFactory matrixFactory;

    public ForwardBackwardAlgorithm(Calculator calculator, Map<Category, Float> categoryMap, MatrixFactory matrixFactory) {
        this.matrixFactory = matrixFactory;
        this.calculator = calculator;
        this.categoryMap = categoryMap;
    }

    @Override
    public List<Alignment> align(List<String> sourceSegmentList, List<String> targetSegmentList) {
        Matrix<Float> forwardMatrix = this.matrixFactory.createMatrix(sourceSegmentList.size() + 1, targetSegmentList.size() + 1);
        ProgressMeter progress = new ProgressMeter("Forward-Backward Align", forwardMatrix.getSize() * 2);
        ProgressManager.getInstance().registerProgressMeter(progress);
        MatrixIterator forwardIterator = forwardMatrix.getIterator();
        while (forwardIterator.hasNext()) {
            forwardIterator.next();
            int x = forwardIterator.getX();
            int y = forwardIterator.getY();
            float data = this.createForwardData(x, y, sourceSegmentList, targetSegmentList, forwardMatrix);
            forwardMatrix.set(x, y, Float.valueOf(data));
            progress.completeTask();
        }
        Matrix<Float> backwardMatrix = this.matrixFactory.createMatrix(sourceSegmentList.size() + 1, targetSegmentList.size() + 1);
        MatrixIterator backwardIterator = backwardMatrix.getIterator();
        backwardIterator.afterLast();
        while (backwardIterator.hasPrevious()) {
            backwardIterator.previous();
            int x = backwardIterator.getX();
            int y = backwardIterator.getY();
            float data = this.createBackwardData(x, y, sourceSegmentList, targetSegmentList, backwardMatrix);
            backwardMatrix.set(x, y, Float.valueOf(data));
            progress.completeTask();
        }
        ArrayList<Alignment> alignmentList = new ArrayList<Alignment>();
        float totalScore = forwardMatrix.get(sourceSegmentList.size(), targetSegmentList.size()).floatValue();
        int x = 0;
        int y = 0;
        while (x < sourceSegmentList.size() || y < targetSegmentList.size()) {
            float bestScore = Float.POSITIVE_INFINITY;
            Category bestCategory = null;
            for (Category category : this.categoryMap.keySet()) {
                float backwardScore;
                float forwardScore;
                float score;
                int newX = x + category.getSourceSegmentCount();
                int newY = y + category.getTargetSegmentCount();
                if (newX > sourceSegmentList.size() || newY > targetSegmentList.size() || forwardMatrix.get(newX, newY) == null || backwardMatrix.get(newX, newY) == null || !((score = (forwardScore = forwardMatrix.get(newX, newY).floatValue()) + (backwardScore = backwardMatrix.get(newX, newY).floatValue()) - totalScore) < bestScore)) continue;
                bestScore = score;
                bestCategory = category;
            }
            List<String> sourceList = this.createSubList(sourceSegmentList, x, x + bestCategory.getSourceSegmentCount());
            List<String> targetList = this.createSubList(targetSegmentList, y, y + bestCategory.getTargetSegmentCount());
            Alignment alignment = new Alignment(sourceList, targetList, bestScore);
            alignmentList.add(alignment);
            this.log.trace((Object)("(" + (x += bestCategory.getSourceSegmentCount()) + ", " + (y += bestCategory.getTargetSegmentCount()) + ") - s: " + bestScore + " (" + Math.exp(-bestScore) + ")"));
        }
        ProgressManager.getInstance().unregisterProgressMeter(progress);
        return alignmentList;
    }

    private float createForwardData(int x, int y, List<String> sourceSegmentList, List<String> targetSegmentList, Matrix<Float> matrix) {
        ArrayList<Float> scoreList = new ArrayList<Float>(this.categoryMap.size());
        for (Map.Entry<Category, Float> entry : this.categoryMap.entrySet()) {
            int startY;
            Category category = entry.getKey();
            float categoryScore = entry.getValue().floatValue();
            int startX = x - category.getSourceSegmentCount();
            if (!Util.elementExists(matrix, startX, startY = y - category.getTargetSegmentCount())) continue;
            List<String> sourceList = sourceSegmentList.subList(startX, x);
            List<String> targetList = targetSegmentList.subList(startY, y);
            float score = categoryScore + this.calculator.calculateScore(sourceList, targetList);
            float totalScore = score + matrix.get(startX, startY).floatValue();
            scoreList.add(Float.valueOf(totalScore));
        }
        float scoreSum = net.loomchild.maligna.util.Util.scoreSum(scoreList);
        return scoreSum;
    }

    private float createBackwardData(int x, int y, List<String> sourceSegmentList, List<String> targetSegmentList, Matrix<Float> matrix) {
        ArrayList<Float> scoreList = new ArrayList<Float>(this.categoryMap.size());
        for (Map.Entry<Category, Float> entry : this.categoryMap.entrySet()) {
            int endY;
            Category category = entry.getKey();
            float categoryScore = entry.getValue().floatValue();
            int endX = x + category.getSourceSegmentCount();
            if (!Util.elementExists(matrix, endX, endY = y + category.getTargetSegmentCount())) continue;
            List<String> sourceList = sourceSegmentList.subList(x, endX);
            List<String> targetList = targetSegmentList.subList(y, endY);
            float score = categoryScore + this.calculator.calculateScore(sourceList, targetList);
            float totalScore = score + matrix.get(endX, endY).floatValue();
            scoreList.add(Float.valueOf(totalScore));
        }
        float scoreSum = net.loomchild.maligna.util.Util.scoreSum(scoreList);
        return scoreSum;
    }

    private List<String> createSubList(List<String> list, int start, int end) {
        return new ArrayList<String>(list.subList(start, end));
    }
}

