/*
 * Decompiled with CFR 0.152.
 */
package calhoun.analysis.crf.solver.check;

import calhoun.analysis.crf.CRFObjectiveFunctionGradient;
import calhoun.analysis.crf.LocalPathSimilarityScore;
import calhoun.analysis.crf.ModelManager;
import calhoun.analysis.crf.io.TrainingSequence;
import calhoun.analysis.crf.scoring.SimScoreMaxStateAgreement;
import calhoun.analysis.crf.solver.CacheProcessor;
import calhoun.analysis.crf.solver.check.FeatureCache;
import calhoun.util.Assert;
import calhoun.util.ColtUtil;
import calhoun.util.DenseIntMatrix2D;
import calhoun.util.FileUtil;
import java.io.BufferedWriter;
import java.util.Arrays;
import java.util.List;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public class CachedAOFGradient
implements CRFObjectiveFunctionGradient {
    private static final Log log = LogFactory.getLog(CachedAOFGradient.class);
    private static final boolean debug = log.isDebugEnabled();
    boolean printALot = false;
    CacheProcessor cacheProcessor;
    LocalPathSimilarityScore score = new SimScoreMaxStateAgreement();
    String scoreAlphaFile = null;
    String expectedProductFile = null;
    BufferedWriter scoreAlphaWriter = null;
    BufferedWriter expectedProductWriter = null;
    short[] id;
    byte[] potentialIx;
    float[] val;
    int miLength;
    double[] constMi;
    double[] mi;
    short[] transitionFrom;
    short[] transitionTo;
    short[] orderedPotentials;
    boolean[] invalidTransitions;
    int totalPositions;
    double[] featureSums;
    int[] starts;
    int[] seqOffsets;
    int nSeqs;
    int nConstantFeatures;
    List<? extends TrainingSequence<?>> data;
    ModelManager fm;
    int nFeatures;
    int nStates;
    int nPotentials;
    int nTransitions;
    int iter = 0;
    double[] prevAlpha;
    double[] alpha;
    double[][] betas;
    double[] betaNorms;
    double[] expects;
    double[][] scoreAlpha;
    double[][] scoreBeta;
    double[][] edgeProb;
    double[][] nodeProb;
    private DenseIntMatrix2D transitionIndex;
    double[] localExpects;
    double[] temp1;
    double[] temp2;
    boolean allPaths;

    double exp(double val1) {
        return Math.exp(val1);
    }

    double log(double val1) {
        return Math.log(val1);
    }

    public void setAllPaths(boolean allPaths) {
        this.allPaths = allPaths;
    }

    @Override
    public void setTrainingData(ModelManager fm, List<? extends TrainingSequence<?>> data) {
        this.fm = fm;
        this.data = data;
        this.nFeatures = fm.getNumFeatures();
        this.nStates = fm.getNumStates();
        this.nSeqs = data.size();
        this.expects = new double[this.nFeatures];
        this.localExpects = new double[this.nFeatures];
        this.prevAlpha = new double[this.nStates];
        this.alpha = new double[this.nStates];
        FeatureCache cache = new FeatureCache(fm, data, this.allPaths);
        Assert.a(this.nStates == cache.nStates);
        this.orderedPotentials = cache.orderedPotentials;
        this.id = cache.id;
        this.potentialIx = cache.potentialIx;
        this.val = cache.val;
        this.transitionFrom = cache.transitionFrom;
        this.transitionTo = cache.transitionTo;
        this.transitionIndex = cache.transitionIndex;
        this.miLength = cache.nTransitions;
        this.constMi = new double[this.miLength];
        this.mi = new double[this.miLength];
        this.featureSums = cache.featureSums;
        this.starts = cache.starts;
        this.seqOffsets = cache.seqOffsets;
        this.invalidTransitions = cache.invalidTransitions;
        this.nConstantFeatures = cache.numConstantFeatures;
        this.nPotentials = cache.nPotentials;
        this.nTransitions = cache.nTransitions;
        this.totalPositions = cache.totalPositions;
        this.betas = new double[cache.longestSeq][this.nStates];
        this.betaNorms = new double[cache.longestSeq];
        this.scoreAlpha = new double[cache.longestSeq][this.nStates];
        this.scoreBeta = new double[cache.longestSeq][this.nStates];
        this.edgeProb = new double[cache.longestSeq][cache.nTransitions];
        this.nodeProb = new double[cache.longestSeq][this.nStates];
        this.temp1 = new double[this.nStates];
        this.temp2 = new double[this.nStates];
    }

    @Override
    public void clean() {
    }

    @Override
    public double apply(double[] param, double[] grad) {
        int i;
        this.scoreAlphaWriter = FileUtil.safeOpen(this.scoreAlphaFile);
        this.expectedProductWriter = FileUtil.safeOpen(this.expectedProductFile);
        Arrays.fill(grad, 0.0);
        double result = 0.0;
        double[] seqGrad = new double[grad.length];
        for (int jfeat = 0; jfeat < this.nFeatures; ++jfeat) {
            Assert.a(!Double.isNaN(grad[jfeat]));
        }
        Arrays.fill(this.constMi, 0.0);
        this.calcMi(-1, 0, this.starts[0], param);
        for (int i2 = 0; i2 < this.miLength; ++i2) {
            this.constMi[i2] = this.log(this.mi[i2]);
        }
        Arrays.fill(this.expects, 0.0);
        int seqStart = 0;
        for (i = 0; i < this.nSeqs; ++i) {
            int jfeat;
            int y;
            double ep;
            int trans;
            int pos;
            Arrays.fill(seqGrad, 0.0);
            int len = this.seqOffsets[i + 1] - this.seqOffsets[i];
            Arrays.fill(this.betas[len - 1], 1.0);
            this.betaNorms[len - 1] = 0.0;
            int cacheStop = this.starts[seqStart + len];
            for (int pos2 = len - 1; pos2 > 0; --pos2) {
                int overallPosition = seqStart + pos2;
                int cacheStart = this.starts[overallPosition];
                this.calcMi(overallPosition, cacheStart, cacheStop, param);
                cacheStop = cacheStart;
                this.quickBetaUpdate(this.betas[pos2], this.betas[pos2 - 1]);
                double n = this.normalizePotential(this.betas[pos2 - 1]);
                this.betaNorms[pos2 - 1] = this.betaNorms[pos2] + this.log(n);
                Assert.a(!Double.isNaN(this.log(n)));
            }
            double logZ = Double.NEGATIVE_INFINITY;
            double alphaNorm = 0.0;
            double prevAlphaNorm = 0.0;
            int cacheStart = this.starts[seqStart];
            for (int j = 0; j < this.nFeatures; ++j) {
                this.localExpects[j] = 0.0;
            }
            for (pos = 0; pos < len; ++pos) {
                int overallPosition = seqStart + pos;
                double[] beta = this.betas[pos];
                double betaNorm = this.betaNorms[pos];
                cacheStop = this.starts[overallPosition + 1];
                if (pos == 0) {
                    this.calcStartAlpha(overallPosition, cacheStart, cacheStop, param);
                    alphaNorm = this.log(this.normalizePotential(this.alpha));
                    logZ = this.log(ColtUtil.dotProduct(this.alpha, beta)) + betaNorm + alphaNorm;
                } else {
                    this.calcMi(overallPosition, cacheStart, cacheStop, param);
                    this.quickAlphaUpdate(this.prevAlpha, this.alpha);
                    alphaNorm = prevAlphaNorm + this.log(this.normalizePotential(this.alpha));
                }
                double nodeNorm = this.exp(alphaNorm + betaNorm - logZ);
                if (Double.isNaN(nodeNorm)) {
                    Assert.a(!Double.isNaN(nodeNorm), " alphaNorm = " + alphaNorm + " betaNorm " + betaNorm + "  logZ " + logZ);
                }
                double edgeNorm = this.exp(prevAlphaNorm + betaNorm - logZ);
                if (Double.isNaN(nodeNorm) || Double.isNaN(edgeNorm) || Double.isInfinite(nodeNorm) || Double.isInfinite(edgeNorm)) {
                    Assert.a(false, "nodeNorm = " + nodeNorm + "  edeNorm = " + edgeNorm + "  alphaNorm = " + alphaNorm + "  betaNorm = " + betaNorm + "  prevAlphaNorm = " + prevAlphaNorm + "  logZ = " + logZ);
                }
                this.updateExpectations(overallPosition, pos, pos != 0, cacheStart, cacheStop, nodeNorm, edgeNorm, beta);
                if (!(!debug || i >= 2 && i != this.nSeqs - 1 || pos >= 2 && pos < len - 2)) {
                    log.debug((Object)String.format("Pos: %d expects: %s alphas: %s (norm %f) betas: %s (norm %f)", pos, ColtUtil.format(this.expects), ColtUtil.format(this.alpha), alphaNorm, ColtUtil.format(beta), betaNorm));
                }
                double[] swap = this.prevAlpha;
                this.prevAlpha = this.alpha;
                this.alpha = swap;
                prevAlphaNorm = alphaNorm;
                cacheStart = cacheStop;
            }
            for (pos = 1; pos < len; ++pos) {
                Arrays.fill(this.temp1, 0.0);
                Arrays.fill(this.temp2, 0.0);
                for (trans = 0; trans < this.nTransitions; ++trans) {
                    double ep2;
                    short yprev = this.transitionFrom[trans];
                    short y2 = this.transitionTo[trans];
                    double eps2 = Double.MIN_VALUE;
                    if (this.edgeProb[pos][trans] < eps2 || Double.isNaN(this.edgeProb[pos][trans])) {
                        this.edgeProb[pos][trans] = eps2;
                    }
                    if (!((ep2 = this.edgeProb[pos][trans]) >= 0.0)) {
                        Assert.a(false, "ep = " + ep2 + "   eps = " + eps2);
                    }
                    short s = yprev;
                    this.temp1[s] = this.temp1[s] + ep2;
                    short s2 = y2;
                    this.temp2[s2] = this.temp2[s2] + ep2;
                }
                for (int j = 0; j < this.nStates; ++j) {
                    if (this.temp1[j] > this.nodeProb[pos - 1][j]) {
                        this.nodeProb[pos - 1][j] = this.temp1[j];
                    }
                    if (!(this.temp2[j] > this.nodeProb[pos][j])) continue;
                    this.nodeProb[pos][j] = this.temp2[j];
                }
                for (trans = 0; trans < this.nTransitions; ++trans) {
                    short yprev = this.transitionFrom[trans];
                    short y3 = this.transitionTo[trans];
                    ep = this.edgeProb[pos][trans];
                    if (this.nodeProb[pos - 1][yprev] < ep) {
                        Assert.a(false, " ep = " + ep + " np = " + this.nodeProb[pos - 1][yprev]);
                    }
                    Assert.a(this.nodeProb[pos][y3] >= ep);
                }
            }
            for (pos = 0; pos < len; ++pos) {
                for (int stat = 0; stat < this.nStates; ++stat) {
                    double np = this.nodeProb[pos][stat];
                    if (!this.printALot) continue;
                    System.out.println("At pos=" + pos + "  the node to y=" + stat + " has probability " + np);
                }
                for (trans = 0; trans < this.nTransitions; ++trans) {
                    short yprev = this.transitionFrom[trans];
                    short y4 = this.transitionTo[trans];
                    ep = this.edgeProb[pos][trans];
                    if (!this.printALot) continue;
                    System.out.println("At pos=" + pos + "  the edge from yprev=" + yprev + " to y=" + y4 + " has probability " + ep);
                }
            }
            Arrays.fill(this.scoreBeta[len - 1], 0.0);
            for (pos = len - 1; pos > 0; --pos) {
                for (int yprev = 0; yprev < this.nStates; ++yprev) {
                    this.scoreBeta[pos - 1][yprev] = 0.0;
                }
                for (trans = 0; trans < this.nTransitions; ++trans) {
                    short yprev = this.transitionFrom[trans];
                    short y5 = this.transitionTo[trans];
                    double[] dArray = this.scoreBeta[pos - 1];
                    short s = yprev;
                    dArray[s] = dArray[s] + this.edgeProb[pos][trans] * this.score.evaluate(yprev, y5, this.data.get(i), pos);
                    ep = this.edgeProb[pos][trans];
                    if (!(ep > 0.0)) continue;
                    double np = this.nodeProb[pos][y5];
                    Assert.a(np > 0.0);
                    double ratio = ep / np;
                    if (!Double.isNaN(ratio)) {
                        double[] dArray2 = this.scoreBeta[pos - 1];
                        short s3 = yprev;
                        dArray2[s3] = dArray2[s3] + this.edgeProb[pos][trans] / this.nodeProb[pos][y5] * this.scoreBeta[pos][y5];
                        continue;
                    }
                    Assert.a(false, "   ep = " + ep + "   np = " + np);
                }
            }
            if (this.printALot) {
                for (pos = 0; pos < len; ++pos) {
                    for (y = 0; y < this.nStates; ++y) {
                        System.out.println("scoreAlpha at position " + pos + " and state " + y + " is " + this.scoreAlpha[pos][y]);
                        System.out.println("scoreBeta at position " + pos + " and state " + y + " is " + this.scoreBeta[pos][y]);
                    }
                }
            }
            Arrays.fill(this.scoreAlpha[0], 0.0);
            for (pos = 1; pos < len; ++pos) {
                for (y = 0; y < this.nStates; ++y) {
                    this.scoreAlpha[pos][y] = 0.0;
                }
                for (trans = 0; trans < this.nTransitions; ++trans) {
                    double ep3 = this.edgeProb[pos][trans];
                    if (!(ep3 > 0.0)) continue;
                    short yprev = this.transitionFrom[trans];
                    short y6 = this.transitionTo[trans];
                    double update = ep3 / this.nodeProb[pos - 1][yprev] * this.scoreAlpha[pos - 1][yprev];
                    double[] dArray = this.scoreAlpha[pos];
                    short s = y6;
                    dArray[s] = dArray[s] + (update += ep3 * this.score.evaluate(yprev, y6, this.data.get(i), pos));
                    if (this.scoreAlphaWriter == null) continue;
                    FileUtil.safeWrite(this.scoreAlphaWriter, String.format("Seq: %d alpha[%d][%d] = %g = %g + Pr: %g * alpha[%d][%d] %g + Pr: %g * Score: %g\n", i, pos, (int)y6, this.scoreAlpha[pos][y6], this.scoreAlpha[pos][y6] - update, ep3 / this.nodeProb[pos - 1][yprev], pos - 1, (int)yprev, this.scoreAlpha[pos - 1][yprev], ep3, this.score.evaluate(yprev, y6, this.data.get(i), pos)));
                }
            }
            for (int jfeat2 = 0; jfeat2 < this.nFeatures; ++jfeat2) {
                Assert.a(!Double.isNaN(grad[jfeat2]));
            }
            double thisresult = 0.0;
            for (int pos3 = 0; pos3 < len; ++pos3) {
                int jfeat3;
                if (pos3 > 0) {
                    for (int trans2 = 0; trans2 < this.nTransitions; ++trans2) {
                        short yprev = this.transitionFrom[trans2];
                        short y7 = this.transitionTo[trans2];
                        thisresult += this.edgeProb[pos3][trans2] * this.score.evaluate(yprev, y7, this.data.get(i), pos3);
                    }
                }
                int overallPosition = seqStart + pos3;
                cacheStart = this.starts[overallPosition];
                cacheStop = this.starts[overallPosition + 1];
                Assert.a(this.nFeatures > 0);
                for (jfeat3 = 0; jfeat3 < this.nFeatures; ++jfeat3) {
                    Assert.a(!Double.isNaN(grad[jfeat3]));
                }
                for (jfeat3 = 0; jfeat3 < this.nFeatures; ++jfeat3) {
                    Assert.a(!Double.isNaN(grad[jfeat3]));
                }
                this.updateGrad(overallPosition, pos3, i, cacheStart, cacheStop, seqGrad);
                if (!this.printALot) continue;
                System.out.println("After adding in pos=" + pos3 + " the gradient is " + grad[0] + " , " + grad[1]);
            }
            for (jfeat = 0; jfeat < this.nFeatures; ++jfeat) {
                Assert.a(!Double.isNaN(grad[jfeat]));
            }
            if (this.printALot) {
                System.out.println("For sequence number " + i + " the expectation of S is " + thisresult);
            }
            for (jfeat = 0; jfeat < this.nFeatures; ++jfeat) {
                if (this.printALot) {
                    System.out.println("For seq " + i + " and feature " + jfeat + " the local expect is " + this.localExpects[jfeat] + "  and expects is " + this.expects[jfeat]);
                }
                int n = jfeat;
                grad[n] = grad[n] + (seqGrad[jfeat] - this.localExpects[jfeat] * thisresult);
            }
            for (jfeat = 0; jfeat < this.nFeatures; ++jfeat) {
                Assert.a(!Double.isNaN(grad[jfeat]));
            }
            if (debug) {
                log.debug((Object)String.format("Iter: %d Seq: %d Expected Score: %g Grad: %s Expected Features: %s Expected Product: %s", this.iter, i, thisresult, ColtUtil.format(grad), ColtUtil.format(this.localExpects), ColtUtil.format(seqGrad)));
            }
            result += thisresult;
            seqStart += len;
        }
        if (log.isInfoEnabled()) {
            log.info((Object)String.format("Iter: %d Val: %g Grad: %s", this.iter, result, ColtUtil.format(grad)));
        }
        Assert.a(!Double.isNaN(result));
        for (int jfeat = 0; jfeat < this.nFeatures; ++jfeat) {
            Assert.a(!Double.isNaN(grad[jfeat]));
        }
        result /= (double)this.totalPositions;
        for (i = 0; i < grad.length; ++i) {
            grad[i] = grad[i] / (double)this.totalPositions;
        }
        ++this.iter;
        FileUtil.safeClose(this.scoreAlphaWriter);
        FileUtil.safeClose(this.expectedProductWriter);
        return result;
    }

    void updateGrad(int overallPos, int pos, int seqI, int posCurrent, int posStop, double[] grad) {
        for (int jfeat = 0; jfeat < this.nFeatures; ++jfeat) {
            Assert.a(!Double.isNaN(grad[jfeat]));
        }
        int constCurrent = 0;
        for (int j = 0; j < this.nFeatures; ++j) {
            if (!Double.isNaN(grad[j])) continue;
            Assert.a(false, "j= " + j);
        }
        short constId = -1;
        short constPotential = -1;
        double constVal = Double.NaN;
        if (constCurrent < this.nConstantFeatures) {
            constId = this.id[constCurrent];
            constPotential = this.potentialIx[constCurrent];
            constVal = this.val[constCurrent];
            ++constCurrent;
        }
        short posId = -1;
        short posPotential = -1;
        double posVal = Double.NaN;
        if (posCurrent < posStop) {
            posId = this.id[posCurrent];
            posPotential = this.potentialIx[posCurrent];
            posVal = this.val[posCurrent];
            ++posCurrent;
        }
        int invalidIndex = overallPos * this.nPotentials;
        block2: for (short potential : this.orderedPotentials) {
            short y;
            int trans;
            double ep;
            double inner;
            int trans2;
            int yprev;
            short y2;
            boolean invalid = this.invalidTransitions[invalidIndex + potential];
            while (constPotential == potential) {
                if (!(invalid || pos <= 0 && potential >= this.nStates)) {
                    if (potential < this.nStates) {
                        y2 = potential;
                        if (pos > 0) {
                            for (yprev = 0; yprev < this.nStates; ++yprev) {
                                trans2 = this.transitionIndex.getQuick(yprev, y2);
                                if (trans2 < 0) continue;
                                Assert.a(yprev == this.transitionFrom[trans2]);
                                Assert.a(y2 == this.transitionTo[trans2]);
                                inner = 0.0;
                                ep = this.edgeProb[pos][trans2];
                                if (ep > 0.0) {
                                    inner += this.edgeProb[pos][trans2] / this.nodeProb[pos - 1][yprev] * this.scoreAlpha[pos - 1][yprev];
                                    inner += this.edgeProb[pos][trans2] * this.score.evaluate(yprev, y2, this.data.get(seqI), pos);
                                    short s = constId;
                                    grad[s] = grad[s] + constVal * (inner += this.edgeProb[pos][trans2] / this.nodeProb[pos][y2] * this.scoreBeta[pos][y2]);
                                    if (this.expectedProductWriter != null) {
                                        FileUtil.safeWrite(this.expectedProductWriter, String.format("Seq: %d Pos: %d State: %d\tFeat: %d = %g = %g + Val: %g * (s: %g * ep: %g + a: %g * ms: %g + b: %g * me: %g)\n", seqI, pos, (int)y2, constId, grad[constId], grad[constId] - inner * constVal, constVal, this.score.evaluate(yprev, y2, this.data.get(seqI), pos), this.edgeProb[pos][trans2], this.scoreAlpha[pos - 1][yprev], this.edgeProb[pos][trans2] / this.nodeProb[pos - 1][yprev], this.scoreBeta[pos][y2], this.edgeProb[pos][trans2] / this.nodeProb[pos][y2]));
                                    }
                                }
                                if (!Double.isNaN(grad[constId])) continue;
                                Assert.a(false, "posVal = " + posVal + "  inner = " + inner + "pos = " + pos + "   trans = " + trans2);
                            }
                        } else {
                            double inner2 = this.scoreBeta[pos][y2];
                            short s = constId;
                            grad[s] = grad[s] + constVal * inner2;
                            if (this.expectedProductWriter != null) {
                                FileUtil.safeWrite(this.expectedProductWriter, String.format("Seq: %d Pos: %d State: %d\tFeat: %d = %g = %g + Val: %g * Beta[%d][%d]: %g:\n", seqI, pos, (int)y2, constId, grad[constId], grad[constId] - inner2 * constVal, constVal, pos, (int)y2, inner2));
                            }
                        }
                    } else {
                        trans = potential - this.nStates;
                        yprev = this.transitionFrom[trans];
                        y = this.transitionTo[trans];
                        inner = 0.0;
                        ep = this.edgeProb[pos][trans];
                        if (ep > 0.0) {
                            inner += ep / this.nodeProb[pos - 1][yprev] * this.scoreAlpha[pos - 1][yprev];
                            inner += ep * this.score.evaluate(yprev, y, this.data.get(seqI), pos);
                            short s = constId;
                            grad[s] = grad[s] + constVal * (inner += ep / this.nodeProb[pos][y] * this.scoreBeta[pos][y]);
                            if (this.expectedProductWriter != null) {
                                FileUtil.safeWrite(this.expectedProductWriter, String.format("Seq: %d Pos: %d Edge: %d-%d\tFeat: %d = %g = %g + Val: %g * (s: %g * ep: %g + a: %g * ms: %g + b: %g * me: %g)\n", seqI, pos, yprev, (int)y, constId, grad[constId], grad[constId] - inner * constVal, constVal, this.score.evaluate(yprev, y, this.data.get(seqI), pos), ep, this.scoreAlpha[pos - 1][yprev], ep / this.nodeProb[pos - 1][yprev], this.scoreBeta[pos][y], ep / this.nodeProb[pos][y]));
                            }
                        }
                        if (Double.isNaN(grad[constId])) {
                            Assert.a(false, "posVal = " + posVal + "  inner = " + inner + "pos = " + pos + "   trans = " + trans);
                        }
                    }
                }
                if (constCurrent >= this.nConstantFeatures) break;
                constId = this.id[constCurrent];
                constVal = this.val[constCurrent];
                constPotential = this.potentialIx[constCurrent];
                ++constCurrent;
            }
            for (int j = 0; j < this.nFeatures; ++j) {
                if (!Double.isNaN(grad[j])) continue;
                Assert.a(false, "j= " + j);
            }
            if (invalid) continue;
            while (posPotential == potential) {
                if (potential < this.nStates) {
                    y2 = potential;
                    if (pos > 0) {
                        for (yprev = 0; yprev < this.nStates; ++yprev) {
                            trans2 = this.transitionIndex.getQuick(yprev, y2);
                            if (trans2 < 0) continue;
                            Assert.a(yprev == this.transitionFrom[trans2]);
                            Assert.a(y2 == this.transitionTo[trans2]);
                            inner = 0.0;
                            ep = this.edgeProb[pos][trans2];
                            if (ep > 0.0) {
                                Assert.a(this.nodeProb[pos - 1][yprev] >= ep - 1.0E-8);
                                Assert.a(this.nodeProb[pos - 1][yprev] > 0.0);
                                Assert.a(this.nodeProb[pos][y2] >= ep - 1.0E-8);
                                Assert.a(this.nodeProb[pos][y2] > 0.0);
                                inner += ep / this.nodeProb[pos - 1][yprev] * this.scoreAlpha[pos - 1][yprev];
                                inner += ep * this.score.evaluate(yprev, y2, this.data.get(seqI), pos);
                                if (Double.isNaN(inner += ep / this.nodeProb[pos][y2] * this.scoreBeta[pos][y2])) {
                                    Assert.a(false, "npyprev = " + this.nodeProb[pos - 1][yprev] + "  npy = " + this.nodeProb[pos][y2] + "pos = " + pos + "   ep = " + ep + "  score = " + this.score.evaluate(yprev, y2, this.data.get(seqI), pos) + "  alpha= " + this.scoreAlpha[pos - 1][yprev] + "   beta=" + this.scoreBeta[pos][y2]);
                                }
                                short s = posId;
                                grad[s] = grad[s] + posVal * inner;
                                if (this.expectedProductWriter != null) {
                                    FileUtil.safeWrite(this.expectedProductWriter, String.format("Seq: %d Pos: %d State: %d\tFeat: %d = %g = %g + Val: %g * (s: %g * ep: %g + a: %g * ms: %g + b: %g * me: %g)\n", seqI, pos, (int)y2, posId, grad[posId], grad[posId] - inner * posVal, posVal, this.score.evaluate(yprev, y2, this.data.get(seqI), pos), this.edgeProb[pos][trans2], this.scoreAlpha[pos - 1][yprev], this.edgeProb[pos][trans2] / this.nodeProb[pos - 1][yprev], this.scoreBeta[pos][y2], this.edgeProb[pos][trans2] / this.nodeProb[pos][y2]));
                                }
                            }
                            if (!Double.isNaN(grad[posId])) continue;
                            Assert.a(false, "posVal = " + posVal + "  inner = " + inner + "pos = " + pos + "   trans = " + trans2);
                        }
                    } else {
                        double inner3 = this.scoreBeta[pos][y2];
                        short s = posId;
                        grad[s] = grad[s] + posVal * inner3;
                    }
                } else {
                    trans = potential - this.nStates;
                    yprev = this.transitionFrom[trans];
                    y = this.transitionTo[trans];
                    inner = 0.0;
                    ep = this.edgeProb[pos][trans];
                    if (ep > 0.0) {
                        inner += ep / this.nodeProb[pos - 1][yprev] * this.scoreAlpha[pos - 1][yprev];
                        inner += ep * this.score.evaluate(yprev, y, this.data.get(seqI), pos);
                        short s = posId;
                        grad[s] = grad[s] + posVal * (inner += ep / this.nodeProb[pos][y] * this.scoreBeta[pos][y]);
                        if (this.expectedProductWriter != null) {
                            FileUtil.safeWrite(this.expectedProductWriter, String.format("Seq: %d Pos: %d Edge: %d-%d\tFeat: %d = %g = %g + Val: %g * (s: %g * ep: %g + a: %g * ms: %g + b: %g * me: %g)\n", seqI, pos, yprev, (int)y, posId, grad[posId], grad[posId] - inner * posVal, posVal, this.score.evaluate(yprev, y, this.data.get(seqI), pos), this.edgeProb[pos][trans], this.scoreAlpha[pos - 1][yprev], this.edgeProb[pos][trans] / this.nodeProb[pos - 1][yprev], this.scoreBeta[pos][y], this.edgeProb[pos][trans] / this.nodeProb[pos][y]));
                        }
                    }
                    if (Double.isNaN(grad[posId])) {
                        Assert.a(false, "posVal = " + posVal + "  inner = " + inner + "pos = " + pos + "   trans = " + trans);
                    }
                }
                if (posCurrent >= posStop) continue block2;
                posId = this.id[posCurrent];
                posVal = this.val[posCurrent];
                posPotential = this.potentialIx[posCurrent];
                ++posCurrent;
            }
        }
        for (int jfeat = 0; jfeat < this.nFeatures; ++jfeat) {
            Assert.a(!Double.isNaN(grad[jfeat]));
        }
        Assert.a(constCurrent == this.nConstantFeatures);
        Assert.a(posCurrent == posStop);
    }

    void updateExpectations(int overallPos, int localPos, boolean includeEdges, int posCurrent, int posStop, double nodeNorm, double edgeNorm, double[] beta) {
        int constCurrent = 0;
        int constId = -1;
        int constPotential = -1;
        double constVal = Double.NaN;
        if (constCurrent < this.nConstantFeatures) {
            constId = this.id[constCurrent];
            constPotential = this.potentialIx[constCurrent];
            constVal = this.val[constCurrent];
            ++constCurrent;
        }
        int posId = -1;
        int posPotential = -1;
        double posVal = Double.NaN;
        if (posCurrent < posStop) {
            posId = this.id[posCurrent];
            posPotential = this.potentialIx[posCurrent];
            posVal = this.val[posCurrent];
            ++posCurrent;
        }
        int currentNode = -1;
        double currentBeta = 0.0;
        int invalidIndex = overallPos * this.nPotentials;
        block0: for (int n : this.orderedPotentials) {
            boolean invalid = this.invalidTransitions[invalidIndex + n];
            double prob = 0.0;
            if (n < this.nStates) {
                currentNode = n;
                currentBeta = beta[currentNode];
                if (!invalid) {
                    prob = this.alpha[currentNode] * currentBeta * nodeNorm;
                }
                if (Double.isNaN(prob)) {
                    Assert.a(false, " alpha" + this.alpha[currentNode] + " beta=" + currentBeta + "  nodeNorm = " + nodeNorm);
                }
                this.nodeProb[localPos][currentNode] = prob;
            } else {
                int trans = n - this.nStates;
                short yprev = this.transitionFrom[trans];
                if (!invalid) {
                    prob = this.prevAlpha[yprev] * this.mi[trans] * currentBeta * edgeNorm;
                }
                if (Double.isNaN(prob)) {
                    Assert.a(false, "prob=" + prob + "  prevAlpha=" + this.prevAlpha[yprev] + "  mi=" + this.mi[trans] + "  currentBeta=" + currentBeta + "  edgeNorm=" + edgeNorm);
                }
                this.edgeProb[localPos][trans] = prob;
            }
            while (constPotential == n) {
                if (!invalid && (includeEdges || n < this.nStates)) {
                    int n2 = constId;
                    this.expects[n2] = this.expects[n2] + prob * constVal;
                    int n3 = constId;
                    this.localExpects[n3] = this.localExpects[n3] + prob * constVal;
                }
                if (constCurrent >= this.nConstantFeatures) break;
                constId = this.id[constCurrent];
                constVal = this.val[constCurrent];
                constPotential = this.potentialIx[constCurrent];
                ++constCurrent;
            }
            if (invalid) continue;
            while (posPotential == n) {
                int n4 = posId;
                this.expects[n4] = this.expects[n4] + prob * posVal;
                int n5 = posId;
                this.localExpects[n5] = this.localExpects[n5] + prob * posVal;
                if (posCurrent >= posStop) continue block0;
                posId = this.id[posCurrent];
                posVal = this.val[posCurrent];
                posPotential = this.potentialIx[posCurrent];
                ++posCurrent;
            }
        }
        Assert.a(constCurrent == this.nConstantFeatures);
        Assert.a(posCurrent == posStop);
    }

    void calcMi(int overallPosition, int current, int stop, double[] lambda) {
        short cachedPotential = -1;
        double cachedVal = Double.NaN;
        if (current < stop) {
            cachedPotential = this.potentialIx[current];
            cachedVal = (double)this.val[current] * lambda[this.id[current]];
            ++current;
        }
        double nodeVal = Double.NaN;
        int invalidIndex = overallPosition * this.nPotentials;
        for (short potential : this.orderedPotentials) {
            double features;
            boolean invalid = overallPosition != -1 && this.invalidTransitions[invalidIndex + potential];
            double d = features = invalid ? Double.NEGATIVE_INFINITY : 0.0;
            while (cachedPotential == potential) {
                features += cachedVal;
                if (current >= stop) break;
                cachedVal = (double)this.val[current] * lambda[this.id[current]];
                cachedPotential = this.potentialIx[current];
                ++current;
            }
            if (potential < this.nStates) {
                nodeVal = features;
                continue;
            }
            int transition = potential - this.nStates;
            this.mi[transition] = this.exp(features + nodeVal + this.constMi[transition]);
        }
        if (current != stop) {
            Assert.a(false, "Pos: ", overallPosition, " Expected ", stop, " features only found ", current);
        }
    }

    void calcStartAlpha(int overallPosition, int posCurrent, int posStop, double[] lambda) {
        int constCurrent = 0;
        short constPotential = -1;
        double constVal = Double.NaN;
        if (constCurrent < this.nConstantFeatures) {
            constPotential = this.potentialIx[constCurrent];
            constVal = (double)this.val[constCurrent] * lambda[this.id[constCurrent]];
            ++constCurrent;
        }
        short posPotential = -1;
        double posVal = Double.NaN;
        if (posCurrent < posStop) {
            posPotential = this.potentialIx[posCurrent];
            posVal = (double)this.val[posCurrent] * lambda[this.id[posCurrent]];
            ++posCurrent;
        }
        int invalidIndex = overallPosition * this.nPotentials;
        for (short potential : this.orderedPotentials) {
            double features;
            boolean invalid = this.invalidTransitions[invalidIndex + potential];
            double d = features = invalid ? Double.NEGATIVE_INFINITY : 0.0;
            while (constPotential == potential) {
                if (potential < this.nStates) {
                    Assert.a(!Double.isNaN(features += constVal));
                }
                if (constCurrent >= this.nConstantFeatures) break;
                constVal = (double)this.val[constCurrent] * lambda[this.id[constCurrent]];
                constPotential = this.potentialIx[constCurrent];
                ++constCurrent;
            }
            if (potential >= this.nStates) continue;
            while (posPotential == potential) {
                features += posVal;
                if (posCurrent >= posStop) break;
                posVal = (double)this.val[posCurrent] * lambda[this.id[posCurrent]];
                posPotential = this.potentialIx[posCurrent];
                ++posCurrent;
            }
            this.alpha[potential] = this.exp(features);
        }
        Assert.a(constCurrent == this.nConstantFeatures);
        Assert.a(posCurrent == posStop);
    }

    private double normalizePotential(double[] vec) {
        double norm = 0.0;
        int len = vec.length;
        for (int i = 0; i < len; ++i) {
            norm += vec[i];
        }
        Assert.a(norm > 0.0);
        double mult = 1.0 / norm;
        int i = 0;
        while (i < len) {
            int n = i++;
            vec[n] = vec[n] * mult;
        }
        return norm;
    }

    private void quickBetaUpdate(double[] lastBeta, double[] newBeta) {
        Arrays.fill(newBeta, 0.0);
        double nodeVal = 0.0;
        for (short potential : this.orderedPotentials) {
            short from;
            if (potential < this.nStates) {
                nodeVal = lastBeta[potential];
                continue;
            }
            int trans = potential - this.nStates;
            short s = from = this.transitionFrom[trans];
            newBeta[s] = newBeta[s] + this.mi[trans] * nodeVal;
        }
    }

    private void quickAlphaUpdate(double[] lastAlpha, double[] newAlpha) {
        double nodeVal = 0.0;
        int lastState = -1;
        for (int n : this.orderedPotentials) {
            if (n < this.nStates) {
                if (lastState != -1) {
                    newAlpha[lastState] = nodeVal;
                }
                lastState = n;
                nodeVal = 0.0;
                continue;
            }
            int trans = n - this.nStates;
            short from = this.transitionFrom[trans];
            nodeVal += lastAlpha[from] * this.mi[trans];
        }
        newAlpha[lastState] = nodeVal;
    }

    public CacheProcessor getCacheProcessor() {
        return this.cacheProcessor;
    }

    public void setCacheProcessor(CacheProcessor cacheProcessor) {
        this.cacheProcessor = cacheProcessor;
    }

    public void setLocalPathSimilarityScore(LocalPathSimilarityScore s) {
        this.score = s;
    }

    public LocalPathSimilarityScore getLocalPathSimilarityScore() {
        return this.score;
    }

    public String getScoreAlphaFile() {
        return this.scoreAlphaFile;
    }

    public void setScoreAlphaFile(String scoreAlphaFile) {
        this.scoreAlphaFile = scoreAlphaFile;
    }

    public String getExpectedProductFile() {
        return this.expectedProductFile;
    }

    public void setExpectedProductFile(String expectedProductFile) {
        this.expectedProductFile = expectedProductFile;
    }
}

