package fmph.features.supporting.phylogenetic;

import calhoun.analysis.crf.features.supporting.phylogenetic.PhylogeneticTreeFelsensteinOrder;
import calhoun.analysis.crf.io.MultipleAlignmentInputSequence.MultipleAlignmentColumn;

import calhoun.seq.KmerHasher;

import calhoun.util.Assert;

import cern.colt.matrix.DoubleMatrix2D;

import fmph.seq.ExtendedKmerHasher;

import java.io.Serializable;

import java.util.Arrays;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public class CodonEvolutionaryModel implements Serializable{
    private static final Log log = LogFactory.getLog(CodonEvolutionaryModel.class);

    // Fundamental information:
    PhylogeneticTreeFelsensteinOrder             T;
    double[]                     pi;
    CodonSubstitutionModel  R;
    int numSpecies;
    double edgeCoef = 1;

    
    // Derived information, or precomputed/reserved for efficiency:
    int[] ileft,iright;  // indices of left and right child nodes
    double[][] Tleft;  //Transition matrices for the branches going left
    double[][] Tright;  // transition matrices for branches going right
    double[][] P;   // space in which Felsenstein algorithm recursions will be performed.

    static ExtendedKmerHasher hforward = new ExtendedKmerHasher(ExtendedKmerHasher.TCAGother,1);
    static ExtendedKmerHasher hbackward = new ExtendedKmerHasher(ExtendedKmerHasher.TCAGotherRC,1); 
    
    public CodonEvolutionaryModel(PhylogeneticTreeFelsensteinOrder T,
                    double[] pi, CodonSubstitutionModel  R ) {
            this.T = T;
            this.pi = pi;
            this.R = R;
            numSpecies = T.numSpecies();
            setup();
    }
    
    public CodonEvolutionaryModel(PhylogeneticTreeFelsensteinOrder T,
                    double[] pi, CodonSubstitutionModel  R , double edgeCoef) {
            this.T = T;
            this.pi = pi;
            this.R = R;
            numSpecies = T.numSpecies();
            this.edgeCoef = edgeCoef;
            setup();
    }
    
    
    private void setup() {
        
            ileft = T.getileft();
            iright = T.getiright();
            double[] bleft = T.getbleft();
            double[] bright = T.getbright();
            
            Tleft  = new double[T.numSteps()][];
            Tright = new double[T.numSteps()][];
            for (int j=0; j<T.numSteps(); j++) {
                    Tleft[j] = createArrayFromTransitionMatrix(R.transitionMatrix(bleft[j]*edgeCoef));
                    Tright[j] = createArrayFromTransitionMatrix(R.transitionMatrix(bright[j]*edgeCoef));
            }
            Assert.a(ileft.length == T.numSteps());
            Assert.a(iright.length == T.numSteps());
            
            P = new double[T.numNodes()][64];
    }
    
    
    public double logprobRC(MultipleAlignmentColumn col0, MultipleAlignmentColumn col1 , MultipleAlignmentColumn col2, boolean conditionref) {
            return logprob(col0,col1,col2,conditionref,hbackward, true);
    }

    public double logprob(MultipleAlignmentColumn col0, MultipleAlignmentColumn col1 , MultipleAlignmentColumn col2, boolean conditionref) {
            return logprob(col0,col1,col2,conditionref,hforward, false);
    }
    
    private double logprob(MultipleAlignmentColumn C0, MultipleAlignmentColumn C1 , MultipleAlignmentColumn C2,boolean conditionref, KmerHasher h, boolean rc) {
        // condition ref - calculate unconditional probability  P(C) = P(C|T)/P(T)      
        if (C0 != null && C0.numSpecies() != numSpecies ) {
                Assert.a(false,"C.numspecies is " + C0.numSpecies() + "  and numSpecies is " + numSpecies);
        }
            
        if (C1 != null && C1.numSpecies() != numSpecies ) {
                Assert.a(false,"C.numspecies is " + C1.numSpecies() + "  and numSpecies is " + numSpecies);
        }
        
        if (C2 != null && C2.numSpecies() != numSpecies ) {
                Assert.a(false,"C.numspecies is " + C2.numSpecies() + "  and numSpecies is " + numSpecies);
        }
        for (int i=0; i<numSpecies; i++) {
            
            int n0 = rc?h.hash(C2==null?'-':C2.nucleotide(i)):h.hash(C0==null?'-':C0.nucleotide(i));
            int n1 = h.hash(C1==null?'-':C1.nucleotide(i));
            int n2 = rc?h.hash(C0==null?'-':C0.nucleotide(i)):h.hash(C2==null?'-':C2.nucleotide(i));
            
            
            for (int j=0; j<64; j++) { 
                P[i][j] = initialprob(n0,n1,n2,j);
            }
        }
        
        
        
        
        for (int step=0; step<T.numSteps(); step++) {
                int node = step + numSpecies;                   
                felsenstein(P[ileft[step]],Tleft[step],P[iright[step]],Tright[step],P[node]);
        }
        double prob = 0;
        for (int i=0; i<64; i++) {
                prob += pi[i] * P[T.numNodes()-1][i];
        }
        
        if (conditionref) {
                for (int i=1; i<T.numSpecies(); i++) {
                        Arrays.fill(P[i], 1.0);
                }
                for (int step=0; step<T.numSteps(); step++) {
                        int node = step + numSpecies;                   
                        felsenstein(P[ileft[step]],Tleft[step],P[iright[step]],Tright[step],P[node]);
                }
                double denom = 0;
                for (int i=0; i<64; i++) {
                        denom += pi[i] * P[T.numNodes()-1][i];
                }
                if ( !(prob/denom < 1.00000001) ) {
                        Assert.a(false , "prob=" + prob + "  denom="+denom);
                }
                
                prob = prob/denom;
        }
        
        if (!(prob > 0)) {
                Assert.a(false,"prob="+prob);
        }
        if ( !(prob < 1.00000001) ) {
                Assert.a(false , "prob=" + prob );
        }
        
        double result = Math.log(prob);  
        return result;
    }

    private static double[] createArrayFromTransitionMatrix(DoubleMatrix2D R) {
            double[] ret = new double[64*64];
            for(int i = 0; i<64; ++i) {
                    for(int j = 0; j<64; ++j) {
                            ret[i*64+j] = R.getQuick(i,j);
                    }                       
            }
            return ret;
    }
    
    private static void felsenstein(double[] lp, double[] lT,double[] rp, double[] rT,double[] pp) {
            for (int i=0; i<64; i++) { 
                    double leftprob=0.0,  rightprob=0.0;
                    for (int j=0; j<64; j++) {
                            leftprob += lT[i*64 + j]*lp[j];
                            rightprob += rT[i*64 + j]*rp[j];
                    }
                    
                    pp[i] = leftprob*rightprob;
            }
            return;
    }
    
    private static double initialprob(int n0,int n1, int n2, int codon){
        int c0 = codon / 16;
        int c1 = (codon % 16) / 4;
        int c2 = (codon % 4);
        if((n0>=4 || n0 == c0) && (n1>=4 || n1 == c1) && (n2>=4 || n2 == c2)) return 1;
        else return 0;
    }

    public void summarize() {
            log.debug("Codon Evolutionary model");
            R.summarize();
    }
}
