/*
 * Decompiled with CFR 0.152.
 */
package calhoun.analysis.crf.solver.semimarkov;

import calhoun.analysis.crf.solver.CacheProcessor;
import calhoun.analysis.crf.solver.LogFiles;
import calhoun.analysis.crf.solver.LookbackBuffer;
import calhoun.analysis.crf.solver.semimarkov.CleanMaximumLikelihoodSemiMarkovGradient;
import calhoun.util.Assert;
import calhoun.util.FileUtil;
import java.util.Arrays;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

final class AlphaLengthFeatureProcessor {
    static final Log log = LogFactory.getLog(CleanMaximumLikelihoodSemiMarkovGradient.class);
    static final boolean debug = log.isDebugEnabled();
    private final CleanMaximumLikelihoodSemiMarkovGradient parent;
    int seqOffset;
    int pos;
    double[] alpha;
    int alphaNorm;
    double[] stableState;
    CacheProcessor.SolverSetup modelInfo;
    final LogFiles logs;

    public AlphaLengthFeatureProcessor(CleanMaximumLikelihoodSemiMarkovGradient parent) {
        this.parent = parent;
        this.modelInfo = parent.modelInfo;
        this.logs = parent.logs;
    }

    final void computeAlpha(int seqNum, int len) {
        Arrays.fill(this.parent.alphaNorms, Integer.MIN_VALUE);
        Arrays.fill(this.parent.starterAlpha, 0.0);
        double[] prevAlpha = null;
        this.seqOffset = this.modelInfo.seqOffsets[seqNum];
        this.pos = 0;
        while (this.pos < len) {
            prevAlpha = this.alpha;
            this.alpha = this.parent.alphas[this.pos];
            this.alphaNorm = this.parent.alphaNorms[this.pos];
            Arrays.fill(this.alpha, 0.0);
            boolean alphaUpdated = false;
            if (this.pos == 0) {
                this.alphaNorm = 0;
                this.calcStartAlpha(this.alpha, seqNum);
                Arrays.fill(this.parent.nextBuffer.stableState, 0.0);
            } else {
                this.parent.cacheMi(seqNum, this.parent.nextBuffer.mi, this.stableState, this.parent.nextBuffer.stableState, this.pos);
                alphaUpdated = this.regularAlphaUpdate(this.pos, this.parent.nextBuffer.mi, prevAlpha, this.alpha);
            }
            this.stableState = this.parent.nextBuffer.stableState;
            this.parent.nextBuffer = this.parent.lookbackBuffer.addFirst(this.parent.nextBuffer);
            if (alphaUpdated && this.pos != 0) {
                this.alphaNorm = this.parent.alphaNorms[this.pos - 1];
            }
            int norm = CleanMaximumLikelihoodSemiMarkovGradient.normalize(this.alpha);
            if (debug && norm != 0 && (this.alphaNorm + norm > this.alphaNorm && norm < 0 || this.alphaNorm + norm < this.alphaNorm && norm > 0)) {
                Assert.a(false, "Wraparound, pos=" + this.pos + ", norm=" + norm);
            }
            this.alphaNorm += norm;
            this.lengthAlpha(seqNum, this.pos);
            this.parent.alphaNorms[this.pos] = this.alphaNorm;
            ++this.pos;
        }
    }

    private final boolean regularAlphaUpdate(int pos, double[] mi, double[] lastAlpha, double[] newAlpha) {
        double nodeVal = 0.0;
        int lastState = -1;
        boolean lengthNode = false;
        boolean ret = false;
        for (int n : this.modelInfo.orderedPotentials) {
            if (n < this.modelInfo.nStates) {
                if (lastState != -1) {
                    if (Math.abs(nodeVal - newAlpha[lastState]) > 1.0E-16) {
                        ret = true;
                    }
                    newAlpha[lastState] = nodeVal;
                }
                lastState = n;
                nodeVal = 0.0;
                lengthNode = this.modelInfo.maxStateLengths[n] > 1;
                continue;
            }
            if (lengthNode) continue;
            ret = true;
            int trans = n - this.modelInfo.nStates;
            double transVal = mi[trans];
            if (Double.isInfinite(transVal)) continue;
            short from = this.modelInfo.transitionFrom[trans];
            if (this.logs.alphaWriter != null) {
                FileUtil.safeWrite(this.logs.alphaWriter, String.format("alpha[%d][%d] = %s = %s + alpha[%d][%d] %s * %s exp(%f)\n", pos, lastState, CleanMaximumLikelihoodSemiMarkovGradient.printNorm(nodeVal + lastAlpha[from] * CleanMaximumLikelihoodSemiMarkovGradient.exp(mi[trans]), this.alphaNorm), CleanMaximumLikelihoodSemiMarkovGradient.printNorm(nodeVal, this.alphaNorm), pos - 1, (int)from, CleanMaximumLikelihoodSemiMarkovGradient.printNorm(lastAlpha[from], this.alphaNorm), CleanMaximumLikelihoodSemiMarkovGradient.printNorm(CleanMaximumLikelihoodSemiMarkovGradient.exp(mi[trans]), 0), mi[trans]));
            }
            nodeVal += lastAlpha[from] * CleanMaximumLikelihoodSemiMarkovGradient.exp(transVal);
        }
        if (Math.abs(nodeVal - newAlpha[lastState]) > 1.0E-16) {
            ret = true;
        }
        newAlpha[lastState] = nodeVal;
        return ret;
    }

    private final void lengthAlpha(int seqNum, int pos) {
        this.parent.cacheProcessor.evaluateSegmentsEndingAt(seqNum, pos);
        for (int i = 0; i < this.parent.nSemiMarkovStates; ++i) {
            CacheProcessor.LengthFeatureEvaluation[] lookbacksForState = this.parent.lengthEvals[i];
            CacheProcessor.StatePotentials statePotentials = this.modelInfo.statesWithLookback[i];
            byte toNode = statePotentials.state;
            int lbIndex = 0;
            CacheProcessor.LengthFeatureEvaluation lengthEval = lookbacksForState[lbIndex];
            short lookback = lengthEval.lookback;
            while (lookback != -1) {
                double stableValue;
                int prevPos = pos - lookback - 1;
                LookbackBuffer buffer = ((LookbackBuffer[])this.parent.lookbackBuffer.array)[(this.parent.lookbackBuffer.currentStart + lookback) % this.parent.lookbackBuffer.length];
                CacheProcessor.FeatureEvaluation nodeEvals = lengthEval.nodeEval;
                short[] indices = nodeEvals.index;
                float[] vals = nodeEvals.value;
                int ix = 0;
                short index = indices[ix];
                double nodePotential = stableValue = this.stableState[toNode] - buffer.stableState[toNode];
                while (index >= 0) {
                    nodePotential += (double)vals[ix] * this.parent.lambda[index];
                    index = indices[++ix];
                }
                if (debug) {
                    Assert.a(index != Short.MIN_VALUE, "Node lengths should only be returned in the cache if they are valid");
                }
                if (prevPos < 0) {
                    double nodeVal = nodePotential + this.parent.starterAlpha[toNode];
                    int norm = (int)nodeVal / 50;
                    nodeVal -= (double)(norm * 50);
                    if (norm > this.alphaNorm) {
                        CleanMaximumLikelihoodSemiMarkovGradient.renormalize(this.alpha, this.alphaNorm, norm);
                        this.alphaNorm = norm;
                    } else if (norm < this.alphaNorm) {
                        nodeVal += (double)(50 * (norm - this.alphaNorm));
                    }
                    if (this.logs.alphaLengthWriter != null) {
                        FileUtil.safeWrite(this.logs.alphaLengthWriter, String.format("seq: %d alpha[%d][%d] = %s = %s + %s (Pot: %f Starter: %f)\n", seqNum, pos, toNode, CleanMaximumLikelihoodSemiMarkovGradient.printNorm(this.alpha[toNode] + CleanMaximumLikelihoodSemiMarkovGradient.exp(nodeVal), this.alphaNorm), CleanMaximumLikelihoodSemiMarkovGradient.printNorm(this.alpha[toNode], this.alphaNorm), CleanMaximumLikelihoodSemiMarkovGradient.printNorm(CleanMaximumLikelihoodSemiMarkovGradient.exp(nodeVal), this.alphaNorm), nodePotential, this.parent.starterAlpha[toNode]));
                    }
                    byte by = toNode;
                    this.alpha[by] = this.alpha[by] + CleanMaximumLikelihoodSemiMarkovGradient.exp(nodeVal);
                } else {
                    CacheProcessor.FeatureEvaluation[] edgeEvals = lengthEval.edgeEvals;
                    int nEdges = statePotentials.potentials.length;
                    for (int edgeIx = 0; edgeIx < nEdges; ++edgeIx) {
                        byte potential = statePotentials.potentials[edgeIx];
                        int trans = potential - this.modelInfo.nStates;
                        short fromNode = this.modelInfo.transitionFrom[trans];
                        if (fromNode == toNode) continue;
                        double edgeVal = 0.0;
                        if (edgeEvals == null) {
                            int invalidIndex = (this.seqOffset + prevPos + 1) * this.modelInfo.nPotentials;
                            if (this.parent.invalidTransitions[invalidIndex + potential]) {
                                continue;
                            }
                        } else {
                            CacheProcessor.FeatureEvaluation potEvals = edgeEvals[edgeIx];
                            indices = potEvals.index;
                            vals = potEvals.value;
                            ix = 0;
                            index = indices[i];
                            while (index >= 0) {
                                edgeVal += (double)vals[ix] * this.parent.lambda[index];
                                index = indices[++ix];
                            }
                            if (index == Short.MIN_VALUE) continue;
                        }
                        double expVal = edgeVal + buffer.mi[trans] + nodePotential;
                        int expNorm = (int)expVal / 50;
                        expVal -= (double)(expNorm * 50);
                        int prevNorm = this.parent.alphaNorms[prevPos];
                        double prevAlpha = this.parent.alphas[prevPos][fromNode];
                        if (prevNorm == Integer.MIN_VALUE) continue;
                        int updateNorm = expNorm + prevNorm;
                        double update = CleanMaximumLikelihoodSemiMarkovGradient.exp(expVal) * prevAlpha;
                        if (update == 0.0) continue;
                        if (update < CleanMaximumLikelihoodSemiMarkovGradient.NORM_MIN) {
                            --updateNorm;
                            update *= CleanMaximumLikelihoodSemiMarkovGradient.NORM_MAX;
                        } else if (update > CleanMaximumLikelihoodSemiMarkovGradient.NORM_MAX) {
                            ++updateNorm;
                            update *= CleanMaximumLikelihoodSemiMarkovGradient.NORM_MIN;
                        }
                        if (updateNorm > this.alphaNorm) {
                            CleanMaximumLikelihoodSemiMarkovGradient.renormalize(this.alpha, this.alphaNorm, updateNorm);
                            this.alphaNorm = updateNorm;
                        } else if (this.alphaNorm > updateNorm) {
                            int expShift = updateNorm - this.alphaNorm;
                            update *= CleanMaximumLikelihoodSemiMarkovGradient.exp(expShift * 50);
                        }
                        if (this.logs.alphaLengthWriter != null) {
                            FileUtil.safeWrite(this.logs.alphaLengthWriter, String.format("seq: %d alpha[%d][%d] %s = %s (%g, %d) + %s (%g, %d) alpha[%d][%d] * %s (%g, %d) exp(EdgeLength: %f NodeLength: %f Edge: %f Node: %f )\n", seqNum, pos, toNode, CleanMaximumLikelihoodSemiMarkovGradient.printNorm(this.alpha[toNode] + update, this.alphaNorm), CleanMaximumLikelihoodSemiMarkovGradient.printNorm(this.alpha[toNode], this.alphaNorm), this.alpha[toNode], this.alphaNorm, CleanMaximumLikelihoodSemiMarkovGradient.printNorm(prevAlpha, this.parent.alphaNorms[prevPos]), prevAlpha, this.parent.alphaNorms[prevPos], prevPos, this.modelInfo.transitionFrom[trans], CleanMaximumLikelihoodSemiMarkovGradient.printNorm(CleanMaximumLikelihoodSemiMarkovGradient.exp(expVal), expNorm), CleanMaximumLikelihoodSemiMarkovGradient.exp(expVal), expNorm, edgeVal, nodePotential - stableValue, buffer.mi[trans], stableValue));
                        }
                        byte by = toNode;
                        this.alpha[by] = this.alpha[by] + update;
                    }
                }
                lengthEval = lookbacksForState[++lbIndex];
                lookback = lengthEval.lookback;
            }
        }
    }

    void calcStartAlpha(double[] currentAlpha, int seq) {
        this.parent.cacheProcessor.evaluatePosition(seq, 0);
        int invalidIndex = this.seqOffset * this.modelInfo.nPotentials;
        for (short potential : this.modelInfo.orderedPotentials) {
            if (potential >= this.modelInfo.nStates) continue;
            boolean invalid = this.parent.invalidTransitions[invalidIndex + potential];
            double features = invalid ? Double.NEGATIVE_INFINITY : 0.0;
            CacheProcessor.FeatureEvaluation potEvals = this.parent.evals[potential];
            short[] indices = potEvals.index;
            float[] vals = potEvals.value;
            int i = 0;
            short index = indices[i];
            while (index != -1) {
                if (index == Short.MIN_VALUE) {
                    features = Double.NEGATIVE_INFINITY;
                    break;
                }
                features += (double)vals[i] * this.parent.lambda[index];
                index = indices[++i];
            }
            if (this.modelInfo.maxStateLengths[potential] > 1) {
                this.parent.starterAlpha[potential] = features;
                continue;
            }
            currentAlpha[potential] = CleanMaximumLikelihoodSemiMarkovGradient.exp(features);
        }
    }
}

