/*
 * Decompiled with CFR 0.152.
 */
package calhoun.analysis.crf.solver;

import calhoun.analysis.crf.ModelManager;
import calhoun.analysis.crf.io.InputSequence;
import calhoun.analysis.crf.io.TrainingSequence;
import calhoun.analysis.crf.solver.CacheProcessor;
import calhoun.util.Assert;
import calhoun.util.DenseBooleanMatrix2D;
import calhoun.util.DenseIntMatrix2D;
import calhoun.util.FileUtil;
import java.io.BufferedWriter;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public abstract class CacheProcessorBasic
implements CacheProcessor {
    private static final Log log = LogFactory.getLog(CacheProcessorBasic.class);
    String trainingFile = null;
    BufferedWriter trainingWriter = null;
    protected short[] maxStateLengths;
    protected List<? extends TrainingSequence<?>> data;
    protected ModelManager fm;
    protected CacheProcessor.SolverSetup modelInfo;
    protected CacheProcessor.FeatureEvaluation[] evals;
    protected CacheProcessor.LengthFeatureEvaluation[][] lengthEvals;
    protected double[] featureSums;
    protected double[][] seqFeatureSums;

    protected void computeFeatureSums() {
        this.trainingWriter = FileUtil.safeOpen(this.trainingFile);
        int numFeatures = this.fm.getNumFeatures();
        this.seqFeatureSums = new double[this.data.size()][numFeatures];
        this.featureSums = new double[numFeatures];
        double[] lastSegmentFeatureSums = new double[numFeatures];
        for (int seqnum = 0; seqnum < this.data.size(); ++seqnum) {
            TrainingSequence<?> seq = this.data.get(seqnum);
            int prevSegmentState = -1;
            int segmentLength = 1;
            int lastStart = 0;
            for (int pos = 0; pos < seq.length(); ++pos) {
                int state = seq.getY(pos);
                this.evaluatePosition(seqnum, pos);
                this.sumFeatures(this.featureSums, this.seqFeatureSums[seqnum], this.evals[state]);
                if (pos > 0) {
                    int pot = this.modelInfo.nStates + this.modelInfo.transitionIndex.getQuick(seq.getY(pos - 1), state);
                    this.sumFeatures(this.featureSums, this.seqFeatureSums[seqnum], this.evals[pot]);
                }
                if (pos == seq.length() - 1 || state != seq.getY(pos + 1)) {
                    int nodeIndex;
                    this.evaluateSegmentsEndingAt(seqnum, pos);
                    for (nodeIndex = 0; nodeIndex < this.modelInfo.statesWithLookback.length && this.modelInfo.statesWithLookback[nodeIndex].state != state; ++nodeIndex) {
                    }
                    if (nodeIndex != this.modelInfo.statesWithLookback.length) {
                        int lbIndex = 0;
                        while (this.lengthEvals[nodeIndex][lbIndex].lookback != segmentLength - 1) {
                            if (this.lengthEvals[nodeIndex][lbIndex].lookback == -1) {
                                Assert.a(false, "Lookback not listed. State: ", this.modelInfo.statesWithLookback[nodeIndex].state, " Seq: ", seqnum, " Pos: ", pos, " Len: ", segmentLength, " # Lookbacks: ", lbIndex);
                            }
                            ++lbIndex;
                        }
                        this.sumFeatures(this.featureSums, this.seqFeatureSums[seqnum], this.lengthEvals[nodeIndex][lbIndex].nodeEval);
                        if (prevSegmentState != -1 && this.lengthEvals[nodeIndex][lbIndex].edgeEvals != null) {
                            throw new UnsupportedOperationException("ComputeFeatureSums doesn't handle explicit length edge evals yet.");
                        }
                    }
                    prevSegmentState = state;
                    lastStart = pos - segmentLength + 1;
                    segmentLength = 1;
                } else {
                    ++segmentLength;
                }
                if (this.trainingWriter == null || pos != seq.length() - 1 && state == seq.getY(pos + 1)) continue;
                for (int i = 0; i < numFeatures; ++i) {
                    FileUtil.safeWrite(this.trainingWriter, String.format("Seq: %d Seg: %d-%d State: %d Feat: %d Val: %f\n", seqnum, lastStart, pos, state, i, this.featureSums[i] - lastSegmentFeatureSums[i]));
                    lastSegmentFeatureSums[i] = this.featureSums[i];
                }
            }
        }
        if (log.isDebugEnabled()) {
            log.debug((Object)"We just computed the feature sums on the training data.  The feature sums are (id,name,sum)");
            for (int j = 0; j < numFeatures; ++j) {
                log.debug((Object)("(" + j + "," + this.fm.getFeatureName(j) + "," + this.featureSums[j] + ")"));
            }
        }
        FileUtil.safeClose(this.trainingWriter);
    }

    void sumFeatures(double[] featureSums, double[] seqFeatureSum, CacheProcessor.FeatureEvaluation eval) {
        int i = 0;
        while (eval.index[i] != -1) {
            short s = eval.index[i];
            featureSums[s] = featureSums[s] + (double)eval.value[i];
            short s2 = eval.index[i];
            seqFeatureSum[s2] = seqFeatureSum[s2] + (double)eval.value[i];
            ++i;
        }
    }

    @Override
    public void setInputData(ModelManager fm, InputSequence<?> seq) {
        int[] dummyHiddenStates = new int[seq.length()];
        Arrays.fill(dummyHiddenStates, Integer.MIN_VALUE);
        this.setTrainingData(fm, Collections.singletonList(new TrainingSequence(seq, dummyHiddenStates)));
    }

    @Override
    public void setTrainingData(ModelManager fm, List<? extends TrainingSequence<?>> data) {
        this.fm = fm;
        this.data = data;
    }

    @Override
    public double[] getFeatureSums() {
        return this.featureSums;
    }

    @Override
    public double[][] getSequenceFeatureSums() {
        return this.seqFeatureSums;
    }

    void basicInit(boolean allPaths) {
        this.initSequenceInfo();
        this.initTransitions(allPaths);
        this.evals = CacheProcessor.FeatureEvaluation.create(this.modelInfo.nPotentials, Math.max(5, this.modelInfo.nFeatures));
    }

    @Override
    public CacheProcessor.SolverSetup getSolverSetup() {
        return this.modelInfo;
    }

    @Override
    public CacheProcessor.FeatureEvaluation[] getFeatureEvaluations() {
        return this.evals;
    }

    @Override
    public CacheProcessor.LengthFeatureEvaluation[][] getLengthFeatureEvaluations() {
        return this.lengthEvals;
    }

    protected void initSequenceInfo() {
        this.modelInfo = new CacheProcessor.SolverSetup();
        this.modelInfo.nFeatures = this.fm.getNumFeatures();
        this.modelInfo.nStates = this.fm.getNumStates();
        this.modelInfo.nSeqs = this.data.size();
        this.modelInfo.seqOffsets = new int[this.modelInfo.nSeqs + 1];
        this.modelInfo.seqOffsets[0] = 0;
        this.modelInfo.longestSeq = 0;
        this.modelInfo.totalPositions = 0;
        for (int i = 0; i < this.modelInfo.nSeqs; ++i) {
            TrainingSequence<?> seq = this.data.get(i);
            int seqLen = seq.length();
            this.modelInfo.longestSeq = Math.max(seqLen, this.modelInfo.longestSeq);
            this.modelInfo.seqOffsets[i + 1] = this.modelInfo.seqOffsets[i] + seqLen;
            this.modelInfo.totalPositions += seqLen;
        }
    }

    protected boolean isSemiMarkovState(int state) {
        return this.maxStateLengths == null ? false : this.maxStateLengths[state] > 1;
    }

    protected void initTransitions(boolean allPaths) {
        this.modelInfo.transitionIndex = new DenseIntMatrix2D(this.modelInfo.nStates, this.modelInfo.nStates);
        this.modelInfo.transitionIndex.assign(-1);
        this.modelInfo.selfTransitions = new int[this.modelInfo.nStates];
        Arrays.fill(this.modelInfo.selfTransitions, -1);
        DenseBooleanMatrix2D transitions = this.fm.getLegalTransitions();
        if (transitions == null || allPaths) {
            transitions = new DenseBooleanMatrix2D(this.modelInfo.nStates, this.modelInfo.nStates);
            transitions.assign(true);
        }
        int count = 0;
        for (int i = 0; i < this.modelInfo.nStates; i = (int)((short)(i + 1))) {
            for (int j = 0; j < this.modelInfo.nStates; j = (int)((short)(j + 1))) {
                if (!transitions.getQuick(i, j) && (i != j || !this.isSemiMarkovState(i))) continue;
                count = (short)(count + 1);
            }
        }
        this.modelInfo.nTransitions = count;
        this.modelInfo.nPotentials = this.modelInfo.nStates + this.modelInfo.nTransitions;
        this.modelInfo.orderedPotentials = new short[this.modelInfo.nPotentials];
        this.modelInfo.transitionFrom = new short[this.modelInfo.nTransitions];
        this.modelInfo.transitionTo = new short[this.modelInfo.nTransitions];
        count = 0;
        int orderedCount = 0;
        for (int i = 0; i < this.modelInfo.nStates; i = (int)((short)(i + 1))) {
            this.modelInfo.orderedPotentials[orderedCount] = i;
            ++orderedCount;
            for (int j = 0; j < this.modelInfo.nStates; j = (int)((short)(j + 1))) {
                if (!transitions.getQuick(j, i) && (i != j || !this.isSemiMarkovState(i))) continue;
                this.modelInfo.orderedPotentials[orderedCount] = (short)(this.modelInfo.nStates + count);
                ++orderedCount;
                if (i == j) {
                    this.modelInfo.selfTransitions[i] = count;
                }
                this.modelInfo.transitionIndex.setQuick(j, i, count);
                this.modelInfo.transitionFrom[count] = j;
                this.modelInfo.transitionTo[count] = i;
                count = (short)(count + 1);
            }
        }
        Assert.a(count == this.modelInfo.nTransitions);
    }

    public String getTrainingFile() {
        return this.trainingFile;
    }

    public void setTrainingFile(String trainingFile) {
        this.trainingFile = trainingFile;
    }

    @Override
    public List<? extends TrainingSequence<?>> getData() {
        return this.data;
    }
}

