/*
 * Decompiled with CFR 0.152.
 */
package calhoun.analysis.crf.features.tricycle13;

import calhoun.analysis.crf.AbstractFeatureManager;
import calhoun.analysis.crf.FeatureList;
import calhoun.analysis.crf.FeatureManagerEdge;
import calhoun.analysis.crf.ModelManager;
import calhoun.analysis.crf.features.supporting.MarkovPredictorLogprob;
import calhoun.analysis.crf.features.supporting.phylogenetic.ColumnConditionalLogProbability;
import calhoun.analysis.crf.features.supporting.phylogenetic.EvolutionaryModel;
import calhoun.analysis.crf.io.CompositeInput;
import calhoun.analysis.crf.io.InputSequence;
import calhoun.analysis.crf.io.TrainingSequence;
import calhoun.seq.KmerHasher;
import calhoun.util.Assert;
import calhoun.util.DenseBooleanMatrix2D;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public class PWM_evolution
extends AbstractFeatureManager<CompositeInput>
implements FeatureManagerEdge<CompositeInput> {
    private static final long serialVersionUID = -7659288739348604129L;
    private static final Log log = LogFactory.getLog(PWM_evolution.class);
    boolean debug = log.isDebugEnabled();
    int startIx;
    ModelManager model;
    int nFeatures;
    int[] span;
    int[] offset;
    int[] nTrans;
    DenseBooleanMatrix2D[] transitions;
    List<int[]> geometry;
    List<float[][]> logprob;
    List<int[]> dcc;
    MarkovPredictorLogprob predictorlp;
    List<int[]> clusters;
    List<EvolutionaryModel> emodels;
    int[] state2cluster;
    static KmerHasher h = new KmerHasher(KmerHasher.ACGTother, 1);
    ColumnConditionalLogProbability mo;
    boolean tieFlag;
    InputSequence<? extends CompositeInput> lastSeq;
    int lastPos;
    float[] vals;
    private int nUpdate = 0;

    public PWM_evolution(List<int[]> geometry, List<int[]> dccorrection, List<int[]> markovhistory, List<int[]> clusters) {
        this.tieFlag = false;
        this.PWM_evolution_setup(geometry, dccorrection, markovhistory, clusters);
    }

    public PWM_evolution(List<int[]> geometry, List<int[]> dccorrection, List<int[]> markovhistory, List<int[]> clusters, List<int[]> flags) {
        this.tieFlag = true;
        this.PWM_evolution_setup(geometry, dccorrection, markovhistory, clusters);
    }

    private void PWM_evolution_setup(List<int[]> geometry1, List<int[]> dccorrection, List<int[]> markovhistory, List<int[]> clusters1) {
        this.dcc = dccorrection;
        this.predictorlp = new MarkovPredictorLogprob(markovhistory);
        this.mo = new ColumnConditionalLogProbability(clusters1, 0);
        this.geometry = geometry1;
        this.clusters = clusters1;
        this.setupGeometry();
    }

    private void setupGeometry() {
        this.nFeatures = this.geometry.size();
        this.span = new int[this.nFeatures];
        this.offset = new int[this.nFeatures];
        this.nTrans = new int[this.nFeatures];
        h = new KmerHasher(KmerHasher.ACGTN, 1);
        this.logprob = new ArrayList<float[][]>();
        for (int i = 0; i < this.nFeatures; ++i) {
            this.nTrans[i] = (this.geometry.get(i).length - 2) / 2;
            this.span[i] = this.geometry.get(i)[0];
            this.offset[i] = this.geometry.get(i)[1];
            Assert.a(this.offset[i] >= 0);
            float[][] lp = new float[this.span[i]][h.range()];
            this.logprob.add(lp);
        }
        Assert.a(this.geometry.size() == this.nFeatures);
        Assert.a(this.dcc.size() == this.nFeatures);
        for (int j = 0; j < this.nFeatures; ++j) {
            Assert.a(this.nTrans[j] == 1);
            Assert.a(this.dcc.get(j).length == this.span[j]);
        }
    }

    @Override
    public int getNumFeatures() {
        if (this.tieFlag) {
            return 1;
        }
        return this.nFeatures;
    }

    @Override
    public String getFeatureName(int featureIndex) {
        int raw = featureIndex - this.startIx;
        int[] X = this.geometry.get(raw);
        String ret = "PWM.span" + X[0] + ".offset" + X[1];
        for (int j = 2; j < X.length; j += 2) {
            ret = ret + ".(" + this.model.getStateName(X[j]) + "," + this.model.getStateName(X[j + 1]) + ")";
        }
        return ret;
    }

    @Override
    public void evaluateEdge(InputSequence<? extends CompositeInput> seq, int pos, int previousState, int state, FeatureList result) {
        if (pos == 0) {
            return;
        }
        if (seq != this.lastSeq || pos != this.lastPos) {
            this.lastSeq = seq;
            this.lastPos = pos;
            this.updateVals(seq, pos);
        }
        for (int j = 0; j < this.nFeatures; ++j) {
            if (!this.transitions[j].getQuick(previousState, state)) continue;
            if (this.tieFlag) {
                result.addFeature(this.startIx, this.vals[j]);
                continue;
            }
            result.addFeature(this.startIx + j, this.vals[j]);
        }
    }

    void updateVals(InputSequence<? extends CompositeInput> seq, int ix) {
        ++this.nUpdate;
        for (int j = 0; j < this.nFeatures; ++j) {
            int[] geo = this.geometry.get(j);
            int spn = geo[0];
            int offset1 = geo[1];
            float val = 0.0f;
            if (ix >= offset1 && ix - offset1 + spn <= seq.length()) {
                int i;
                InputSequence<?> CIS = seq.getComponent("ref");
                InputSequence<?> MIS = seq.getComponent("aln");
                for (i = 0; i < spn; ++i) {
                    int pos = ix - offset1 + i;
                    char c = ((Character)CIS.getX(pos)).charValue();
                    val += this.logprob.get(j)[i][h.hash(c)];
                }
                Assert.a(this.nTrans[j] == 1);
                Assert.a(this.dcc.get(j).length == spn);
                for (i = 0; i < spn; ++i) {
                    val -= this.predictorlp.logprob(this.dcc.get(j)[i], CIS, ix - offset1 + i);
                    val = (float)((double)val - this.mo.condLogProb(MIS, ix - offset1 + i, this.dcc.get(j)[i]));
                }
            }
            this.vals[j] = val;
        }
    }

    @Override
    public void train(int startingIndex, ModelManager modelInfo, List<? extends TrainingSequence<? extends CompositeInput>> data) {
        int i;
        this.startIx = startingIndex;
        this.model = modelInfo;
        this.vals = new float[this.nFeatures];
        ArrayList<TrainingSequence> LTSC = new ArrayList<TrainingSequence>();
        ArrayList<TrainingSequence> LTSMA = new ArrayList<TrainingSequence>();
        for (int j = 0; j < data.size(); ++j) {
            LTSC.add(data.get(j).getTrainingComponent("ref"));
            LTSMA.add(data.get(j).getTrainingComponent("aln"));
        }
        this.predictorlp.train(LTSC);
        this.mo.train(this.model, LTSMA);
        for (int i2 = 0; i2 < this.nFeatures; ++i2) {
            float[][] A = new float[this.span[i2]][h.range()];
            this.logprob.add(A);
        }
        int nStates = this.model.getNumStates();
        this.transitions = new DenseBooleanMatrix2D[this.nFeatures];
        for (i = 0; i < this.nFeatures; ++i) {
            this.transitions[i] = new DenseBooleanMatrix2D(nStates, nStates);
            for (int k = 2; k < this.geometry.get(i).length; k += 2) {
                this.transitions[i].setQuick(this.geometry.get(i)[k], this.geometry.get(i)[k + 1], true);
            }
        }
        for (i = 0; i < this.nFeatures; ++i) {
            for (int j = 0; j < this.span[i]; ++j) {
                for (int k = 0; k < h.range(); ++k) {
                    this.logprob.get((int)i)[j][k] = 1.0f;
                }
            }
        }
        for (TrainingSequence seq : LTSC) {
            int len = seq.length();
            for (int i3 = 0; i3 < this.nFeatures; ++i3) {
                for (int ix = 0; ix < len; ++ix) {
                    if (ix < this.offset[i3] || ix - this.offset[i3] + this.span[i3] > seq.length() || ix <= 0) continue;
                    int yprev = seq.getY(ix - 1);
                    int y = seq.getY(ix);
                    for (int j = 0; j < this.nTrans[i3]; ++j) {
                        if (yprev != this.geometry.get(i3)[2 + 2 * j] || y != this.geometry.get(i3)[2 + 2 * j + 1]) continue;
                        for (int pos = 0; pos < this.span[i3]; ++pos) {
                            char c = ((Character)seq.getX(ix - this.offset[i3] + pos)).charValue();
                            float[] fArray = this.logprob.get(i3)[pos];
                            int n = h.hash(c);
                            fArray[n] = (float)((double)fArray[n] + 1.0);
                        }
                    }
                }
            }
        }
        for (int i4 = 0; i4 < this.nFeatures; ++i4) {
            for (int j = 0; j < this.span[i4]; ++j) {
                int k;
                float norm = 0.0f;
                for (k = 0; k < h.range(); ++k) {
                    norm += this.logprob.get(i4)[j][k];
                }
                Assert.a(norm > 0.0f);
                for (k = 0; k < h.range(); ++k) {
                    this.logprob.get((int)i4)[j][k] = (float)(Math.log(this.logprob.get(i4)[j][k]) - Math.log(norm));
                }
            }
        }
    }
}

