/*
 * Decompiled with CFR 0.152.
 */
package calhoun.analysis.crf.features.tricycle13;

import calhoun.analysis.crf.AbstractFeatureManager;
import calhoun.analysis.crf.BeanModel;
import calhoun.analysis.crf.CacheStrategySpec;
import calhoun.analysis.crf.FeatureList;
import calhoun.analysis.crf.FeatureManagerEdge;
import calhoun.analysis.crf.ModelManager;
import calhoun.analysis.crf.features.supporting.MarkovPredictorLogprob;
import calhoun.analysis.crf.features.tricycle13.EmissionMarkovFeature;
import calhoun.analysis.crf.io.InputSequence;
import calhoun.analysis.crf.io.TrainingSequence;
import calhoun.seq.KmerHasher;
import calhoun.util.Assert;
import calhoun.util.DenseBooleanMatrix2D;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public class PositionWeightMatrixFeatures
extends AbstractFeatureManager<Character>
implements FeatureManagerEdge<Character> {
    private static final long serialVersionUID = -7659288739348604129L;
    private static final Log log = LogFactory.getLog(PositionWeightMatrixFeatures.class);
    boolean debug = log.isDebugEnabled();
    int startIx;
    ModelManager model;
    int nFeatures;
    int[] span;
    int[] offset;
    int[] nTrans;
    KmerHasher h;
    DenseBooleanMatrix2D[] transitions;
    List<int[]> geometry;
    List<float[][]> logprob;
    boolean dcflag;
    List<int[]> dcc;
    MarkovPredictorLogprob predictorlp;
    transient InputSequence<? extends Character> lastSeq;
    int lastPos;
    float[] vals;
    boolean tieFlag = false;
    int UVCount = 0;
    List<Geometry> pwmGeometry;
    EmissionMarkovFeature.MarkovHistory markovHistory;

    public PositionWeightMatrixFeatures() {
    }

    public PositionWeightMatrixFeatures(List<int[]> geometry, List<int[]> dccorrection, List<int[]> markovhistory) {
        this.setThingsUp(geometry, dccorrection, markovhistory);
    }

    public PositionWeightMatrixFeatures(List<int[]> geometry2, List<int[]> dccorrection, List<int[]> markovhistory, List<int[]> flags) {
        this.tieFlag = true;
        this.setThingsUp(geometry2, dccorrection, markovhistory);
    }

    public void init() {
        ArrayList<int[]> geometry1 = new ArrayList<int[]>(this.pwmGeometry.size());
        ArrayList<int[]> dccorrection = new ArrayList<int[]>(this.pwmGeometry.size());
        for (Geometry g : this.pwmGeometry) {
            int[] params = new int[]{g.getSize(), g.getTransition(), g.getPrev().getIndex(), g.getCurrent().getIndex()};
            geometry1.add(params);
            int[] correction = new int[g.overlapCorrections.size()];
            for (int i = 0; i < correction.length; ++i) {
                correction[i] = g.overlapCorrections.get(i).getIndex();
            }
            dccorrection.add(correction);
        }
        this.setThingsUp(geometry1, dccorrection, this.markovHistory.convert());
    }

    private void setThingsUp(List<int[]> geometry2, List<int[]> dccorrection, List<int[]> markovhistory) {
        this.predictorlp = new MarkovPredictorLogprob(markovhistory);
        this.setupGeometry(geometry2);
        Assert.a(this.geometry.size() == this.nFeatures);
        Assert.a(dccorrection.size() == this.nFeatures);
        for (int j = 0; j < this.nFeatures; ++j) {
            Assert.a(this.nTrans[j] == 1);
            Assert.a(dccorrection.get(j).length == this.span[j]);
        }
        this.setupDoubleCountCorrections(dccorrection, this.predictorlp);
    }

    private void setupDoubleCountCorrections(List<int[]> dccorrection, MarkovPredictorLogprob predictorlp) {
        this.dcflag = true;
        this.predictorlp = predictorlp;
        this.dcc = dccorrection;
    }

    private void setupGeometry(List<int[]> geometry) {
        this.geometry = geometry;
        this.nFeatures = geometry.size();
        this.span = new int[this.nFeatures];
        this.offset = new int[this.nFeatures];
        this.nTrans = new int[this.nFeatures];
        this.vals = new float[this.nFeatures];
        this.h = new KmerHasher(KmerHasher.ACGTN, 1);
        this.logprob = new ArrayList<float[][]>();
        for (int i = 0; i < this.nFeatures; ++i) {
            this.nTrans[i] = (geometry.get(i).length - 2) / 2;
            this.span[i] = geometry.get(i)[0];
            this.offset[i] = geometry.get(i)[1];
            Assert.a(this.offset[i] >= 0);
            float[][] lp = new float[this.span[i]][this.h.range()];
            this.logprob.add(lp);
        }
    }

    @Override
    public int getNumFeatures() {
        if (this.tieFlag) {
            return 1;
        }
        return this.nFeatures;
    }

    @Override
    public String getFeatureName(int featureIndex) {
        if (this.tieFlag) {
            return "tiedPwmFeature";
        }
        int raw = featureIndex - this.startIx;
        int[] X = this.geometry.get(raw);
        String ret = "PWM.span" + X[0] + ".offset" + X[1];
        for (int j = 2; j < X.length; j += 2) {
            ret = ret + ".(" + this.model.getStateName(X[j]) + "," + this.model.getStateName(X[j + 1]) + ")";
        }
        return ret;
    }

    @Override
    public void evaluateEdge(InputSequence<? extends Character> seq, int pos, int previousState, int state, FeatureList result) {
        if (pos == 0) {
            return;
        }
        if (seq != this.lastSeq || pos != this.lastPos) {
            this.lastSeq = seq;
            this.lastPos = pos;
            this.updateVals(seq, pos);
        }
        if (this.tieFlag) {
            for (int j = 0; j < this.nFeatures; ++j) {
                if (!this.transitions[j].getQuick(previousState, state)) continue;
                result.addFeature(this.startIx, this.vals[j]);
            }
        } else {
            for (int j = 0; j < this.nFeatures; ++j) {
                if (!this.transitions[j].getQuick(previousState, state)) continue;
                result.addFeature(this.startIx + j, this.vals[j]);
            }
        }
    }

    void updateVals(InputSequence<? extends Character> seq, int ix) {
        ++this.UVCount;
        for (int j = 0; j < this.nFeatures; ++j) {
            int[] geo = this.geometry.get(j);
            int spn = geo[0];
            int offset1 = geo[1];
            float val = 0.0f;
            if (ix >= offset1 && ix - offset1 + spn <= seq.length()) {
                int i;
                for (i = 0; i < spn; ++i) {
                    int pos = ix - offset1 + i;
                    char c = seq.getX(pos).charValue();
                    val += this.logprob.get(j)[i][this.h.hash(c)];
                }
                if (this.dcflag) {
                    Assert.a(this.nTrans[j] == 1);
                    Assert.a(this.dcc.get(j).length == spn);
                    for (i = 0; i < spn; ++i) {
                        val -= this.predictorlp.logprob(this.dcc.get(j)[i], seq, ix - offset1 + i);
                    }
                }
            }
            this.vals[j] = val;
        }
    }

    /*
     * WARNING - void declaration
     */
    @Override
    public void train(int startingIndex, ModelManager modelInfo, List<? extends TrainingSequence<? extends Character>> data) {
        int i;
        this.startIx = startingIndex;
        this.model = modelInfo;
        for (int i2 = 0; i2 < this.nFeatures; ++i2) {
            float[][] A = new float[this.span[i2]][this.h.range()];
            this.logprob.add(A);
        }
        int nStates = this.model.getNumStates();
        this.transitions = new DenseBooleanMatrix2D[this.nFeatures];
        for (i = 0; i < this.nFeatures; ++i) {
            this.transitions[i] = new DenseBooleanMatrix2D(nStates, nStates);
            for (int j = 2; j < this.geometry.get(i).length; j += 2) {
                this.transitions[i].setQuick(this.geometry.get(i)[j], this.geometry.get(i)[j + 1], true);
            }
        }
        for (i = 0; i < this.nFeatures; ++i) {
            for (int j = 0; j < this.span[i]; ++j) {
                for (int k = 0; k < this.h.range(); ++k) {
                    this.logprob.get((int)i)[j][k] = 1.0f;
                }
            }
        }
        for (TrainingSequence<? extends Character> trainingSequence : data) {
            int len = trainingSequence.length();
            for (int i3 = 0; i3 < this.nFeatures; ++i3) {
                for (int ix = 0; ix < len; ++ix) {
                    if (ix < this.offset[i3] || ix - this.offset[i3] + this.span[i3] > trainingSequence.length() || ix <= 0) continue;
                    int yprev = trainingSequence.getY(ix - 1);
                    int y = trainingSequence.getY(ix);
                    for (int j = 0; j < this.nTrans[i3]; ++j) {
                        if (yprev != this.geometry.get(i3)[2 + 2 * j] || y != this.geometry.get(i3)[2 + 2 * j + 1]) continue;
                        for (int pos = 0; pos < this.span[i3]; ++pos) {
                            char c = trainingSequence.getX(ix - this.offset[i3] + pos).charValue();
                            float[] fArray = this.logprob.get(i3)[pos];
                            int n = this.h.hash(c);
                            fArray[n] = (float)((double)fArray[n] + 1.0);
                        }
                    }
                }
            }
        }
        for (int i4 = 0; i4 < this.nFeatures; ++i4) {
            void var6_14;
            boolean bl = false;
            while (var6_14 < this.span[i4]) {
                int k;
                float norm = 0.0f;
                for (k = 0; k < this.h.range(); ++k) {
                    norm += this.logprob.get(i4)[var6_14][k];
                }
                Assert.a(norm > 0.0f);
                for (k = 0; k < this.h.range(); ++k) {
                    this.logprob.get((int)i4)[var6_14][k] = (float)(Math.log(this.logprob.get(i4)[var6_14][k]) - Math.log(norm));
                }
                ++var6_14;
            }
        }
        if (this.dcflag) {
            this.predictorlp.train(data);
        }
    }

    @Override
    public CacheStrategySpec getCacheStrategy() {
        return new CacheStrategySpec(CacheStrategySpec.CacheStrategy.SPARSE);
    }

    public EmissionMarkovFeature.MarkovHistory getMarkovHistory() {
        return this.markovHistory;
    }

    public void setMarkovHistory(EmissionMarkovFeature.MarkovHistory markovHistory) {
        this.markovHistory = markovHistory;
    }

    public List<Geometry> getPwmGeometry() {
        return this.pwmGeometry;
    }

    public void setPwmGeometry(List<Geometry> pwmGeometry) {
        this.pwmGeometry = pwmGeometry;
    }

    public boolean isTieFlag() {
        return this.tieFlag;
    }

    public void setTieFlag(boolean tieFlag) {
        this.tieFlag = tieFlag;
    }

    public static class Geometry
    implements Serializable {
        private static final long serialVersionUID = 4896358213027322167L;
        int size;
        int transition;
        BeanModel.Node prev;
        BeanModel.Node current;
        List<BeanModel.Node> overlapCorrections;

        public BeanModel.Node getCurrent() {
            return this.current;
        }

        public void setCurrent(BeanModel.Node current) {
            this.current = current;
        }

        public List<BeanModel.Node> getOverlapCorrections() {
            return this.overlapCorrections;
        }

        public void setOverlapCorrections(List<BeanModel.Node> overlapCorrections) {
            this.overlapCorrections = overlapCorrections;
        }

        public BeanModel.Node getPrev() {
            return this.prev;
        }

        public void setPrev(BeanModel.Node prev) {
            this.prev = prev;
        }

        public int getSize() {
            return this.size;
        }

        public void setSize(int size) {
            this.size = size;
        }

        public int getTransition() {
            return this.transition;
        }

        public void setTransition(int transition) {
            this.transition = transition;
        }
    }
}

