/*
 * Decompiled with CFR 0.152.
 */
package calhoun.analysis.crf.solver;

import calhoun.analysis.crf.CRFInference;
import calhoun.analysis.crf.ModelManager;
import calhoun.analysis.crf.io.InputSequence;
import calhoun.analysis.crf.solver.CacheProcessor;
import calhoun.analysis.crf.solver.RecyclingBuffer;
import calhoun.util.Assert;
import calhoun.util.ColtUtil;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public class SemiMarkovViterbi
implements CRFInference {
    private static final Log log = LogFactory.getLog(SemiMarkovViterbi.class);
    boolean debug = log.isDebugEnabled();
    double[] lambda;
    private double[] bestScore;
    private int[] backPointers;
    int nStates;
    CacheProcessor.SolverSetup modelInfo;
    CacheProcessor cacheProcessor;
    CacheProcessor.FeatureEvaluation[] evals;
    CacheProcessor.LengthFeatureEvaluation[][] lengthEvals;
    boolean[] invalidTransitions;
    int[] selfTransitions;
    RecyclingBuffer<double[]> stableStates;
    double[] stableVector;

    public CacheProcessor getCacheProcessor() {
        return this.cacheProcessor;
    }

    public void setCacheProcessor(CacheProcessor cacheProcessor) {
        this.cacheProcessor = cacheProcessor;
    }

    @Override
    public CRFInference.InferenceResult predict(ModelManager fm, InputSequence<?> seq, double[] lambda) {
        this.lambda = lambda;
        this.cacheProcessor.setInputData(fm, seq);
        this.modelInfo = this.cacheProcessor.getSolverSetup();
        this.nStates = this.modelInfo.nStates;
        Assert.a(this.modelInfo.maxStateLengths.length == this.nStates, "Maximum state lengths array was length (" + this.modelInfo.maxStateLengths.length + ").  Must have one entry for each state " + this.modelInfo.nStates + ")");
        this.evals = this.cacheProcessor.getFeatureEvaluations();
        this.lengthEvals = this.cacheProcessor.getLengthFeatureEvaluations();
        this.invalidTransitions = this.cacheProcessor.getInvalidTransitions();
        int len = seq.length();
        this.selfTransitions = new int[this.nStates];
        for (int i = 0; i < this.nStates; ++i) {
            this.selfTransitions[i] = this.modelInfo.selfTransitions[i];
        }
        RecyclingBuffer<double[]> mis = new RecyclingBuffer<double[]>((T[])new double[this.modelInfo.maxLookback][this.modelInfo.nTransitions]);
        double[] nextMi = new double[this.modelInfo.nTransitions];
        this.stableStates = new RecyclingBuffer((T[])new double[this.modelInfo.maxLookback][this.nStates]);
        this.stableVector = new double[this.nStates];
        this.bestScore = new double[len * this.nStates];
        this.backPointers = new int[len * this.nStates];
        int[] backLengths = new int[len * this.nStates];
        for (int pos = 0; pos < len; ++pos) {
            if (pos == 0) {
                this.computeSparseMi(seq, pos, null, this.stableVector);
                this.stableVector = this.stableStates.addFirst(this.stableVector);
            } else {
                this.computeSparseMi(seq, pos, nextMi, null);
                this.updateStableBuffer(nextMi);
                nextMi = mis.addFirst(nextMi);
            }
            double[] latestStable = this.stableStates.get(0);
            double[] latestMi = (double[])mis.get(0);
            for (CacheProcessor.StatePotentials potentials : this.modelInfo.statesWithoutLookback) {
                byte state = potentials.state;
                double max = Double.NEGATIVE_INFINITY;
                int invalidIndex = pos * this.modelInfo.nPotentials;
                int bestLookback = 0;
                int bestPrevState = -2;
                if (!this.invalidTransitions[invalidIndex + state]) {
                    if (pos == 0) {
                        if (this.debug) {
                            log.debug((Object)String.format("Pos: %d State: %d %.2f", pos, (int)state, latestStable[state]));
                        }
                        max = latestStable[state];
                        bestPrevState = -1;
                    } else {
                        for (byte edgePotential : potentials.potentials) {
                            if (this.invalidTransitions[invalidIndex + state]) continue;
                            int transition = edgePotential - this.nStates;
                            int prevState = this.modelInfo.transitionFrom[transition];
                            double transitionCost = latestMi[transition];
                            if (Double.isInfinite(transitionCost)) continue;
                            double previous = this.bestScore[this.nStates * (pos - 1) + prevState];
                            double current = previous + transitionCost;
                            if (this.debug) {
                                log.debug((Object)String.format("Pos: %d Trans: %d-%d %.2f (Prev: %.2f + Trans: %.2f)", pos, prevState, (int)state, current, previous, transitionCost));
                            }
                            if (!(current > max)) continue;
                            max = current;
                            bestPrevState = prevState;
                        }
                    }
                }
                int index = pos * this.nStates + state;
                this.bestScore[index] = max;
                this.backPointers[index] = bestPrevState;
                backLengths[index] = bestLookback + 1;
            }
            for (int i = 0; i < this.modelInfo.statesWithLookback.length; ++i) {
                CacheProcessor.StatePotentials potentials = this.modelInfo.statesWithLookback[i];
                CacheProcessor.LengthFeatureEvaluation[] lookbacksForState = this.lengthEvals[i];
                byte state = potentials.state;
                double max = Double.NEGATIVE_INFINITY;
                int bestLookback = -1;
                int bestPrevState = -2;
                this.cacheProcessor.evaluateSegmentsEndingAt(0, pos);
                int lbIndex = 0;
                CacheProcessor.LengthFeatureEvaluation lengthEval = lookbacksForState[lbIndex];
                short lookback = lengthEval.lookback;
                while (lookback != -1) {
                    double[] lookbackMi = (double[])mis.get(lookback);
                    double[] lookbackStable = this.stableStates.get(lookback);
                    CacheProcessor.FeatureEvaluation nodeEvals = lengthEval.nodeEval;
                    short[] indices = nodeEvals.index;
                    float[] vals = nodeEvals.value;
                    int ix = 0;
                    short index = indices[ix];
                    double nodePotential = 0.0;
                    while (index >= 0) {
                        nodePotential += (double)vals[ix] * lambda[index];
                        index = indices[++ix];
                    }
                    Assert.a(index != Short.MIN_VALUE, "Node lengths should only be returned in the cache if they are valid");
                    int prevPos = pos - lookback - 1;
                    if (prevPos < 0) {
                        double current = latestStable[state] + nodePotential;
                        if (this.debug) {
                            log.debug((Object)String.format("Pos: %d Lb: %d State: %d %.2f (Stable: %.2f + Node: %.2f)", pos, (int)lookback, (int)state, current, latestStable[state], nodePotential));
                        }
                        if (current > max) {
                            max = current;
                            bestLookback = lookback;
                            bestPrevState = -1;
                        }
                    } else {
                        CacheProcessor.FeatureEvaluation[] edgeEvals = lengthEval.edgeEvals;
                        int nEdges = potentials.potentials.length;
                        for (int edgeIx = 0; edgeIx < nEdges; ++edgeIx) {
                            byte potential = potentials.potentials[edgeIx];
                            int trans = potential - this.modelInfo.nStates;
                            int fromNode = this.modelInfo.transitionFrom[trans];
                            if (fromNode == state) continue;
                            double edgeVal = 0.0;
                            if (edgeEvals == null) {
                                if (this.invalidTransitions[(prevPos + 1) * this.modelInfo.nPotentials + potential]) {
                                    continue;
                                }
                            } else {
                                CacheProcessor.FeatureEvaluation potEvals = edgeEvals[edgeIx];
                                indices = potEvals.index;
                                vals = potEvals.value;
                                ix = 0;
                                index = indices[i];
                                if (index == Short.MIN_VALUE) {
                                    log.info((Object)"SHORT.MIN_VALUE");
                                    continue;
                                }
                                while (index != -1) {
                                    edgeVal += (double)vals[ix] * lambda[index];
                                    index = indices[++ix];
                                }
                            }
                            double prevBest = this.bestScore[this.nStates * (pos - (lookback + 1)) + fromNode];
                            double stable = latestStable[state] - lookbackStable[state];
                            double current = prevBest + nodePotential + edgeVal + stable + lookbackMi[trans];
                            if (this.debug) {
                                log.debug((Object)String.format("Pos: %d Lb: %d Trans: %d-%d %.4f (Prev: %.4f + Stable: %.4f + Trans: %.4f + Node: %.4f + Edge: %.4f)", pos, (int)lookback, fromNode, (int)state, current, prevBest, stable, lookbackMi[trans], nodePotential, edgeVal));
                            }
                            if (current == Double.NEGATIVE_INFINITY || !(current > max)) continue;
                            max = current;
                            bestLookback = lookback;
                            bestPrevState = fromNode;
                        }
                    }
                    lengthEval = lookbacksForState[++lbIndex];
                    lookback = lengthEval.lookback;
                }
                int index = pos * this.nStates + state;
                this.bestScore[index] = max;
                this.backPointers[index] = bestPrevState;
                backLengths[index] = bestLookback + 1;
            }
        }
        int[] ret = new int[len];
        int pos = len - 1;
        int state = ColtUtil.maxInColumn(this.bestScore, this.nStates, len - 1);
        Assert.a(state != -2, "No valid paths");
        while (pos >= 0) {
            int stateLen = backLengths[pos * this.nStates + state];
            int prevState = this.backPointers[pos * this.nStates + state];
            for (int i = 0; i < stateLen; ++i) {
                ret[pos] = state;
                --pos;
            }
            state = prevState;
        }
        Assert.a(pos == -1);
        CRFInference.InferenceResult inferenceResult = new CRFInference.InferenceResult();
        inferenceResult.hiddenStates = ret;
        inferenceResult.bestScores = new double[this.nStates];
        System.arraycopy(this.bestScore, this.nStates * (len - 1), inferenceResult.bestScores, 0, this.nStates);
        return inferenceResult;
    }

    void updateStableBuffer(double[] nextMi) {
        double[] prevState = this.stableStates.get(0);
        for (int ix = 0; ix < this.nStates; ++ix) {
            int trans;
            if (this.modelInfo.maxStateLengths[ix] <= 1 || (trans = this.selfTransitions[ix]) == -1) continue;
            this.stableVector[ix] = Double.isInfinite(nextMi[trans]) ? prevState[ix] : prevState[ix] + nextMi[trans];
        }
        this.stableVector = this.stableStates.addFirst(this.stableVector);
    }

    void computeSparseMi(InputSequence seq, int pos, double[] mi, double[] ri) {
        this.cacheProcessor.evaluatePosition(0, pos);
        double nodeVal = Double.NaN;
        int invalidIndex = pos * this.modelInfo.nPotentials;
        for (short potential : this.modelInfo.orderedPotentials) {
            boolean invalid = this.invalidTransitions[invalidIndex + potential];
            double features = invalid ? Double.NEGATIVE_INFINITY : 0.0;
            CacheProcessor.FeatureEvaluation potEvals = this.evals[potential];
            short[] indices = potEvals.index;
            float[] vals = potEvals.value;
            int i = 0;
            short index = indices[i];
            while (index != -1) {
                features += index == Short.MIN_VALUE ? Double.NEGATIVE_INFINITY : (double)vals[i] * this.lambda[index];
                index = indices[++i];
            }
            if (potential < this.modelInfo.nStates) {
                nodeVal = features;
                if (ri == null) continue;
                ri[potential] = nodeVal;
                continue;
            }
            int transition = potential - this.modelInfo.nStates;
            double val = features + nodeVal;
            if (mi == null) continue;
            mi[transition] = val;
        }
    }
}

