#include "hmm.h"

HMM::HMM()
{
    usable = false;
}

HMM::~HMM()
{
    // delete emission matrices
    for (unsigned int i=0; i<emissMatrix.size(); i++)
        for (unsigned int j=0; j<emissMatrix[i].size(); j++)
			delete emissMatrix[i][j];
}

int HMM::LoadHMM(char* file)
{
    FILE* input;
    input = fopen(file, "r");
    if (input==NULL) {
        //fclose(input);
        return -1;
    }

    // first states and alphabet
    int hasToBeOne;
    int ret = fscanf(input, "%d %d %d", &stateNum, &alphabetNum, &hasToBeOne);
    if (ret != 3) {
    	fclose(input);
			return -1;
		}
    stateNames.resize(stateNum);
	stateGrade.resize(stateNum);
    stateLabel.resize(stateNum);
    start.resize(stateNum);
    for (int i=0; i<stateNum; i++) {
		ret = fscanf(input, "%lf", &start[i]);
		if (ret != 1) {
    	fclose(input);
			return -1;
		}
	}

    // transition matrix
    int k;
    double d;
    transMatrix.resize(stateNum);
    for (int i=0; i<stateNum; i++) {
        for (int j=0; j<stateNum; j++) {
            ret = fscanf(input, "%lf", &d);
            if (ret != 1) {
		    	fclose(input);
					return -1;
			}
            if (d>0) transMatrix[i].push_back(make_pair(j, log(d)));
        }
    }

    // emission matrix
    emissMatrix.resize(1); // gccontent = 0
    for (int i=0; i<stateNum; i++) {
		ret = fscanf(input, "%d", &k);
		if (ret != 1) {
			fclose(input);
			return -1;
		}
		DoubleMatrix* emiss;
		// write down grade of state...
		stateGrade[i] = k;
		//printf("$$%d$$%d %d fffffff",k,i,stateNum)
		int *ndim = new int[k+2];
		for (int j=0; j<k+2; j++) {
			ndim[j] = alphabetNum+1;
		}
	
		// we create emissmatrix and fill it from file + 
		// recompute alphabetNum+1 positions
		// normalize and log it...
		emiss = new DoubleMatrix(k+1, ndim, 0.0);	
		emiss->readRecompute(input);
		emiss->normalize();
		emiss->logAll();

		/*FILE *f;
		f = fopen("test.fil", "w");
		emiss->writeSimple(f);
		fclose(f);*/
        
		emissMatrix[0].push_back(emiss);
    }

    // states
    char str[100];
    char ch;
    for (int i=0; i<stateNum; i++) {
        ret = fscanf(input, "state %d ", &k);
        if (ret != 1) {
			fclose(input);
			return -1;
		}
        ret = fscanf(input, "%s", str);
        if (ret != 1) {
			fclose(input);
			return -1;
		}
        ret = fscanf(input, " %c\n", &ch);
		if (ret != 1) {
			fclose(input);
			return -1;
		}

        stateNames[k] = str;
        stateLabel[k] = ch;
		if (stateLabelMap.find(ch) == stateLabelMap.end()) {
			stateLabelMap.insert(make_pair(ch, vector<int>()));
		}
		stateLabelMap[ch].push_back(k);
        stateMap.insert(make_pair(str, k));
    }

    fclose (input);
    usable = true;
    return 0;
}

int HMM::SetNLoadSeqFA(char* file)
{
    seq_input = fopen(file, "r");
    if (seq_input==NULL) {
        //fclose(input);
        return -1;
    }
	return LoadSeqFA();
}

int HMM::LoadSeqFA()
{
	bool done = false;

	// N's in the begining
	dnaSeq.push_back(alphabetNum); gcContent.push_back(0);
	dnaSeq.push_back(alphabetNum); gcContent.push_back(0);
	dnaSeq.push_back(alphabetNum); gcContent.push_back(0);
	dnaSeq.push_back(alphabetNum); gcContent.push_back(0);
	startSeq = 4;

	char ch;
	char str[100];

    while (!feof(seq_input)) {
        int ret = fscanf(seq_input, "%c", &ch);
        if (ret != 1) {
			continue;
		}
		if (ch == '>' || ch == ';') {
			// next chromosome...
			if (done) {
				fseek(seq_input, -1, SEEK_CUR);
				break;
			}
			ret = fscanf(seq_input, "%s", str);
			if (ret != 1) chromosome = "__noname__";
			else chromosome = str;
			done = true;
			continue;
		}
		bool written = true;
		switch (ch) {
			case 'a':
			case 'A': dnaSeq.push_back(0); break;
			case 'c':
			case 'C': dnaSeq.push_back(1); break;
			case 'g':
			case 'G': dnaSeq.push_back(2); break;
			case 't':
			case 'T': dnaSeq.push_back(3); break;
			case 'n':
			case 'N': dnaSeq.push_back(4); break;
			default: written = false; break;
		}
		if (written) gcContent.push_back(0);
    }
	
	// read nothing
	if (dnaSeq.size() == startSeq) return -1;

	// make empty states in viterbi
	viterbi.resize(dnaSeq.size(), State(stateNum));

	// set the beginning 
	for (int i=0; i<stateNum; i++) {
		viterbi[startSeq-1].getStat(i) = log(start[i]);
	}

    return 0;
}

bool compareHints (const Hint &i, const Hint &j) 
{ 
	return (i.begin()<j.begin()); 
}

bool compareHintsEnd (Hint *i, Hint *j)
{
	return (i->end()<j->end());
}

int HMM::SetNLoadHints(char* file, char* bonus_chars, bool multiply, bool mult_sqrt)
{
    hint_input = fopen(file, "r");
    if (hint_input == NULL) {
        return -1;
    }

	if (multiply && mult_sqrt) {
		printf("WARNING: incompatible switch: -int && -sqrt, using -int only\n");
		mult_sqrt = false;
	}
	
	this->multiply = multiply;
	this->mult_sqrt = mult_sqrt;
	
	FILE *bonus_file;
	bonus_file = fopen(bonus_chars, "r");
    if (bonus_file != NULL) {
		while (!feof(bonus_file)) {
			double bonus;
			char type[100];
			int ret = fscanf(bonus_file, "%s %lf", type, &bonus);
			if (ret != 2) {
				break;
			}
			string typestr = type;
			bonuses.insert(make_pair(typestr, bonus));
		}
		fclose(bonus_file);
	}

	return LoadHints();
}

int HMM::LoadHints() 
{
	int num, start, stop;
	int interva;
	bool interval;
	char bonus[50];
	char label;

	int ret;

	if (hint_input == NULL) {
		return -2;
	}

	char chrom[100];
	ret = fscanf(hint_input, ">%s", chrom);
	if (ret != 1) {
		return -4;
	}

	printf("Hints from chromosome part: %s\n", chrom);


	while (true) {
		ret = fscanf(hint_input, "%d %s %d", &num, bonus, &interva);
		interval = (interva > 0);
		if (ret != 3) {
			ret = 1;
			break;
		}
		// retrieve bonus from category or convert str to float
		double bon;
		string bonustype = bonus;
		if (bonuses.count(bonustype) > 0) {
			bon = bonuses[bonustype];	
		}
		else sscanf(bonus, "%lf", &bon);

		Hint hint(num, bon, interval);
		
		if (interval) {
			ret = fscanf(hint_input, " - %d", &stop);
			if (ret != 1) break;
			for (int i=0; i<interva; i++) {
				start = stop;
				ret = fscanf(hint_input, " %c %d", &label, &stop);
				if (ret != 2) break;
				hint.AddInterval(start + startSeq, stop + startSeq, label);
			}
		} else {
			 ret = fscanf(hint_input, " - %d %c", &start, &label);
			if (ret != 2) break;
			hint.AddInterval(start + startSeq, start + startSeq + 1, label);
		}

		// multiplying bonus
		if (multiply)  hint.bonus *= hint.intervals.size();
		if (mult_sqrt) hint.bonus *= sqrt((double)hint.intervals.size());

		// does hint fit in sequence???
		if (hint.intervals[hint.intervals.size()-1].end >= (int)hmm.viterbi.size()) {
			printf ("bad hint: %d\n", hint.number);
		} else {
			hints.push_back(hint);
		}
	}

	// sort according to beginnings
	sort(hints.begin(), hints.end(), compareHints);

	// subsets add to local bonus... 
	for (unsigned int i=0; i<hints.size(); i++) {
		unsigned int j = i+1;
		// hint to be considered
		while (j<hints.size() && (hints[j].begin() <= hints[i].end())) {
			if ((hints[j].end() <= hints[i].end()) && compatible(hints[i], hints[j], true)) {
				// add hint
				hints[i].subsets.push_back(&hints[j]);
				// sort hints according to ends (needed in subset bonuses)
				sort(hints[i].subsets.begin(), hints[i].subsets.end(), compareHintsEnd);
			}	
			j++;
		}
	}

	// finalize (create states)
	for (unsigned int i=0; i<hints.size(); i++) {
		hints[i].finalize(i, stateNum);
	}

    return ret-1;
}

int HMM::LoadLabels(char* file)
{
	FILE* input;
    input = fopen(file, "r");
    if (input==NULL) {
        return -1;
    }
	int ret = 1;

	stateNumToHint.resize(stateNum);
		
	while (!feof(input)) {
		char hint;
		char state;
		ret = fscanf(input, "%c - ", &hint);
		if (ret != 1) break;
		if (hintToStateLabel.find(hint) == hintToStateLabel.end()) {
			hintToStateLabel.insert(make_pair(hint, vector<char>()));
			hintToState.insert(make_pair(hint, vector<int>()));
			allowedState.insert(make_pair(hint, vector<bool>(stateNum, false)));
			hintToStateFull.insert(make_pair(hint, vector<int>(stateNum, -1)));
		}
		while (true) {
			state = fgetc(input);
			if (state == '\n' || state == EOF) break;
			if (state == ' ') continue;
			hintToStateLabel[hint].push_back(state);
			for (unsigned int i=0; i<stateLabelMap[state].size(); i++) {
				hintToState[hint].push_back(stateLabelMap[state][i]);
				allowedState[hint][stateLabelMap[state][i]] = true;
				hintToStateFull[hint][stateLabelMap[state][i]] = hintToState[hint].size()-1;
				stateNumToHint[stateLabelMap[state][i]] = hint;
			}
			stateToHint.insert(make_pair(state, hint));
		}
	}

	fclose(input);
    return ret-1;
}

double HMM::emit(int state, int pos)
{
	int grade = stateGrade[state]; 

	for (int i=0; i<=grade; i++) {
		dim[i] = dnaSeq[pos - grade + i];
	}

	return emissMatrix[gcContent[pos]][state]->operator [](dim); 
}

void HMM::transition(int pos)
{
	// pos is already computed, compute viterbi[pos+1] from viterbi[pos]
	// = solve one edge from position pos.

	vector<pair<int, double> >::iterator it2;

	for (int j=0; j<viterbi[pos].getSize(); j++) {
		if (viterbi[pos].getStat(j) <= MININFTY) continue;
		for (it2=transMatrix[j].begin(); it2!=transMatrix[j].end(); it2++) {
			// newval = oldval + transition + emission
			double newVal = viterbi[pos].getStat(j) + it2->second + emit(it2->first, pos+1);
			if (viterbi[pos+1].getStat(it2->first) < newVal) {
				viterbi[pos+1].getStat(it2->first) = newVal;
				// remember the path...
				viterbi[pos+1].getLastStat(it2->first).back = &(viterbi[pos].getLastStat(j));
			}
		}
	}
}

void HMM::transition(int pos, Hint &hint)
{
	// pos is already computed, compute next(beginning of hint) from viterbi[pos]
	// = solve one edge from position pos.

	vector<bool> *allowed = &(allowedState[hint.intervals[0].label]);

	vector<pair<int, double> >::iterator it2;

	for (int j=0; j<viterbi[pos].getSize(); j++) {
		if (viterbi[pos].getStat(j) <= MININFTY) continue;
		for (it2=transMatrix[j].begin(); it2!=transMatrix[j].end(); it2++) {
			if (allowed->at(it2->first)) {
				// newval = oldval + transition + emission + bonus
				double newVal = viterbi[pos].getStat(j) + it2->second + emit(it2->first, pos+1) + hint.bonus;
				if (hint.st[0].getStat(it2->first) < newVal) {
					hint.st[0].getStat(it2->first) = newVal;
					// remember the path...
					hint.st[0].getLastStat(it2->first).back = &(viterbi[pos].getLastStat(j));
				}
			}
		}
	}
}

bool HMM::respected(Hint &hint)
{
	vector<SimpleInt>::iterator it;
	for (it=hint.intervals.begin(); it!=hint.intervals.end(); it++) {
		for (int i=(*it).beg; i<=(*it).end; i++) {
			if (stateNumToHint[resStates[i-startSeq]] != it->label) return false;
		}
	}
	return true;
}

double HMM::ComputeViterbi()
{
	unsigned int progress = startSeq-1;
	unsigned int pivotHint = 0;

	unsigned int percent = 10;
	// forward pass
	while (progress < viterbi.size()-1) {
		if (progress == viterbi.size()*percent/100) {
			printf("%d%% done\n", percent);
			percent += 10;
		}
		// first whole hints
		while ((pivotHint < hints.size()) && (hints[pivotHint].begin() == progress)) {
			hints[pivotHint].transition();
			// clear space
			hints[pivotHint].deinit();

			pivotHint++;
		}
		
		// now edges
			// first classic viterbi
		transition(progress);
	
			// now hints 
		unsigned int pivot2 = pivotHint;
		while ((pivot2 < hints.size()) && (hints[pivot2].begin() == progress+1)) {
			transition(progress, hints[pivot2]);
			pivot2++;
		}

		// delete unwanted space
		viterbi[progress].deinit();

		progress++;
	} 

	// backward pass
		// get maximal value
	Reference *ref = NULL;
	double maxVal = MININFTY;
	for (int i=0; i<viterbi[progress].getSize(); i++) {
		if (viterbi[progress].getStat(i) > maxVal) {
			maxVal = viterbi[progress].getStat(i);
			ref = &(viterbi[progress].getLastStat(i));
		}
	}
	
	State *st; 
	st = &viterbi[progress];
	
	resStates.resize(dnaSeq.size() - startSeq, -1);
		
	//vector<int> hnt(progress+1, -1); // debug only

		// trace back
	while (progress >= startSeq) {
		//hnt[progress - startSeq] = ref->hint;
		resStates[progress - startSeq] = ref->state;
		
		assert(ref != NULL);
		ref = ref->back;

		progress--;
	}

	// respected hints
	for (unsigned int i=0; i<hints.size(); i++) {
		if (respected(hints[i])) respectedHints.push_back(hints[i].number);
	}
	sort(respectedHints.begin(), respectedHints.end());
	
	return maxVal;
}

int HMM::SetAnnoOutput(char* file)
{
	if (file == NULL) output = stdout;
	else {
		output = fopen(file, "w");
		if (output==NULL) {
			//fclose(input);
			return -1;
		}
	}
	return 0;
}

int HMM::PrintAnnotation()
{
	if (output == NULL) {
		printf("unable to write... \n");
		return -1;
	}
	fprintf(output, ">%s", chromosome.c_str());
	for (unsigned int i=0; i<resStates.size(); i++) {
		if (i%80==0) fprintf(output, "\n");
		fprintf(output, "%c", stateLabel[resStates[i]]);
	}
	fprintf(output, "\n");
	return 0;
}

int HMM::PrintAnnotationGTF(char* file)
{
	FILE* output;
	if (file == NULL) output = stdout;
	else {
		output = fopen(file, "w");
		if (output==NULL) {
			//fclose(input);
			return -1;
		}
	}
	unsigned int ifrom = 0;
	unsigned int ito = 0;
	for (unsigned int i=0; i<resStates.size(); i++) {
		bool to_print = false;
		bool to_printc = false;
		char strand = '.';
		string type = "__none__";
		switch (stateLabel[resStates[i]]) {
			case 'E': type = "stop_codon"; strand = '-'; to_printc = true; ifrom = i+4; break;
			case 'e': type = "stop_codon"; strand = '+'; to_printc = true; to_print = true; ito = i-2; break;
			case 'B': type = "start_codon"; strand = '-'; to_printc = true; to_print = true; ito = i+1; break;
			case 'b': type = "start_codon"; strand = '+'; to_printc = true; ifrom = i+1; break;

			case 'A': type = "CDS"; strand = '-'; to_print = true; ito = i; break;
			case 'a': ifrom = i+2; break;
			case 'D': ifrom = i+2; break;
			case 'd': type = "CDS"; strand = '+'; to_print = true; ito = i; break;
			default: break;
		}
		if (to_printc) {
			int j = (	(stateLabel[resStates[i]] == 'E') || 
						(stateLabel[resStates[i]] == 'b') ? i+1 : i-1);
			fprintf(output, "%s\t%s\t%s\t%d\t%d\t%c\t%c\t%c\n",
				chromosome.c_str(), "grapHMM", type.c_str(), j, j+2, '.', strand, '0');
		}
		if (to_print) {
			char frame;
			switch ((strand == '+' ? stateLabel[resStates[ i-1]] : stateLabel[resStates[ifrom]])) {
				case '0': frame = '0'; break;
				case '3': frame = '3'; break; 
				case '1': frame = '1'; break;
				case '4': frame = '4'; break;
				case '2': frame = '2'; break;
				case '5': frame = '5'; break;
				default: frame = '?'; break;
			}
			fprintf(output, "%s\t%s\t%s\t%d\t%d\t%c\t%c\t%c\n",
				chromosome.c_str(), "grapHMM", "CDS", ifrom, ito, '.', strand, frame);
		}
	}

	if (file != NULL) fclose(output);
	return 0;
}

int HMM::SetHintOutput(char* file)
{
	if (file == NULL) resp_hints = stdout;
	else {
		resp_hints = fopen(file, "w");
		if (resp_hints==NULL) {
			return -1;
		}
	}
	return 0;
}

int HMM::PrintHints()
{
	if (resp_hints == NULL) {
		printf("unable to write resp. hints... \n");
		return -1;
	}
	fprintf(resp_hints, ">%s Respected hints (%d):\n", chromosome.c_str(), (int)respectedHints.size());
	for (unsigned int i=0; i<respectedHints.size(); i++) {
		fprintf(resp_hints, "%d ", respectedHints[i]);
	}
	fprintf(resp_hints, "\n");
	return 0;
}

int HMM::ComputeAll()
{
	bool next = true;
	int count = 0;

	while (next) {
		// compute 
		time_t start = clock();
		printf ("Starting chromosome part:   %s\n", chromosome.c_str());
		double res = hmm.ComputeViterbi();
		printf("%s done in %.2f sec. (end probability = %.3f)\n", chromosome.c_str(), (clock() - start)/(double)CLOCKS_PER_SEC, res);

		// print results
		PrintAnnotation();
		PrintHints();

		printf("\n");

		// Load next chromosome part
		next = NextChromosome();

		// count the chromosome parts
		count++;
	}

	// close files
	if (output != NULL) fclose(output);
	if (resp_hints != NULL) fclose(resp_hints);
	fclose(seq_input);
	fclose(hint_input);

	return count;
}


bool HMM::NextChromosome()
{
	// next sequence + clear space
	dnaSeq.clear();
	gcContent.clear();
	viterbi.clear();
	int ret = LoadSeqFA();
	if (ret != 0) {
		return false;
	}
	
	// next hints
	hints.clear();
	ret = LoadHints();
	if (ret != 0) {
		// no hints, but whatever
		//return false;
	}

	// clear results
	resStates.clear();
	respectedHints.clear();

	return true;
}


