/*
 * Decompiled with CFR 0.152.
 */
package calhoun.analysis.crf.solver;

import calhoun.analysis.crf.CRFObjectiveFunctionGradient;
import calhoun.analysis.crf.ModelManager;
import calhoun.analysis.crf.io.TrainingSequence;
import calhoun.analysis.crf.solver.CacheProcessor;
import calhoun.analysis.crf.solver.check.AllSparseLengthCacheProcessor;
import calhoun.util.Assert;
import calhoun.util.ColtUtil;
import java.util.Arrays;
import java.util.List;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public class MaximumLikelihoodGradient
implements CRFObjectiveFunctionGradient {
    private static final Log log = LogFactory.getLog(MaximumLikelihoodGradient.class);
    boolean debug = log.isDebugEnabled();
    CacheProcessor cacheProcessor = new AllSparseLengthCacheProcessor();
    CacheProcessor.SolverSetup modelInfo;
    CacheProcessor.FeatureEvaluation[] evals;
    boolean[] invalidTransitions;
    int miLength;
    double[] mi;
    int iter = 0;
    double[] prevAlpha;
    double[] alpha;
    double[][] betas;
    double[] betaNorms;
    double[] expects;
    private double[] featureSums;

    public CacheProcessor getCacheProcessor() {
        return this.cacheProcessor;
    }

    public void setCacheProcessor(CacheProcessor cacheProcessor) {
        this.cacheProcessor = cacheProcessor;
    }

    @Override
    public void setTrainingData(ModelManager fm, List<? extends TrainingSequence<?>> data) {
        this.cacheProcessor.setTrainingData(fm, data);
        this.modelInfo = this.cacheProcessor.getSolverSetup();
        this.evals = this.cacheProcessor.getFeatureEvaluations();
        this.invalidTransitions = this.cacheProcessor.getInvalidTransitions();
        this.miLength = this.modelInfo.nTransitions;
        this.mi = new double[this.miLength];
        this.expects = new double[this.modelInfo.nFeatures];
        this.prevAlpha = new double[this.modelInfo.nStates];
        this.alpha = new double[this.modelInfo.nStates];
        this.betas = new double[this.modelInfo.longestSeq][this.modelInfo.nStates];
        this.betaNorms = new double[this.modelInfo.longestSeq];
    }

    @Override
    public double apply(double[] param, double[] grad) {
        Arrays.fill(grad, 0.0);
        double result = 0.0;
        Arrays.fill(this.expects, 0.0);
        int seqStart = 0;
        for (int i = 0; i < this.modelInfo.nSeqs; ++i) {
            int len = this.modelInfo.seqOffsets[i + 1] - this.modelInfo.seqOffsets[i];
            Arrays.fill(this.betas[len - 1], 1.0);
            this.betaNorms[len - 1] = 0.0;
            for (int pos = len - 1; pos > 0; --pos) {
                this.calcMi(i, pos, param);
                this.quickBetaUpdate(this.betas[pos], this.betas[pos - 1]);
                double n = this.normalizePotential(this.betas[pos - 1]);
                this.betaNorms[pos - 1] = this.betaNorms[pos] + MaximumLikelihoodGradient.log(n);
            }
            double logZ = Double.NEGATIVE_INFINITY;
            double alphaNorm = 0.0;
            double prevAlphaNorm = 0.0;
            for (int pos = 0; pos < len; ++pos) {
                double[] beta = this.betas[pos];
                double betaNorm = this.betaNorms[pos];
                if (pos == 0) {
                    this.calcStartAlpha(i, param);
                    alphaNorm = MaximumLikelihoodGradient.log(this.normalizePotential(this.alpha));
                    logZ = MaximumLikelihoodGradient.log(ColtUtil.dotProduct(this.alpha, beta)) + betaNorm + alphaNorm;
                } else {
                    this.calcMi(i, pos, param);
                    this.quickAlphaUpdate(this.prevAlpha, this.alpha);
                    alphaNorm = prevAlphaNorm + MaximumLikelihoodGradient.log(this.normalizePotential(this.alpha));
                    double newZ = MaximumLikelihoodGradient.log(ColtUtil.dotProduct(this.alpha, beta)) + betaNorm + alphaNorm;
                    Assert.a(Math.abs(newZ - logZ) < 1.0E-7 * Math.abs(logZ), "New Z:", newZ, " Old was: ", logZ);
                }
                double nodeNorm = MaximumLikelihoodGradient.exp(alphaNorm + betaNorm - logZ);
                double edgeNorm = MaximumLikelihoodGradient.exp(prevAlphaNorm + betaNorm - logZ);
                this.updateExpectations(i, pos, nodeNorm, edgeNorm, beta);
                if (!(!this.debug || i >= 2 && i != this.modelInfo.nSeqs - 1 || pos >= 2 && pos < len - 2)) {
                    log.debug((Object)String.format("Pos: %d expects: %s alphas: %s (norm %f) betas: %s (norm %f)", pos, ColtUtil.format(this.expects), ColtUtil.format(this.alpha), alphaNorm, ColtUtil.format(beta), betaNorm));
                }
                double[] swap = this.prevAlpha;
                this.prevAlpha = this.alpha;
                this.alpha = swap;
                prevAlphaNorm = alphaNorm;
            }
            result -= logZ;
            seqStart += len;
        }
        double[] featureSums = this.cacheProcessor.getFeatureSums();
        for (int j = 0; j < this.modelInfo.nFeatures; ++j) {
            result += featureSums[j] * param[j];
            grad[j] = featureSums[j] - this.expects[j];
        }
        if (log.isInfoEnabled()) {
            log.info((Object)String.format("It: %d L=%e, LL=%f, norm(grad): %f Sums: %s Expects: %s Weights: %s Grad (unnorm): %s", this.iter, MaximumLikelihoodGradient.exp(result / (double)this.modelInfo.totalPositions), result / (double)this.modelInfo.totalPositions, ColtUtil.norm(grad) / (double)this.modelInfo.totalPositions, ColtUtil.format(featureSums), ColtUtil.format(this.expects), ColtUtil.format(param), ColtUtil.format(grad)));
        }
        ++this.iter;
        result /= (double)this.modelInfo.totalPositions;
        for (int i = 0; i < grad.length; ++i) {
            grad[i] = grad[i] / (double)this.modelInfo.totalPositions;
        }
        this.featureSums = featureSums;
        return result;
    }

    @Override
    public void clean() {
    }

    void calcMi(int seq, int pos, double[] lambda) {
        this.cacheProcessor.evaluatePosition(seq, pos);
        double nodeVal = Double.NaN;
        int overallPosition = this.modelInfo.seqOffsets[seq] + pos;
        int invalidIndex = overallPosition * this.modelInfo.nPotentials;
        for (short potential : this.modelInfo.orderedPotentials) {
            boolean invalid = this.invalidTransitions[invalidIndex + potential];
            double features = 0.0;
            if (invalid) {
                features = Double.NEGATIVE_INFINITY;
            } else {
                CacheProcessor.FeatureEvaluation potEvals = this.evals[potential];
                short[] indices = potEvals.index;
                float[] vals = potEvals.value;
                int i = 0;
                short index = indices[i];
                while (index >= 0) {
                    features += (double)vals[i] * lambda[index];
                    index = indices[++i];
                }
                if (index == Short.MIN_VALUE) {
                    features = Double.NEGATIVE_INFINITY;
                }
            }
            if (potential < this.modelInfo.nStates) {
                nodeVal = features;
                continue;
            }
            int transition = potential - this.modelInfo.nStates;
            this.mi[transition] = MaximumLikelihoodGradient.exp(features + nodeVal);
        }
    }

    void calcStartAlpha(int seq, double[] lambda) {
        this.cacheProcessor.evaluatePosition(seq, 0);
        int overallPosition = this.modelInfo.seqOffsets[seq];
        int invalidIndex = overallPosition * this.modelInfo.nPotentials;
        for (short potential : this.modelInfo.orderedPotentials) {
            if (potential >= this.modelInfo.nStates) continue;
            boolean invalid = this.invalidTransitions[invalidIndex + potential];
            double features = 0.0;
            if (invalid) {
                features = Double.NEGATIVE_INFINITY;
            } else {
                CacheProcessor.FeatureEvaluation potEvals = this.evals[potential];
                short[] indices = potEvals.index;
                float[] vals = potEvals.value;
                int i = 0;
                short index = indices[i];
                while (index >= 0) {
                    features += (double)vals[i] * lambda[index];
                    index = indices[++i];
                }
                if (index == Short.MIN_VALUE) {
                    features = Double.NEGATIVE_INFINITY;
                }
            }
            this.alpha[potential] = MaximumLikelihoodGradient.exp(features);
        }
    }

    private void quickBetaUpdate(double[] lastBeta, double[] newBeta) {
        Arrays.fill(newBeta, 0.0);
        double nodeVal = 0.0;
        for (short potential : this.modelInfo.orderedPotentials) {
            short from;
            if (potential < this.modelInfo.nStates) {
                nodeVal = lastBeta[potential];
                continue;
            }
            int trans = potential - this.modelInfo.nStates;
            short s = from = this.modelInfo.transitionFrom[trans];
            newBeta[s] = newBeta[s] + this.mi[trans] * nodeVal;
        }
    }

    private void quickAlphaUpdate(double[] lastAlpha, double[] newAlpha) {
        double nodeVal = 0.0;
        int lastState = -1;
        for (int n : this.modelInfo.orderedPotentials) {
            if (n < this.modelInfo.nStates) {
                if (lastState != -1) {
                    newAlpha[lastState] = nodeVal;
                }
                lastState = n;
                nodeVal = 0.0;
                continue;
            }
            int trans = n - this.modelInfo.nStates;
            short from = this.modelInfo.transitionFrom[trans];
            nodeVal += lastAlpha[from] * this.mi[trans];
        }
        newAlpha[lastState] = nodeVal;
    }

    void updateExpectations(int seq, int pos, double nodeNorm, double edgeNorm, double[] beta) {
        int currentNode = -1;
        double currentBeta = 0.0;
        int overallPos = this.modelInfo.seqOffsets[seq] + pos;
        int invalidIndex = overallPos * this.modelInfo.nPotentials;
        for (int n : this.modelInfo.orderedPotentials) {
            boolean invalid = this.invalidTransitions[invalidIndex + n];
            if (invalid) continue;
            double prob = 0.0;
            if (n < this.modelInfo.nStates) {
                currentNode = n;
                currentBeta = beta[currentNode];
                prob = this.alpha[currentNode] * currentBeta * nodeNorm;
            } else {
                if (pos == 0) continue;
                int trans = n - this.modelInfo.nStates;
                short yprev = this.modelInfo.transitionFrom[trans];
                prob = this.prevAlpha[yprev] * this.mi[trans] * currentBeta * edgeNorm;
            }
            CacheProcessor.FeatureEvaluation potEvals = this.evals[n];
            short[] indices = potEvals.index;
            float[] vals = potEvals.value;
            int i = 0;
            short index = indices[i];
            if (index == Short.MIN_VALUE) continue;
            while (index != -1) {
                short s = index;
                this.expects[s] = this.expects[s] + prob * (double)vals[i];
                index = indices[++i];
            }
        }
    }

    private double normalizePotential(double[] vec) {
        double norm = 0.0;
        int len = vec.length;
        for (int i = 0; i < len; ++i) {
            norm += vec[i];
        }
        double mult = 1.0 / norm;
        int i = 0;
        while (i < len) {
            int n = i++;
            vec[n] = vec[n] * mult;
        }
        return norm;
    }

    static final double exp(double val) {
        return Math.exp(val);
    }

    static final double log(double val) {
        return Math.log(val);
    }

    public double[] getFeatureSums() {
        return (double[])this.featureSums.clone();
    }
}

