/*
 * Decompiled with CFR 0.152.
 */
package calhoun.analysis.crf.test;

import calhoun.analysis.crf.CacheStrategySpec;
import calhoun.analysis.crf.Conrad;
import calhoun.analysis.crf.ModelManager;
import calhoun.analysis.crf.SemiMarkovSetup;
import calhoun.analysis.crf.io.IntInput;
import calhoun.analysis.crf.io.StringInput;
import calhoun.analysis.crf.io.TrainingSequence;
import calhoun.analysis.crf.solver.CacheProcessor;
import calhoun.analysis.crf.solver.CacheProcessorDeluxe;
import calhoun.analysis.crf.solver.MaximumLikelihoodSemiMarkovGradient;
import calhoun.analysis.crf.solver.NoCachingCacheProcessor;
import calhoun.analysis.crf.solver.StandardOptimizer;
import calhoun.analysis.crf.solver.check.AllSparseLengthCacheProcessor;
import calhoun.analysis.crf.test.TestFeatureManager;
import calhoun.util.AbstractTestCase;
import calhoun.util.Assert;
import java.util.List;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public class CacheProcessorTest
extends AbstractTestCase {
    private static final Log log = LogFactory.getLog(CacheProcessorTest.class);

    public void testCPDRejectTrainingDataStatesTooLong() throws Exception {
        this.checkFailure("test/input/interval13/config/shortIntergenicCPD.xml", "test/input/interval13/data/tooLong.txt");
    }

    public void testCPDRejectTrainingDataStatesTooShort() throws Exception {
        this.checkFailure("test/input/interval13/config/lengthDependentCPDuncommented.xml", "test/input/interval13/data/tooShort.txt");
    }

    public void testAllSparseRejectTrainingDataStatesTooShort() throws Exception {
        this.checkFailure("test/input/interval13/config/lengthDependentAllSparse.xml", "test/input/interval13/data/tooShort.txt");
    }

    public void testCPDRejectTrainingDataStatesTooShortStart() throws Exception {
        this.checkFailure("test/input/interval13/config/lengthDependentCPDuncommented.xml", "test/input/interval13/data/tooShortStart.txt");
    }

    public void testCPDRejectTrainingDataStatesTooShortEnd() throws Exception {
        this.checkFailure("test/input/interval13/config/lengthDependentCPDuncommented.xml", "test/input/interval13/data/tooShortEnd.txt");
    }

    public void testCPDRejectTrainingDataStatesViolatesConstraints() throws Exception {
        this.checkFailure("test/input/interval13/config/lengthDependentCPDuncommented.xml", "test/input/interval13/data/badConstraints.txt");
    }

    public void testCPDAllSequencesInvalid() throws Exception {
        this.checkFailure("test/input/interval13/config/lengthDependentCPDuncommentedDiscard.xml", "test/input/interval13/data/tooShortEnd.txt");
    }

    public void testCPDAllDiscardInvalid() throws Exception {
        Conrad conrad = new Conrad("test/input/interval13/config/lengthDependentCPDuncommentedDiscard.xml");
        conrad.train("test/input/interval13/data/oneGoodOneBad.txt");
    }

    public void testCPDAllDiscardInvalidLocalScore() throws Exception {
        Conrad conrad = new Conrad("test/input/interval13/config/lengthDependentCPDuncommentedDiscardLocalScore.xml");
        conrad.train("test/input/interval13/data/oneGoodOneBad.txt");
    }

    void checkFailure(String configFile, String data) {
        Conrad conrad = new Conrad(configFile);
        boolean fail = false;
        try {
            conrad.trainFeatures(data);
            conrad.trainWeights(conrad.getInputHandler().readTrainingData(data));
        }
        catch (Exception ex) {
            log.warn((Object)ex);
            fail = true;
        }
        CacheProcessorTest.assertTrue((boolean)fail);
    }

    public void testEdgeTrivial() throws Exception {
        int[][] indices = new int[][]{{-1}, {-1}, {-1}, {-1}, {-1}, {-1}};
        float[][] vals = new float[0][0];
        double[] featureSums = new double[]{0.0};
        List<TrainingSequence<?>> data = IntInput.prepareData("0\n0");
        System.out.println("number of data sequences is " + data.size());
        this.doTest(1, data, 0, 0, indices, vals, featureSums);
    }

    public void testEdgeShort() throws Exception {
        int[][] indices = new int[][]{{-1}, {-1}, {0, -1}, {0, -1}, {0, -1}, {0, -1}};
        float[][] vals = new float[][]{new float[0], new float[0], {-0.4054651f}, {-1.0986123f}, {-1.0986123f}, {-0.4054651f}};
        double[] featureSums = new double[]{-0.4054651f};
        this.doTest(1, IntInput.prepareData("00\n00"), 0, 1, indices, vals, featureSums);
    }

    public void testNode() throws Exception {
        int[][] indices = new int[][]{{0, -1}, {0, -1}, {-1}, {-1}, {-1}, {-1}};
        float[][] vals1 = new float[][]{{-1.0986123f}, {-0.4054651f}, new float[0], new float[0], new float[0], new float[0]};
        double[] featureSums = new double[]{-3.819085121154785};
        this.doTest(0, IntInput.prepareData("001111\n001111"), 0, 0, indices, vals1, featureSums);
        float[][] vals2 = new float[][]{{-1.0986123f}, {-0.4054651f}, new float[0], new float[0], new float[0], new float[0]};
        this.doTest(0, IntInput.prepareData("001111\n001111"), 0, 4, indices, vals2, featureSums);
    }

    public void testTwoFeaturesTrivial() throws Exception {
        int[][] indices1 = new int[][]{{0, -1}, {0, -1}, {-1}, {-1}, {-1}, {-1}};
        float[][] vals1 = new float[][]{{-1.0986123f}, {-0.4054651f}, new float[0], new float[0], new float[0], new float[0]};
        double[] featureSums = new double[]{-2.1972246170043945, -0.4054651f};
        this.doTest(2, IntInput.prepareData("00\n00"), 0, 0, indices1, vals1, featureSums);
        int[][] indices2 = new int[][]{{0, -1}, {0, -1}, {1, -1}, {1, -1}, {1, -1}, {1, -1}};
        float[][] vals2 = new float[][]{{-1.0986123f}, {-0.4054651f}, {-0.4054651f}, {-1.0986123f}, {-1.0986123f}, {-0.4054651f}};
        featureSums = new double[]{-4.394449234008789, -0.8109301924705505};
        this.doTest(2, IntInput.prepareData("00\n00\n00\n00"), 1, 1, indices2, vals2, featureSums);
    }

    public void testTwoFeaturesNonTrivial() throws Exception {
        int[][] indices = new int[][]{{0, -1}, {0, -1}, {1, -1}, {1, -1}, {1, -1}, {1, -1}};
        float[][] vals = new float[][]{{-1.0986123f}, {-0.4054651f}, {-0.4054651f}, {-1.0986123f}, {-1.0986123f}, {-0.4054651f}};
        double[] featureSums = new double[]{-33.54728, -29.963765};
        this.doTest(2, IntInput.prepareData("00001010100100111000\n00001010100100111000\n00001010100100111001\n00001010100100111001\n"), 1, 4, indices, vals, featureSums);
    }

    public void testLengthCacheDummy() throws Exception {
        int[][] lookbacks = new int[][]{{0, 1, -1}, {0, 1, -1}};
        int[][] nodeIndices = new int[][]{{0, 0, 0, -1}, {1, 0, 0, -1}};
        float[][] nodeValues = new float[0][4];
        TestFeatureManager m = new TestFeatureManager(1);
        this.doLengthTest(m, IntInput.prepareData("00\n00"), 0, 1, 2, lookbacks, nodeIndices, nodeValues);
        m = new TestFeatureManager(2);
        lookbacks = new int[][]{{0, 1, 2, 3, -1}, {0, 1, 2, 3, -1}};
        nodeIndices = new int[][]{{0, 0, 0, -1}, {0, 1, 0, -1}, {0, 2, 0, -1}, {0, 3, 0, -1}, {1, 0, 0, -1}};
        this.doLengthTest(m, IntInput.prepareData("0000\n0000"), 0, 3, 2, lookbacks, nodeIndices, nodeValues);
    }

    public void testLengthCache() throws Exception {
        int[][] lookbacks = new int[][]{{0, 1, 2, 3, -1}};
        int[][] nodeIndices = new int[0][4];
        float[][] nodeValues = new float[][]{{0.0f, 0.0f, 0.0f, -0.0f}, {0.0f, 1.0f, 0.0f, -0.11157f}, {0.0f, 2.0f, 0.0f, -0.22314f}, {0.0f, 3.0f, 0.0f, -0.33471f}};
        Conrad c = new Conrad("test/input/semiMarkovTestModelHalfAndHalf.xml");
        ModelManager m = c.getModel();
        List<? extends TrainingSequence<Character>> data = StringInput.prepareData("00110\nATGCA");
        c.trainFeatures(data);
        CacheProcessor cp = ((MaximumLikelihoodSemiMarkovGradient)((StandardOptimizer)c.getOptimizer()).getObjectiveFunction()).getCacheProcessor();
        cp.setTrainingData(m, data);
        cp.evaluateSegmentsEndingAt(0, 3);
        CacheProcessor.LengthFeatureEvaluation[][] lenEvals = cp.getLengthFeatureEvaluations();
        this.checkLengthEvals(lenEvals, 1, lookbacks, nodeIndices, nodeValues);
    }

    void doLengthTest(ModelManager m, List<? extends TrainingSequence<?>> data, int seq, int pos, int nStates, int[][] lookback, int[][] nodeIndices, float[][] nodeValues) {
        AllSparseLengthCacheProcessor cp = new AllSparseLengthCacheProcessor();
        SemiMarkovSetup setup = new SemiMarkovSetup(new short[]{4, 4});
        setup.setIgnoreSemiMarkovSelfTransitions(true);
        cp.setSemiMarkovSetup(setup);
        cp.setTrainingData(m, data);
        cp.evaluateSegmentsEndingAt(seq, pos);
        CacheProcessor.LengthFeatureEvaluation[][] lenEvals = cp.getLengthFeatureEvaluations();
        this.checkLengthEvals(lenEvals, nStates, lookback, nodeIndices, nodeValues);
    }

    /*
     * WARNING - void declaration
     */
    void checkLengthEvals(CacheProcessor.LengthFeatureEvaluation[][] lenEvals, int nStates, int[][] lookback, int[][] nodeIndices, float[][] nodeValues) {
        void var6_7;
        CacheProcessorTest.assertEquals((int)nStates, (int)lenEvals.length);
        boolean bl = false;
        while (var6_7 < lookback.length) {
            for (int j = 0; j < lookback[var6_7].length; ++j) {
                CacheProcessorTest.assertEquals((int)lookback[var6_7][j], (int)lenEvals[var6_7][j].lookback);
            }
            ++var6_7;
        }
        for (Object[] entry : nodeIndices) {
            CacheProcessorTest.assertEquals((int)entry[3], (int)lenEvals[entry[0]][entry[1]].nodeEval.index[entry[2]]);
        }
        float[][] fArray = nodeValues;
        int len$ = fArray.length;
        for (int i$ = 0; i$ < len$; ++i$) {
            Object[] entry;
            entry = fArray[i$];
            CacheProcessorTest.assertEquals((double)entry[3], (double)lenEvals[(int)entry[0]][(int)entry[1]].nodeEval.value[(int)entry[2]], (double)1.0E-4);
        }
    }

    void doTest(int mmNum, List<? extends TrainingSequence<?>> data, int seq, int pos, int[][] indices, float[][] vals, double[] featureSums) {
        TestFeatureManager m = new TestFeatureManager(mmNum);
        AllSparseLengthCacheProcessor cp = new AllSparseLengthCacheProcessor();
        this.testOneCacheProcessor(cp, m, data, seq, pos, indices, vals, featureSums);
        NoCachingCacheProcessor ncp = new NoCachingCacheProcessor();
        this.testOneCacheProcessor(ncp, m, data, seq, pos, indices, vals, featureSums);
        CacheProcessorDeluxe dcp = new CacheProcessorDeluxe();
        this.testOneCacheProcessor(dcp, m, data, seq, pos, indices, vals, featureSums);
        CacheProcessorDeluxe dcp2 = new CacheProcessorDeluxe(CacheStrategySpec.CacheStrategy.CONSTANT);
        this.testOneCacheProcessor(dcp2, m, data, seq, pos, indices, vals, featureSums);
        CacheProcessorDeluxe dcp3 = new CacheProcessorDeluxe(CacheStrategySpec.CacheStrategy.DENSE);
        this.testOneCacheProcessor(dcp3, m, data, seq, pos, indices, vals, featureSums);
        CacheProcessorDeluxe dcp4 = new CacheProcessorDeluxe(CacheStrategySpec.CacheStrategy.SPARSE);
        this.testOneCacheProcessor(dcp4, m, data, seq, pos, indices, vals, featureSums);
    }

    void testOneCacheProcessor(CacheProcessor dcp, ModelManager m, List<? extends TrainingSequence<?>> data, int seq, int pos, int[][] indices, float[][] vals, double[] featureSums) {
        dcp.setTrainingData(m, data);
        dcp.evaluatePosition(seq, pos);
        if (pos > 0) {
            this.assertEvalEquals(dcp.getFeatureEvaluations(), indices, vals);
        } else {
            this.assertNonedgeEvalEquals(m, dcp.getFeatureEvaluations(), indices, vals);
        }
        this.assertArrayEquals(featureSums, dcp.getFeatureSums(), 1.0E-5);
    }

    private void assertNonedgeEvalEquals(ModelManager m, CacheProcessor.FeatureEvaluation[] evals, int[][] indices, float[][] vals) {
        int j;
        int i;
        log.warn((Object)evals);
        Assert.a(indices.length >= m.getNumStates());
        for (i = 0; i < m.getNumStates(); ++i) {
            for (j = 0; j < indices[i].length; ++j) {
                CacheProcessorTest.assertEquals((int)indices[i][j], (int)evals[i].index[j]);
            }
        }
        for (i = 0; i < vals.length; ++i) {
            for (j = 0; j < vals[i].length; ++j) {
                CacheProcessorTest.assertEquals((double)vals[i][j], (double)evals[i].value[j], (double)1.0E-5);
            }
        }
    }

    private void assertEvalEquals(CacheProcessor.FeatureEvaluation[] evals, int[][] indices, float[][] vals) {
        int j;
        int i;
        log.warn((Object)evals);
        for (i = 0; i < indices.length; ++i) {
            for (j = 0; j < indices[i].length; ++j) {
                CacheProcessorTest.assertEquals((int)indices[i][j], (int)evals[i].index[j]);
            }
        }
        for (i = 0; i < vals.length; ++i) {
            for (j = 0; j < vals[i].length; ++j) {
                CacheProcessorTest.assertEquals((double)vals[i][j], (double)evals[i].value[j], (double)1.0E-5);
            }
        }
    }
}

