/*
 * Decompiled with CFR 0.152.
 */
package calhoun.analysis.crf.solver.check;

import calhoun.analysis.crf.ModelManager;
import calhoun.analysis.crf.io.InputSequence;
import calhoun.analysis.crf.io.TrainingSequence;
import calhoun.analysis.crf.solver.check.ArrayFeatureList;
import calhoun.analysis.crf.solver.check.TransitionInfo;
import calhoun.util.Assert;
import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import java.util.Arrays;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public class FeatureCalculator {
    private static final Log log = LogFactory.getLog(FeatureCalculator.class);
    boolean debug = log.isDebugEnabled();
    public ArrayFeatureList result;
    ModelManager manager;
    double[] lambda;
    double weightedFeatureSum;
    double[] featureSums;
    int nFeatures;
    int nStates;
    TransitionInfo transitions;

    public FeatureCalculator(ModelManager manager, double[] lambda, TransitionInfo transitions) {
        this.manager = manager;
        this.lambda = lambda;
        this.nFeatures = manager.getNumFeatures();
        this.nStates = manager.getNumStates();
        this.result = new ArrayFeatureList(manager);
        this.transitions = transitions;
    }

    public double getWeightedFeatureSum() {
        return this.weightedFeatureSum;
    }

    public double[] getFeatureSums() {
        return this.featureSums;
    }

    public void resetFeatureSums() {
        this.weightedFeatureSum = 0.0;
        if (this.featureSums == null) {
            this.featureSums = new double[this.nFeatures];
        } else {
            Arrays.fill(this.featureSums, 0.0);
        }
    }

    public double calcEdgeValue(InputSequence seq, int pos, int prevState, int state) {
        if (this.debug) {
            log.debug((Object)String.format("Edge - pos %d prevState %d state %d", pos, prevState, state));
        }
        if (this.transitions.transitionIndex.getQuick(prevState, state) == -1) {
            return Double.NEGATIVE_INFINITY;
        }
        this.result.evaluateEdge(seq, pos, prevState, state);
        boolean updateSum = this.checkForUpdate(seq, pos, prevState, state);
        return this.calcRet(updateSum);
    }

    public boolean isValidTransition(int prevState, int state) {
        return this.transitions.transitionIndex.getQuick(prevState, state) != -1;
    }

    public double calcNodeValue(InputSequence seq, int pos, int state) {
        if (this.debug) {
            log.debug((Object)String.format("Node - pos %d state %d", pos, state));
        }
        this.result.evaluateNode(seq, pos, state);
        boolean updateSum = this.checkForUpdate(seq, pos, -1, state);
        return this.calcRet(updateSum);
    }

    public void computeMi(InputSequence seq, int pos, DoubleMatrix2D mi, DoubleMatrix1D ri) {
        for (int current = 0; current < this.nStates; ++current) {
            double nodeVal = this.calcNodeValue(seq, pos, current);
            if (ri != null) {
                ri.setQuick(current, nodeVal);
            }
            if (pos <= 0) continue;
            for (int prev = 0; prev < this.nStates; ++prev) {
                mi.setQuick(prev, current, nodeVal + this.calcEdgeValue(seq, pos, prev, current));
            }
        }
    }

    public void computeSparseMi(InputSequence seq, int pos, double[] mi, double[] ri) {
        double nodeVal = 0.0;
        int currentState = -1;
        for (int n : this.transitions.orderedPotentials) {
            if (n < this.nStates) {
                nodeVal = this.calcNodeValue(seq, pos, n);
                currentState = n;
                if (ri == null) continue;
                ri[n] = nodeVal;
                continue;
            }
            if (pos <= 0) continue;
            mi[n - this.nStates] = nodeVal + this.calcEdgeValue(seq, pos, this.transitions.transitionFrom[n - this.nStates], currentState);
        }
    }

    public double calcNodeLengthValue(InputSequence seq, int pos, int len, int state) {
        this.result.evaluateNodeLength(seq, pos, len, state);
        return this.calcRet(false);
    }

    public double calcEdgeLengthValue(InputSequence seq, int pos, int len, int prevState, int state) {
        this.result.evaluateEdgeLength(seq, pos, len, prevState, state);
        return this.calcRet(false);
    }

    boolean checkForUpdate(InputSequence seq, int pos, int previousState, int state) {
        if (this.featureSums != null) {
            TrainingSequence train = (TrainingSequence)seq;
            boolean prevMatches = pos == 0 || previousState == -1 || previousState == train.getY(pos - 1);
            return prevMatches && state == train.getY(pos);
        }
        return false;
    }

    public double calcRet(boolean updateSum) {
        if (!this.result.valid) {
            return Double.NEGATIVE_INFINITY;
        }
        double ret = 0.0;
        int count = this.result.currentSize;
        int[] indices = this.result.indices;
        double[] vals = this.result.values;
        for (int i = 0; i < count; ++i) {
            int index = indices[i];
            double val = vals[i] * this.lambda[index];
            ret += val;
            if (updateSum) {
                int n = index;
                this.featureSums[n] = this.featureSums[n] + vals[i];
                this.weightedFeatureSum += val;
            }
            if (!this.debug) continue;
            log.debug((Object)("Adding feature " + index + " val: " + val));
        }
        Assert.a(!Double.isNaN(ret));
        return ret;
    }
}

