/*
 * Decompiled with CFR 0.152.
 */
package calhoun.analysis.crf.features.generic;

import calhoun.analysis.crf.AbstractFeatureManager;
import calhoun.analysis.crf.CacheStrategySpec;
import calhoun.analysis.crf.FeatureList;
import calhoun.analysis.crf.FeatureManagerEdge;
import calhoun.analysis.crf.ModelManager;
import calhoun.analysis.crf.io.InputSequence;
import calhoun.analysis.crf.io.TrainingSequence;
import calhoun.util.Assert;
import java.util.List;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public class WeightedStateChanges
extends AbstractFeatureManager<Object>
implements FeatureManagerEdge<Object> {
    private static final long serialVersionUID = 8477631359065280630L;
    private static final Log log = LogFactory.getLog(WeightedStateChanges.class);
    boolean debug = log.isDebugEnabled();
    int startIx;
    ModelManager manager;
    float[][] transitions;

    @Override
    public CacheStrategySpec getCacheStrategy() {
        return new CacheStrategySpec(CacheStrategySpec.CacheStrategy.CONSTANT);
    }

    @Override
    public String getFeatureName(int featureIndex) {
        Assert.a(featureIndex == this.startIx, "Invalid feature index: ", featureIndex, ". Must be ", this.startIx);
        return "WeightedEdges";
    }

    @Override
    public int getNumFeatures() {
        return 1;
    }

    @Override
    public void evaluateEdge(InputSequence<?> seq, int pos, int prevState, int state, FeatureList result) {
        if (prevState != state) {
            result.addFeature(this.startIx, this.transitions[prevState][state]);
        }
    }

    @Override
    public void train(int startingIndex, ModelManager modelInfo, List<? extends TrainingSequence<?>> data) {
        this.startIx = startingIndex;
        this.manager = modelInfo;
        int nStates = this.manager.getNumStates();
        this.transitions = new float[nStates][nStates];
        for (int j = 0; j < nStates; ++j) {
            for (int k = 0; k < nStates; ++k) {
                this.transitions[j][k] = j != k ? 1.0f : 0.0f;
            }
        }
        for (TrainingSequence<?> seq : data) {
            for (int pos = 1; pos < seq.length(); ++pos) {
                int end;
                int start = seq.getY(pos - 1);
                if (start == (end = seq.getY(pos))) continue;
                float[] fArray = this.transitions[start];
                int n = end;
                fArray[n] = fArray[n] + 1.0f;
            }
        }
        for (int j = 0; j < nStates; ++j) {
            int k;
            float rowtotal = 0.0f;
            for (k = 0; k < nStates; ++k) {
                rowtotal += this.transitions[j][k];
            }
            for (k = 0; k < nStates; ++k) {
                if (j == k) continue;
                this.transitions[j][k] = (float)Math.log(this.transitions[j][k] / rowtotal);
            }
        }
    }
}

