package mlproject.io;

import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;

import mlproject.hmm.StateModel;

public class GTFWriter {
    PrintWriter pw;
    StateModel stateModel;
    
    public GTFWriter(StateModel stateModel, String outputFile) throws IOException {
        this.stateModel = stateModel;
        pw = new PrintWriter(new FileWriter(outputFile));
    }
    
    
    
    public void writeStates(String name, int[] states){
        int[][] stateTypes = new int[][]{
            stateModel.getIntergenicStates(),
            stateModel.getPlusStartCodonStates(),
            stateModel.getPlusExonicStates(),
            stateModel.getPlusStopCodonStates(),
            stateModel.getMinusStopCodonStates(),
            stateModel.getMinusStartCodonStates(),
            stateModel.getMinusExonicStates()};
        
        int start=0;
        int frame = 0;
        int geneCount = 1;
        int lastType = getStateType(stateTypes,states[0]);
        for(int i=1;i<states.length;i++){
            int currentType = getStateType(stateTypes,states[i]);
            if(currentType != lastType){
                if(currentType == 0){
                    frame = 0;
                }
                
                
                if(currentType == 1){
                    start = i;
                    pw.println(getGtfLine(name,"start_codon",i+1,i+3,"+",".",geneCount));
                }
                if(currentType == 4){
                    start = i+3;
                    pw.println(getGtfLine(name,"stop_codon",i+1,i+3,"-",".",geneCount));
                }
                if(currentType == 5){
                    frame = (i - start + 1 - (3-frame) % 3) % 3;
                    pw.println(getGtfLine(name,"CDS",start+1,i+3,"-",Integer.toString(frame),geneCount));
                    pw.println(getGtfLine(name,"start_codon",i+1,i+3,"-",".",geneCount));
                    geneCount++;
                }
                if(currentType == 3){
                    pw.println(getGtfLine(name,"CDS",start+1,i,"+",Integer.toString(frame),geneCount));
                    pw.println(getGtfLine(name,"stop_codon",i+1,i+3,"+",".",geneCount));
                    geneCount++;
                }
                
                if(currentType == -1){
                    if(lastType == 2){
                        pw.println(getGtfLine(name,"CDS",start+1,i,"+",Integer.toString(frame),geneCount));
                        frame = (3 - (i - start + 1 - frame) % 3) % 3;
                    }
                    if(lastType == 6){
                        frame = (i - start + 1 - (3-frame) % 3) % 3;
                        pw.println(getGtfLine(name,"CDS",start+1,i,"-",Integer.toString(frame),geneCount));
                    }
                }
                
                if(currentType == 2 || currentType == 6){
                    if(lastType == -1){
                        start = i;
                    }
                }
                
                lastType = currentType;
            }
        }
    }
    
    private String getGtfLine(String name, String type, int start, int stop, String strand, String frame, int geneId){
        return name + "\t" +".\t" + type +"\t"+start+"\t"+stop+"\t.\t"+strand+"\t"+
               frame+"\t"+"gene_id \""+name+"_"+geneId+ "\"; transcript_id \""+name+"_"+geneId+"\";";
    }
    
    private int getStateType(int[][] stateTypes, int state){
        for(int i=0;i<stateTypes.length;i++)
            for(int j=0;j<stateTypes[i].length;j++)
                if(stateTypes[i][j] == state)
                    return i;
        return -1;
    }
    
    public void close(){
        pw.close();
    }
}
