package mlproject.phylo;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;

public class EvolutionaryModel {
    double[] pi; //nucleotide probs
    int[] felsensteinOrder;
    int[] l;
    int[] r;
    double[] d;
    int[] leaves;
    double[][] substitutionProbsL; //precomputed left child probs
    double[][] substitutionProbsR; //precomputed right child probs
    double[][] probTable;
    
    private HashMap<String,Double> computedColumns = new HashMap<String,Double>();
    
    public EvolutionaryModel(PhylogeneticTree tree, double[] pi, NucleotideSubstitutionModel model){
        this.pi = pi;
        felsensteinOrder = tree.getFelsensteinOrder();
        l = tree.getL();
        r = tree.getR();
        d = tree.getD();
        leaves = tree.getLeaves();
        probTable = new double[felsensteinOrder.length][4]; //dynamic programming space
        substitutionProbsL = new double[felsensteinOrder.length][16];
        substitutionProbsR = new double[felsensteinOrder.length][16];
        int[] p = tree.getP();
        /* precompute edge probs */
        for(int i=0;i<p.length;i++){
            if(p[i] == -1) continue;
            double[][] probs  = model.getSubstitutionProbs(d[i]);
            if(i == l[p[i]]){
                for(int x=0;x<4;x++)
                    for(int y=0;y<4;y++)
                        substitutionProbsL[p[i]][x*4 + y] = probs[x][y];
            }  else {
                for(int x=0;x<4;x++)
                    for(int y=0;y<4;y++)
                        substitutionProbsR[p[i]][x*4 + y] = probs[x][y];        
            }
        }
    }
    
    public double getColumnLogProbability(String column, boolean complement) {
        /* Dont calculate in log space, the tree shouldnt be large hopefully :-)*/
        Double logprob = computedColumns.get(column); // first try to find the column in hashtable
        if(logprob != null) return logprob;
        
        if (column.length() != leaves.length)
            throw new RuntimeException("Incorrect size of column");


        /* calculate initial probabilities */
        if(complement){
            for (int i = 0; i < leaves.length; i++) {
                Arrays.fill(probTable[leaves[i]], 0);
                switch (column.charAt(i)) {
                case 'A':
                    probTable[leaves[i]][3] = 1;break;
                case 'C':
                    probTable[leaves[i]][2] = 1;break;
                case 'G':
                    probTable[leaves[i]][1] = 1;break;
                case 'T':
                    probTable[leaves[i]][0] = 1;break;
                default:
                    Arrays.fill(probTable[leaves[i]], 1); // set all 1
                }
            }
        } else {
            for (int i = 0; i < leaves.length; i++) {
                Arrays.fill(probTable[leaves[i]], 0);
                switch (column.charAt(i)) {
                case 'A':
                    probTable[leaves[i]][0] = 1;break;
                case 'C':
                    probTable[leaves[i]][1] = 1;break;
                case 'G':
                    probTable[leaves[i]][2] = 1;break;
                case 'T':
                    probTable[leaves[i]][3] = 1;break;
                default:
                    Arrays.fill(probTable[leaves[i]], 1); // set all 1
                }
            }
        }
        /* calculate table */
        for (int i = 2; i < felsensteinOrder.length; i++) {
            int currentNode = felsensteinOrder[i]; //get index of next node to be calculated
            if(l[currentNode] == -1) continue; //leaf
            for (int j = 0; j < 4; j++) {
                
                double probLeft = 0;
                double probRight = 0;
                for (int k = 0; k < 4; k++) {
                    probLeft += substitutionProbsL[currentNode][j*4+k] * probTable[l[currentNode]][k];
                    probRight += substitutionProbsR[currentNode][j*4+k] * probTable[r[currentNode]][k];
                    
                }
                probTable[currentNode][j] = probLeft * probRight;
            }
        }
        logprob =  Math.log((pi[0]*probTable[0][0] + pi[1]*probTable[0][1] + pi[2]*probTable[0][2] + pi[3]*probTable[0][3]) / 4);
        computedColumns.put(column,logprob); // hash current value
        return logprob;
    }
    
}
