/*
 * Decompiled with CFR 0.152.
 */
package calhoun.analysis.crf.solver.semimarkov;

import calhoun.analysis.crf.solver.CacheProcessor;
import calhoun.analysis.crf.solver.LogFiles;
import calhoun.analysis.crf.solver.LookbackBuffer;
import calhoun.analysis.crf.solver.semimarkov.CleanMaximumLikelihoodSemiMarkovGradient;
import calhoun.util.Assert;
import calhoun.util.ColtUtil;
import calhoun.util.FileUtil;
import java.util.Arrays;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

class BetaLengthFeatureProcessor {
    static final Log log = LogFactory.getLog(CleanMaximumLikelihoodSemiMarkovGradient.class);
    static final boolean debug = log.isDebugEnabled();
    final CleanMaximumLikelihoodSemiMarkovGradient parent;
    final CacheProcessor.SolverSetup modelInfo;
    final LogFiles logs;
    int seqOffset;
    double prob;
    double[] nodeProb;
    boolean globalArrays = false;
    double[][] betas;
    int[] betaNorms;
    double[][] allNodeProb;
    double[][] allEdgeProb;
    double[][] cumulativeStableLogProb;

    public BetaLengthFeatureProcessor(CleanMaximumLikelihoodSemiMarkovGradient parent) {
        this.parent = parent;
        this.modelInfo = parent.modelInfo;
        this.logs = parent.logs;
        this.nodeProb = new double[this.modelInfo.nStates];
    }

    public void setGlobalArrays(double[][] betas, int[] betaNorms, double[][] allNodeProb, double[][] allEdgeProb) {
        this.globalArrays = true;
        this.betas = betas;
        this.betaNorms = betaNorms;
        this.allNodeProb = allNodeProb;
        this.allEdgeProb = allEdgeProb;
    }

    void computeBetasAndExpectations(int seqNum, int len) {
        this.seqOffset = this.modelInfo.seqOffsets[seqNum];
        Arrays.fill(this.nodeProb, 0.0);
        LookbackBuffer posLookback = this.parent.nextBuffer;
        posLookback.clear();
        Arrays.fill(posLookback.beta, 1.0);
        posLookback.betaNorm = 0;
        Arrays.fill(posLookback.stableState, 0.0);
        posLookback.pos = len - 1;
        double[] lastStableState = posLookback.stableState;
        this.parent.nextBuffer = this.parent.lookbackBuffer.addFirst(posLookback);
        int lastInitPos = len - 2 - this.modelInfo.maxLookback;
        int miPos = len - 2;
        for (int pos = len - 1; pos >= 0; --pos) {
            while (miPos >= 0 && miPos >= lastInitPos) {
                LookbackBuffer newLookback = this.parent.nextBuffer;
                newLookback.clear();
                newLookback.pos = miPos;
                this.parent.cacheMi(seqNum, newLookback.mi, lastStableState, newLookback.stableState, miPos + 1);
                lastStableState = newLookback.stableState;
                this.parent.nextBuffer = this.parent.lookbackBuffer.addFirst(newLookback);
                --miPos;
            }
            --lastInitPos;
            if (debug) {
                Assert.a(posLookback.pos == pos, "Wrong lookback buffer: was ", posLookback.pos, " should be ", pos);
            }
            LookbackBuffer nextLookback = null;
            if (pos != 0) {
                nextLookback = this.parent.lookbackBuffer.get(pos - (miPos + 1) - 1);
                Assert.a(nextLookback.pos == pos - 1, "Wrong next lookback buffer: was ", posLookback.pos, " should be ", pos - 1);
            }
            this.regularBetaUpdate(pos, nextLookback, posLookback.beta, posLookback.betaNorm, posLookback.transitionProb);
            this.lengthBeta(seqNum, pos, miPos, posLookback);
            if (pos > 0) {
                for (CacheProcessor.StatePotentials lb : this.modelInfo.statesWithLookback) {
                    byte state = lb.state;
                    int selfTransIndex = this.modelInfo.selfTransitions[state];
                    double selfTransProb = this.nodeProb[state];
                    for (byte pot : lb.potentials) {
                        double lbTrans = nextLookback.transitionProb[pot - this.modelInfo.nStates];
                        selfTransProb -= lbTrans;
                    }
                    Assert.a(nextLookback.transitionProb[selfTransIndex] == 0.0);
                    nextLookback.transitionProb[selfTransIndex] = selfTransProb;
                }
            }
            this.verifyMarginals(seqNum, pos, nextLookback);
            this.updateExpectations(seqNum, pos, nextLookback);
            if (this.globalArrays) {
                int i;
                for (i = 0; i < this.modelInfo.nStates; ++i) {
                    this.allNodeProb[i][pos] = this.nodeProb[i];
                    this.betas[i][pos] = posLookback.beta[i];
                    this.betaNorms[pos] = posLookback.betaNorm;
                }
                for (i = 0; i < this.modelInfo.nTransitions; ++i) {
                    this.allEdgeProb[i][pos] = posLookback.transitionProb[i];
                }
            }
            if (pos > 0) {
                for (CacheProcessor.StatePotentials lb : this.modelInfo.statesWithLookback) {
                    int selfTransIndex = this.modelInfo.selfTransitions[lb.state];
                    this.nodeProb[lb.state] = nextLookback.transitionProb[selfTransIndex];
                }
            }
            posLookback = nextLookback;
        }
    }

    private void regularBetaUpdate(int pos, LookbackBuffer nextLookback, double[] oldBeta, int oldNorm, double[] transitionProb) {
        if (this.modelInfo.statesWithoutLookback.length == 0) {
            return;
        }
        double[] nodeAlpha = this.parent.alphas[pos];
        double nodeNorm = CleanMaximumLikelihoodSemiMarkovGradient.exp((this.parent.alphaNorms[pos] + oldNorm - this.parent.zNorm) * 50) * this.parent.zInv;
        double[] edgeAlpha = null;
        double edgeNorm = Double.NaN;
        double normAdjust = 0.0;
        if (pos > 0) {
            edgeAlpha = this.parent.alphas[pos - 1];
            if (oldNorm > nextLookback.betaNorm) {
                CleanMaximumLikelihoodSemiMarkovGradient.renormalize(nextLookback.beta, nextLookback.betaNorm, oldNorm);
                nextLookback.betaNorm = oldNorm;
            } else {
                normAdjust = (oldNorm - nextLookback.betaNorm) * 50;
            }
            edgeNorm = CleanMaximumLikelihoodSemiMarkovGradient.exp((this.parent.alphaNorms[pos - 1] + nextLookback.betaNorm - this.parent.zNorm) * 50) * this.parent.zInv;
        }
        for (CacheProcessor.StatePotentials potentials : this.modelInfo.statesWithoutLookback) {
            byte node = potentials.state;
            double nodePotential = 0.0;
            double betaVal = oldBeta[node];
            if (this.logs.nodeMarginalWriter != null) {
                FileUtil.safeWrite(this.logs.nodeMarginalWriter, String.format("NodeMarg[%d][%d] = %f = %f * %f * %f (aN: %d bN: %d zN: %d 1/z: %f)\n", pos, node, nodeAlpha[node] * betaVal * nodeNorm, nodeAlpha[node], betaVal, nodeNorm, this.parent.alphaNorms[pos], oldNorm, this.parent.zNorm, this.parent.zInv));
            }
            this.nodeProb[node] = nodeAlpha[node] * betaVal * nodeNorm;
            if (pos <= 0) continue;
            byte[] arr$ = potentials.potentials;
            int len$ = arr$.length;
            for (int i$ = 0; i$ < len$; ++i$) {
                short potential = arr$[i$];
                int trans = potential - this.modelInfo.nStates;
                double transVal = nextLookback.mi[trans];
                if (!Double.isInfinite(transVal)) {
                    short from;
                    double potentialValue = CleanMaximumLikelihoodSemiMarkovGradient.exp(nextLookback.mi[trans] + normAdjust);
                    nodePotential += potentialValue;
                    short s = from = this.modelInfo.transitionFrom[trans];
                    nextLookback.beta[s] = nextLookback.beta[s] + potentialValue * betaVal;
                    nextLookback.transitionProb[trans] = edgeAlpha[from] * potentialValue * betaVal * edgeNorm;
                    continue;
                }
                nextLookback.transitionProb[trans] = 0.0;
            }
        }
        if (pos > 0) {
            try {
                nextLookback.betaNorm += CleanMaximumLikelihoodSemiMarkovGradient.normalize(nextLookback.beta);
            }
            catch (RuntimeException ex) {
                CleanMaximumLikelihoodSemiMarkovGradient.log.warn((Object)("Normalization problem at " + pos + " " + ColtUtil.format(nextLookback.beta)));
                throw ex;
            }
        }
    }

    private void lengthBeta(int seqNum, int pos, int miPos, LookbackBuffer posLookback) {
        double[] beta = posLookback.beta;
        int betaNorm = posLookback.betaNorm;
        double[] lengthStable = posLookback.stableState;
        this.parent.cacheProcessor.evaluateSegmentsEndingAt(seqNum, pos);
        int nSemiMarkovStates = this.modelInfo.statesWithLookback.length;
        for (int i = 0; i < nSemiMarkovStates; ++i) {
            CacheProcessor.LengthFeatureEvaluation[] lookbacksForState = this.parent.lengthEvals[i];
            CacheProcessor.StatePotentials statePotentials = this.modelInfo.statesWithLookback[i];
            byte toNode = statePotentials.state;
            if (beta[toNode] == 0.0) continue;
            int lbArrayIndex = 0;
            CacheProcessor.LengthFeatureEvaluation lengthEval = lookbacksForState[lbArrayIndex];
            short lookback = lengthEval.lookback;
            while (lookback != -1) {
                double stableValue;
                int prevPos = pos - lookback - 1;
                int lbIndex = prevPos - miPos - 1;
                LookbackBuffer segBegin = null;
                if (prevPos >= 0) {
                    segBegin = ((LookbackBuffer[])this.parent.lookbackBuffer.array)[(this.parent.lookbackBuffer.currentStart + lbIndex) % this.parent.lookbackBuffer.length];
                }
                LookbackBuffer stableBuffer = ((LookbackBuffer[])this.parent.lookbackBuffer.array)[(this.parent.lookbackBuffer.currentStart + lbIndex + 1) % this.parent.lookbackBuffer.length];
                double nodePotential = stableValue = stableBuffer.stableState[toNode] - lengthStable[toNode];
                CacheProcessor.FeatureEvaluation nodeEvals = lengthEval.nodeEval;
                short[] indices = nodeEvals.index;
                float[] vals = nodeEvals.value;
                int ix = 0;
                short index = indices[ix];
                while (index >= 0) {
                    nodePotential += (double)vals[ix] * this.parent.lambda[index];
                    index = indices[++ix];
                }
                if (debug) {
                    Assert.a(index != Short.MIN_VALUE, "Node lengths should only be returned in the cache if they are valid.  They can be invalid because a node is invalid or a self-transition edge is invalid.");
                }
                if (prevPos < 0) {
                    double expVal = nodePotential + this.parent.starterAlpha[toNode];
                    this.lengthBetaHandling(seqNum, prevPos, pos, expVal, -1, toNode, 1.0, 0, beta[toNode], betaNorm, nodeEvals);
                } else {
                    CacheProcessor.FeatureEvaluation[] edgeEvals = lengthEval.edgeEvals;
                    int nEdges = statePotentials.potentials.length;
                    for (int edgeIx = 0; edgeIx < nEdges; ++edgeIx) {
                        byte potential = statePotentials.potentials[edgeIx];
                        int trans = potential - this.modelInfo.nStates;
                        short fromNode = this.modelInfo.transitionFrom[trans];
                        if (fromNode == toNode) continue;
                        double edgeVal = 0.0;
                        if (edgeEvals == null) {
                            int invalidIndex = (this.seqOffset + prevPos + 1) * this.modelInfo.nPotentials;
                            if (this.parent.invalidTransitions[invalidIndex + potential]) {
                                continue;
                            }
                        } else {
                            CacheProcessor.FeatureEvaluation potEvals = edgeEvals[edgeIx];
                            indices = potEvals.index;
                            vals = potEvals.value;
                            ix = 0;
                            index = indices[i];
                            while (index >= 0) {
                                edgeVal += (double)vals[ix] * this.parent.lambda[index];
                                index = indices[++ix];
                            }
                            if (index == Short.MIN_VALUE) continue;
                        }
                        if (debug) {
                            Assert.a(segBegin.pos == pos - lookback - 1, "Expected ", pos - lookback - 1, " was ", segBegin.pos);
                        }
                        double expVal = edgeVal + segBegin.mi[trans] + nodePotential;
                        double prevAlpha = this.parent.alphas[prevPos][fromNode];
                        int prevAlphaNorm = this.parent.alphaNorms[prevPos];
                        double origBeta = segBegin.beta[fromNode];
                        int origBetaNorm = segBegin.betaNorm;
                        int expNorm = this.lengthBetaHandling(seqNum, prevPos, pos, expVal, fromNode, toNode, prevAlpha, prevAlphaNorm, beta[toNode], betaNorm, nodeEvals);
                        double transPotential = CleanMaximumLikelihoodSemiMarkovGradient.exp(expVal -= (double)(expNorm * 50));
                        double update = transPotential * beta[toNode];
                        int updateNorm = expNorm + betaNorm;
                        if (update < CleanMaximumLikelihoodSemiMarkovGradient.NORM_MIN) {
                            --updateNorm;
                            update *= CleanMaximumLikelihoodSemiMarkovGradient.NORM_MAX;
                        } else if (update > CleanMaximumLikelihoodSemiMarkovGradient.NORM_MAX) {
                            ++updateNorm;
                            update *= CleanMaximumLikelihoodSemiMarkovGradient.NORM_MIN;
                        }
                        if (updateNorm > segBegin.betaNorm) {
                            CleanMaximumLikelihoodSemiMarkovGradient.renormalize(segBegin.beta, segBegin.betaNorm, updateNorm);
                            segBegin.betaNorm = updateNorm;
                        } else if (segBegin.betaNorm > updateNorm) {
                            int expShift = updateNorm - segBegin.betaNorm;
                            update *= CleanMaximumLikelihoodSemiMarkovGradient.exp(expShift * 50);
                        }
                        short s = fromNode;
                        segBegin.beta[s] = segBegin.beta[s] + update;
                        if (edgeEvals != null) {
                            CacheProcessor.FeatureEvaluation potEvals = edgeEvals[edgeIx];
                            indices = potEvals.index;
                            vals = potEvals.value;
                            ix = 0;
                            index = indices[i];
                            while (index != -1) {
                                if (this.logs.expectLengthWriter != null) {
                                    FileUtil.safeWrite(this.logs.expectLengthWriter, String.format("Seq %d Pos %d-%d State: %d-%d Expect #%d: %e = %e + Prob: %e * EdgeVal: %e\n", seqNum, prevPos + 1, pos, this.modelInfo.transitionFrom[edgeIx], this.modelInfo.transitionTo[edgeIx], index, this.parent.expects[index] + this.prob * (double)vals[i], this.parent.expects[index], this.prob, Float.valueOf(vals[i])));
                                }
                                short s2 = index;
                                this.parent.expects[s2] = this.parent.expects[s2] + this.prob * (double)vals[ix];
                                index = indices[++ix];
                            }
                        }
                        int n = trans;
                        segBegin.transitionProb[n] = segBegin.transitionProb[n] + this.prob;
                        if (this.logs.betaLengthWriter == null) continue;
                        FileUtil.safeWrite(this.logs.betaLengthWriter, String.format(String.format("Beta[%d][%d] = %s (%g, %d)= %s (%g, %d) + %s (%g, %d) beta[%d][%d] * %s (%g, %d) exp(Edge: %f Node: %f Stable: %f Trans: %f)\n", prevPos, (int)fromNode, CleanMaximumLikelihoodSemiMarkovGradient.printNorm(segBegin.beta[fromNode], segBegin.betaNorm), segBegin.beta[fromNode], segBegin.betaNorm, CleanMaximumLikelihoodSemiMarkovGradient.printNorm(origBeta, origBetaNorm), origBeta, origBetaNorm, CleanMaximumLikelihoodSemiMarkovGradient.printNorm(beta[toNode], betaNorm), beta[toNode], betaNorm, pos, toNode, CleanMaximumLikelihoodSemiMarkovGradient.printNorm(transPotential, expNorm), transPotential, expNorm, edgeVal, nodePotential - stableValue, stableValue, segBegin.mi[trans]), new Object[0]));
                    }
                    segBegin.betaNorm += CleanMaximumLikelihoodSemiMarkovGradient.normalize(segBegin.beta);
                }
                lengthEval = lookbacksForState[++lbArrayIndex];
                lookback = lengthEval.lookback;
            }
        }
    }

    int lengthBetaHandling(int seqNum, int prevPos, int pos, double expVal, int fromNode, int toNode, double prevAlpha, int prevAlphaNorm, double betaVal, int betaNorm, CacheProcessor.FeatureEvaluation nodeEvals) {
        int norm = (int)expVal / 50;
        double afterExp = CleanMaximumLikelihoodSemiMarkovGradient.exp((expVal -= (double)(norm * 50)) + (double)(50 * (prevAlphaNorm + norm + betaNorm - this.parent.zNorm)));
        this.prob = prevAlpha == 0.0 || betaVal == 0.0 ? 0.0 : prevAlpha * betaVal * afterExp * this.parent.zInv;
        if (Double.isNaN(this.prob) || Double.isInfinite(this.prob)) {
            CleanMaximumLikelihoodSemiMarkovGradient.log.info((Object)String.format("NaN = Alpha: %s * Beta: %s * Seg: %s / Z: %s", CleanMaximumLikelihoodSemiMarkovGradient.printNorm(prevAlpha, prevAlphaNorm), CleanMaximumLikelihoodSemiMarkovGradient.printNorm(betaVal, betaNorm), CleanMaximumLikelihoodSemiMarkovGradient.printNorm(CleanMaximumLikelihoodSemiMarkovGradient.exp(expVal), norm), CleanMaximumLikelihoodSemiMarkovGradient.printNorm(1.0 / this.parent.zInv, this.parent.zNorm)));
            Assert.a(false, String.format("Seq: %d Pos: %d-%d: Bad prob (NaN) = Alpha: %e * Beta[%d] %e * %e exp(%f Norm  a:%d n:%d b:%d z:%d) * %e", seqNum, prevPos, pos, prevAlpha, toNode, betaVal, afterExp, expVal, prevAlphaNorm, norm, betaNorm, this.parent.zNorm, this.parent.zInv));
        }
        short[] indices = nodeEvals.index;
        float[] vals = nodeEvals.value;
        int i = 0;
        short index = indices[i];
        while (index != -1) {
            if (this.logs.expectLengthWriter != null) {
                FileUtil.safeWrite(this.logs.expectLengthWriter, String.format("Seq %d Pos %d-%d State: %d Expect #%d: %e = %e + Prob: %e * NodeVal: %e\n", seqNum, prevPos + 1, pos, toNode, index, this.parent.expects[index] + this.prob * (double)vals[i], this.parent.expects[index], this.prob, Float.valueOf(vals[i])));
            }
            if (this.prob != 0.0) {
                short s = index;
                this.parent.expects[s] = this.parent.expects[s] + this.prob * (double)vals[i];
            }
            index = indices[++i];
        }
        if (this.logs.nodeMarginalWriter != null) {
            FileUtil.safeWrite(this.logs.nodeMarginalWriter, String.format("NodeMarg[%d][%d] = %f = %f + Alpha[%d][%d]: %s (%g n: %d) * Beta[%d][%d]: %s (%g n: %d) * seg: %s (%g n: %d) / Z: %s\n", pos, toNode, this.nodeProb[toNode] + this.prob, this.nodeProb[toNode], prevPos, fromNode, CleanMaximumLikelihoodSemiMarkovGradient.printNorm(prevAlpha, prevAlphaNorm), prevAlpha, prevAlphaNorm, pos, toNode, CleanMaximumLikelihoodSemiMarkovGradient.printNorm(betaVal, betaNorm), betaVal, betaNorm, CleanMaximumLikelihoodSemiMarkovGradient.printNorm(expVal, norm), expVal, norm, CleanMaximumLikelihoodSemiMarkovGradient.printNorm(1.0 / this.parent.zInv, this.parent.zNorm)));
        }
        int n = toNode;
        this.nodeProb[n] = this.nodeProb[n] + this.prob;
        return norm;
    }

    void updateExpectations(int seqNum, int pos, LookbackBuffer nextLookback) {
        this.parent.cacheProcessor.evaluatePosition(seqNum, pos);
        int invalidIndex = (this.seqOffset + pos) * this.modelInfo.nPotentials;
        for (short potential : this.modelInfo.orderedPotentials) {
            boolean invalid = this.parent.invalidTransitions[invalidIndex + potential];
            if (invalid) continue;
            double prob = Double.NaN;
            if (potential < this.modelInfo.nStates) {
                prob = this.nodeProb[potential];
            } else {
                if (pos == 0) continue;
                prob = nextLookback.transitionProb[potential - this.modelInfo.nStates];
            }
            CacheProcessor.FeatureEvaluation potEvals = this.parent.evals[potential];
            short[] indices = potEvals.index;
            float[] vals = potEvals.value;
            int i = 0;
            short index = indices[i];
            while (index != -1) {
                if (this.logs.expectWriter != null) {
                    FileUtil.safeWrite(this.logs.expectWriter, String.format("Seq %d Pos %d Expect #%d: %e = %e + Prob: %e * Val: %e\n", seqNum, pos, index, this.parent.expects[index] + prob * (double)vals[i], this.parent.expects[index], prob, Float.valueOf(vals[i])));
                }
                short s = index;
                this.parent.expects[s] = this.parent.expects[s] + prob * (double)vals[i];
                index = indices[++i];
            }
        }
    }

    final void verifyMarginals(int seqNum, int pos, LookbackBuffer nextLookback) {
        double sum = 0.0;
        for (double x : this.nodeProb) {
            if (x > 1.0001 || x < -1.0E-4) {
                Assert.a(false, "Iter ", this.parent.iter, " Seq: ", seqNum, " Pos: " + pos + " Node marginals not valid " + x);
            }
            sum += x;
        }
        if (Math.abs(1.0 - sum) > 1.0E-4) {
            Assert.a(false, "Iter ", this.parent.iter, " Pos: " + pos + " Node marginals sum to " + sum + " not 1: ", ColtUtil.format(this.nodeProb), " at ", seqNum, " ", pos);
        }
        if (debug && nextLookback != null) {
            double[] transitionProb = nextLookback.transitionProb;
            sum = 0.0;
            for (double x : transitionProb) {
                if (x > 1.0001 || x < -1.0E-4) {
                    Assert.a(false, "Iter ", this.parent.iter, " Pos: " + pos + " Self-trans marginal not valid " + x);
                }
                sum += x;
            }
            if (Math.abs(1.0 - sum) > 0.001) {
                Assert.a(false, "Seq: ", seqNum, " pos: ", pos, " Edge marginals don't sum to 1.  Sum to: ", sum, ".  They are: ", ColtUtil.format(transitionProb));
            }
        }
    }
}

