package fmph.features.mitochondrion1;

import calhoun.analysis.crf.AbstractFeatureManager;
import calhoun.analysis.crf.CacheStrategySpec;
import calhoun.analysis.crf.CacheStrategySpec.CacheStrategy;
import calhoun.analysis.crf.FeatureList;
import calhoun.analysis.crf.FeatureManagerNode;
import calhoun.analysis.crf.ModelManager;
import calhoun.analysis.crf.features.supporting.phylogenetic.EvolutionaryModel;
import calhoun.analysis.crf.features.supporting.phylogenetic.HKY85Model;
import calhoun.analysis.crf.features.supporting.phylogenetic.Kimura80Model;
import calhoun.analysis.crf.features.supporting.phylogenetic.PhylogeneticTreeFelsensteinOrder;
import calhoun.analysis.crf.io.InputSequence;
import calhoun.analysis.crf.io.MultipleAlignmentInputSequence;
import calhoun.analysis.crf.io.MultipleAlignmentInputSequence.MultipleAlignmentColumn;
import calhoun.analysis.crf.io.TrainingSequence;

import calhoun.seq.KmerHasher;

import calhoun.util.Assert;

import flanagan.math.Minimisation;
import flanagan.math.MinimisationFunction;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public class PhylogeneticLogprobMitochondrion1 extends AbstractFeatureManager<MultipleAlignmentColumn> implements FeatureManagerNode<MultipleAlignmentColumn> {
    private static final long serialVersionUID = -7659288739348604129L;
    private static final Log log = LogFactory.getLog(PhylogeneticLogprobMitochondrion1.class);


    int startIx; // The index of the first feature managed by this FeatureManager
    ModelManager model;
    boolean multipleFeatures = false;
    boolean k80Model = false;
    int maxIter = 100;
    
    double[] emodelIntergenicParams;
    double[] emodelIntronicParams;
    double[] emodelExonic0Params;
    double[] emodelExonic1Params;
    double[] emodelExonic2Params;

    EvolutionaryModel emodelIntergenic; // one model for a column of aligned sequence in intergenic region
    EvolutionaryModel emodelIntronic; // one model for intronic regions
    ArrayList<EvolutionaryModel> emodelExonic; // a model for positions 0,1,2 = (A,T,G) of a codon n a coding exon.

    static KmerHasher hforward = new KmerHasher(KmerHasher.ACGTother, 1); // a character hasher for forward strand
    static KmerHasher hbackward = new KmerHasher(KmerHasher.ACGTotherRC, 1); // a character hasher for reverse strand

    ///////////////////////////////////////////////////////////////////////////////

    public PhylogeneticLogprobMitochondrion1() {
    } // a constructor with no arguments

    public int getNumFeatures() { // there is exactly one feature
        return multipleFeatures ? 5 : 1;
    }

    public String getFeatureName(int featureIndex) {
        if (multipleFeatures) {
            String[] vals = new String[] { "Intergenic", "Exon pos.", "Intron pos.", "Exon neg.", "Intron neg." };
            int feat = featureIndex - startIx;
            String table = vals[feat];
            return table + " phylogeny";
        } else {
            return "PhylogeneticLogProbInterval13";
        }
    }


    public void evaluateNode(InputSequence<? extends MultipleAlignmentColumn> seq, int pos, int state, FeatureList result) {

        Assert.a(state < model.getNumStates());
        MultipleAlignmentColumn col = seq.getX(pos);

        double val = 0.0;
        int ephase;
        int featureOffset = Integer.MIN_VALUE;
        switch (state) {
        case 0:
            val = emodelIntergenic.logprob(col, true);
            featureOffset = 0;
            break;
        case 1:
        case 2:
        case 3:
            ephase = ((pos - state + 1) % 3 + 3) % 3; //((pos-(state-1))%3 +3)%3;
            //val = emodelExonic.get(0).logprob(col,true);
            val = emodelExonic.get(ephase).logprob(col, true);
            featureOffset = 2 + ephase;
            break;
        case 4:
        case 5:
        case 6:
            val = emodelIntronic.logprob(col, true);
            featureOffset = 1;
            break;
        case 7:
        case 8:
        case 9:
            ephase = ((-pos + state + 1) % 3 + 3) % 3; // ((-pos+2+(state-7))%3 +3)%3;
            val = emodelExonic.get(ephase).logprobRC(col, true);
            featureOffset = 2 + ephase;
            break;
        case 10:
        case 11:
        case 12:
            val = emodelIntronic.logprobRC(col, true);
            featureOffset = 1;
            break;
        default:
            Assert.a(false);
        }

        result.addFeature(startIx + (multipleFeatures ? featureOffset : 0), val);
    }


    public void train(int startingIndex, ModelManager modelInfo, final List<? extends TrainingSequence<? extends MultipleAlignmentColumn>> data) {
        startIx = startingIndex;
        model = modelInfo;

        final PhylogeneticTreeFelsensteinOrder felsOrder = data.get(0).getX(0).getMultipleAlignment().getFelsensteinOrder();

        ArrayList<boolean[]> flagsForward = new ArrayList<boolean[]>();
        ArrayList<boolean[]> flagsBackward = new ArrayList<boolean[]>();
        for (int seqNum = 0; seqNum < data.size(); seqNum++) {
            TrainingSequence<? extends MultipleAlignmentColumn> aln = data.get(seqNum);
            int len = aln.length();
            flagsForward.add(new boolean[len]);
            flagsBackward.add(new boolean[len]);
        }


        log.debug("Training model for intergenic regions...");
        if(emodelIntergenicParams == null){
            for (int seqNum = 0; seqNum < data.size(); seqNum++) {
                TrainingSequence<? extends MultipleAlignmentColumn> aln = data.get(seqNum);
                int len = aln.length();
    
                boolean[] ff = flagsForward.get(seqNum);
                Assert.a(ff.length == len);
    
                boolean[] fb = flagsBackward.get(seqNum);
                Assert.a(fb.length == len);
    
                for (int pos = 0; pos < len; pos++) {
                    int y = aln.getY(pos);
                    if (y == 0) {
                        ff[pos] = true;
                        fb[pos] = true;
                    } else {
                        ff[pos] = false;
                        fb[pos] = false;
                    }
                }
            }
            emodelIntergenic = k80Model?trainEvolutionaryModel(felsOrder, data, flagsForward, flagsBackward):trainHKY85EvolutionaryModel(felsOrder, data, flagsForward, flagsBackward);
        } else {
            emodelIntergenic = new EvolutionaryModel(felsOrder, 
                    new double[]{emodelIntergenicParams[2],
                                 emodelIntergenicParams[3],
                                 emodelIntergenicParams[4],
                                 1-emodelIntergenicParams[2]-emodelIntergenicParams[3]-emodelIntergenicParams[4]},
                    new HKY85Model(emodelIntergenicParams));
        }
        
        log.debug("Evolutionary model for intergenic regions:");
        emodelIntergenic.summarize();


        log.debug("Training model for intronic regions...");
        if(emodelIntergenicParams == null){
            for (int seqNum = 0; seqNum < data.size(); seqNum++) {
                TrainingSequence<? extends MultipleAlignmentColumn> aln = data.get(seqNum);
                int len = aln.length();
    
                boolean[] ff = flagsForward.get(seqNum);
                Assert.a(ff.length == len);
    
                boolean[] fb = flagsBackward.get(seqNum);
                Assert.a(fb.length == len);
    
                for (int pos = 0; pos < len; pos++) {
                    int y = aln.getY(pos);
                    if ((y == 4) || (y == 5) || (y == 6)) {
                        ff[pos] = true;
                    } else {
                        ff[pos] = false;
                    }
                    if ((y == 10) || (y == 11) || (y == 12)) {
                        fb[pos] = true;
                    } else {
                        fb[pos] = false;
                    }
                }
            }
            emodelIntronic = k80Model?trainEvolutionaryModel(felsOrder, data, flagsForward, flagsBackward):trainHKY85EvolutionaryModel(felsOrder, data, flagsForward, flagsBackward);
        } else {
            emodelIntronic = new EvolutionaryModel(felsOrder, 
                    new double[]{emodelIntronicParams[2],
                                 emodelIntronicParams[3],
                                 emodelIntronicParams[4],
                                 1-emodelIntronicParams[2]-emodelIntronicParams[3]-emodelIntronicParams[4]},
                    new HKY85Model(emodelIntronicParams));
        }
        log.debug("Evolutionary model for intronic regions:");
        emodelIntronic.summarize();

        // 	  ephase = ((pos-state+1)%3+3)%3; for states 1,2,3
        //    ephase = ((-pos+state+1)%3+3)%3; for states 10,11,12

        emodelExonic = new ArrayList<EvolutionaryModel>();
        
        for (int phase = 0; phase < 3; phase++) {
            if((phase == 0 && emodelExonic0Params == null) || (phase == 1 && emodelExonic1Params == null) || 
               (phase == 2 && emodelExonic2Params == null)){
                log.debug("Training model for exonic regions...");
                for (int seqNum = 0; seqNum < data.size(); seqNum++) {
                    TrainingSequence<? extends MultipleAlignmentColumn> aln = data.get(seqNum);
                    int len = aln.length();
    
                    boolean[] ff = flagsForward.get(seqNum);
                    Assert.a(ff.length == len);
    
                    boolean[] fb = flagsBackward.get(seqNum);
                    Assert.a(fb.length == len);
    
                    for (int pos = 0; pos < len; pos++) {
                        int y = aln.getY(pos);
                        int pstate = ((pos - phase) % 3 + 3) % 3 + 1;
                        int mstate = ((phase + pos - 2) % 3 + 3) % 3 + 7;
                        if (y == pstate) {
                            ff[pos] = true;
                        } else {
                            ff[pos] = false;
                        }
                        if (y == mstate) {
                            fb[pos] = true;
                        } else {
                            fb[pos] = false;
                        }
                    }
                }
                emodelExonic.add(k80Model?trainEvolutionaryModel(felsOrder, data, flagsForward, flagsBackward):trainHKY85EvolutionaryModel(felsOrder, data, flagsForward, flagsBackward));
            } else {
                if(phase == 0){
                    emodelExonic.add(new EvolutionaryModel(felsOrder, 
                            new double[]{emodelExonic0Params[2],
                                         emodelExonic0Params[3],
                                         emodelExonic0Params[4],
                                         1-emodelExonic0Params[2]-emodelExonic0Params[3]-emodelExonic0Params[4]},
                            new HKY85Model(emodelExonic0Params)));
                }
                
                if(phase == 1){
                    emodelExonic.add(new EvolutionaryModel(felsOrder, 
                            new double[]{emodelExonic1Params[2],
                                         emodelExonic1Params[3],
                                         emodelExonic1Params[4],
                                         1-emodelExonic1Params[2]-emodelExonic1Params[3]-emodelExonic1Params[4]},
                            new HKY85Model(emodelExonic1Params)));
                }
                
                if(phase == 2){
                    emodelExonic.add(new EvolutionaryModel(felsOrder, 
                            new double[]{emodelExonic2Params[2],
                                         emodelExonic2Params[3],
                                         emodelExonic2Params[4],
                                         1-emodelExonic2Params[2]-emodelExonic2Params[3]-emodelExonic2Params[4]},
                            new HKY85Model(emodelExonic2Params)));
                }
            }
            log.debug("Evolutionary model for exonic regions:");
            emodelExonic.get(phase).summarize();
        }

        log.debug("Just trained all evolutionary models");
    }

    private EvolutionaryModel trainEvolutionaryModel(final PhylogeneticTreeFelsensteinOrder felsOrder,
                                                     final List<? extends TrainingSequence<? extends MultipleAlignmentColumn>> data,
                                                     final ArrayList<boolean[]> flagsForward, final ArrayList<boolean[]> flagsBackward) {

        Assert.a(flagsForward.size() == data.size());
        Assert.a(flagsBackward.size() == data.size());

        // Estimate pi based on the nucleotide frequencies in the reference sequence
        final double[] pi = new double[] { 1.0, 1.0, 1.0, 1.0 };
        for (int seqNum = 0; seqNum < data.size(); seqNum++) {
            TrainingSequence<? extends MultipleAlignmentColumn> aln = data.get(seqNum);
            int len = aln.length();

            boolean[] ff = flagsForward.get(seqNum);
            Assert.a(ff.length == len);

            boolean[] fb = flagsBackward.get(seqNum);
            Assert.a(fb.length == len);

            for (int ix = 0; ix < len; ix++) {
                if (ff[ix]) {
                    int x = hforward.hash(aln.getX(ix).nucleotide(0));
                    if (x < 4) {
                        pi[x] += 1.0;
                    }
                }

                if (fb[ix]) {
                    int x = hbackward.hash(aln.getX(ix).nucleotide(0));
                    if (x < 4) {
                        pi[x] += 1.0;
                    }
                }
            }
        }
        double total = pi[0] + pi[1] + pi[2] + pi[3];
        pi[0] /= total;
        pi[1] /= total;
        pi[2] /= total;
        pi[3] /= total;


        MinimisationFunction mFunc = new MinimisationFunction() {
            public double function(double[] d) {
                double[] ed = new double[2];
                ed[0] = Math.exp(d[0]);
                ed[1] = Math.exp(d[1]);

                Kimura80Model R = new Kimura80Model(ed);
                EvolutionaryModel M = new EvolutionaryModel(felsOrder, pi, R);

                double ret = 0;
                for (int seqNum = 0; seqNum < data.size(); seqNum++) {
                    TrainingSequence<? extends MultipleAlignmentColumn> aln = data.get(seqNum);
                    int len = aln.length();

                    boolean[] ff = flagsForward.get(seqNum);
                    Assert.a(ff.length == len);

                    boolean[] fb = flagsBackward.get(seqNum);
                    Assert.a(fb.length == len);

                    for (int ix = 0; ix < len; ix++) {
                        if (ff[ix]) {
                            ret += M.logprob(aln.getX(ix), true);
                        }
                        if (fb[ix]) {
                            ret += M.logprobRC(aln.getX(ix), true);
                        }
                    }
                }
                return -ret;
            }
        };

        // The standard mantra for minimizing the function mFunc defined above
        final int nParm = 2;
        Minimisation m = new Minimisation();
        m.setNmax(maxIter);
        double[] starts = new double[nParm];
        Arrays.fill(starts, 0.1);
        double[] steps = new double[nParm];
        Arrays.fill(steps, 0.1);
        m.nelderMead(mFunc, starts, steps);
        if (!m.getConvStatus()) {
            log.warn("WARNING - Nelder-Mead routine says convergence was not reached");
        }
        double[] results = m.getParamValues();
        double[] eresults = new double[] { Math.exp(results[0]), Math.exp(results[1]) };

        return (new EvolutionaryModel(felsOrder, pi, new Kimura80Model(eresults)));
    }
    
    private EvolutionaryModel trainHKY85EvolutionaryModel(final PhylogeneticTreeFelsensteinOrder felsOrder,
                                                     final List<? extends TrainingSequence<? extends MultipleAlignmentColumn>> data,
                                                     final ArrayList<boolean[]> flagsForward, final ArrayList<boolean[]> flagsBackward) {

        Assert.a(flagsForward.size() == data.size());
        Assert.a(flagsBackward.size() == data.size());

        // Estimate pi based on the nucleotide frequencies in the reference sequence
        final double[] pi = new double[] { 1.0, 1.0, 1.0, 1.0 };
        for (int seqNum = 0; seqNum < data.size(); seqNum++) {
            TrainingSequence<? extends MultipleAlignmentColumn> aln = data.get(seqNum);
            int len = aln.length();

            boolean[] ff = flagsForward.get(seqNum);
            Assert.a(ff.length == len);

            boolean[] fb = flagsBackward.get(seqNum);
            Assert.a(fb.length == len);

            for (int ix = 0; ix < len; ix++) {
                if (ff[ix]) {
                    int x = hforward.hash(aln.getX(ix).nucleotide(0));
                    if (x < 4) {
                        pi[x] += 1.0;
                    }
                }

                if (fb[ix]) {
                    int x = hbackward.hash(aln.getX(ix).nucleotide(0));
                    if (x < 4) {
                        pi[x] += 1.0;
                    }
                }
            }
        }
        double total = pi[0] + pi[1] + pi[2] + pi[3];
        pi[0] /= total;
        pi[1] /= total;
        pi[2] /= total;
        pi[3] /= total;


        MinimisationFunction mFunc = new MinimisationFunction() {
            public double function(double[] d) {
                double[] ed = new double[5];
                ed[0] = Math.exp(d[0]);
                ed[1] = Math.exp(d[1]);
                ed[2] = pi[0];
                ed[3] = pi[1];
                ed[4] = pi[2];

                HKY85Model R = new HKY85Model(ed);
                EvolutionaryModel M = new EvolutionaryModel(felsOrder, pi, R);

                double ret = 0;
                for (int seqNum = 0; seqNum < data.size(); seqNum++) {
                    TrainingSequence<? extends MultipleAlignmentColumn> aln = data.get(seqNum);
                    int len = aln.length();

                    boolean[] ff = flagsForward.get(seqNum);
                    Assert.a(ff.length == len);

                    boolean[] fb = flagsBackward.get(seqNum);
                    Assert.a(fb.length == len);

                    for (int ix = 0; ix < len; ix++) {
                        if (ff[ix]) {
                            ret += M.logprob(aln.getX(ix), true);
                        }
                        if (fb[ix]) {
                            ret += M.logprobRC(aln.getX(ix), true);
                        }
                    }
                }
                return -ret;
            }
        };

        // The standard mantra for minimizing the function mFunc defined above
        final int nParm = 2;
        Minimisation m = new Minimisation();
        m.setNmax(maxIter);
        double[] starts = new double[nParm];
        Arrays.fill(starts, 0.1);
        double[] steps = new double[nParm];
        Arrays.fill(steps, 0.1);
        m.nelderMead(mFunc, starts, steps);
        if (!m.getConvStatus()) {
            log.warn("WARNING - Nelder-Mead routine says convergence was not reached");
        }
        double[] results = m.getParamValues();
        double[] eresults = new double[] { Math.exp(results[0]), Math.exp(results[1]) , pi[0],pi[1],pi[2]};

        return (new EvolutionaryModel(felsOrder, pi, new HKY85Model(eresults)));
    }

    @Override
    public CacheStrategySpec getCacheStrategy() {
        return new CacheStrategySpec(CacheStrategy.DENSE);
    }

    /**
     * @return Returns the multipleFeatures.
     */
    public boolean isMultipleFeatures() {
        return multipleFeatures;
    }

    /**
     * @param multipleFeatures The multipleFeatures to set.
     */
    public void setMultipleFeatures(boolean multipleFeatures) {
        this.multipleFeatures = multipleFeatures;
    }

    public void setEmodelIntergenicParams(double[] emodelIntergenicParams) {
        this.emodelIntergenicParams = emodelIntergenicParams;
    }

    public double[] getEmodelIntergenicParams() {
        return emodelIntergenicParams;
    }

    public void setEmodelIntronicParams(double[] emodelIntronicParams) {
        this.emodelIntronicParams = emodelIntronicParams;
    }

    public double[] getEmodelIntronicParams() {
        return emodelIntronicParams;
    }

    public void setEmodelExonic0Params(double[] emodelExonic0Params) {
        this.emodelExonic0Params = emodelExonic0Params;
    }

    public double[] getEmodelExonic0Params() {
        return emodelExonic0Params;
    }

    public void setEmodelExonic1Params(double[] emodelExonic1Params) {
        this.emodelExonic1Params = emodelExonic1Params;
    }

    public double[] getEmodelExonic1Params() {
        return emodelExonic1Params;
    }

    public void setEmodelExonic2Params(double[] emodelExonic2Params) {
        this.emodelExonic2Params = emodelExonic2Params;
    }

    public double[] getEmodelExonic2Params() {
        return emodelExonic2Params;
    }

    public void setK80Model(boolean k80model) {
        this.k80Model = k80model;
    }

    public boolean isK80Model() {
        return k80Model;
    }

    public void setMaxIter(int maxIter) {
        this.maxIter = maxIter;
    }

    public int getMaxIter() {
        return maxIter;
    }
}
