#include "hints.h"
	
State::State(int stateNum, vector<int> *conv, vector<int> *backConv, int index) 
{
	conversion = conv;
	backConversion = backConv;
	this->stateNum = stateNum;

	//hint = index; 

	states = NULL;
}

State::~State()
{
	if (states) {
		deinit();
	}
}

void State::init()
{
	states = new vector<double>();
	states->resize(stateNum, MININFTY);
	lastState.resize(stateNum);

	for (unsigned int i=0; i<lastState.size(); i++) {
		lastState[i].back = NULL;
		if ((conversion != NULL) && (backConversion != NULL)) lastState[i].state = backConversion->at(i);
		else lastState[i].state = i;
	}
}

void State::deinit()
{
	delete states;
	states = NULL;

	// something with lastStates?
}

double &State::getStat(int index)
{
	//if (states.size() == 0) init();
	if (states == NULL) init();
	if (conversion != NULL) index = conversion->at(index);
	assert(index != -1);

	return states->at(index);
}

Reference &State::getLastStat(int index)
{
	//if (lastState.size() == 0) init();
	if (states == NULL) init();
	if (conversion != NULL) index = conversion->at(index);
	assert(index != -1);

	return lastState[index];
}

Hint::Hint(int number, double bonus, bool interval)
{
	this->bonus = bonus;
	this->number = number;
	this->interval = interval;

	intervals.clear();
}

void Hint::AddInterval(int start, int stop, char label)
{
	SimpleInt s; 
	s.beg = start;
	s.end = stop-1;
	s.label = label;

	if (intervals.size() != 0) {
		assert(start == intervals[intervals.size()-1].end+1);
	}

	intervals.push_back(s);
}

void Hint::finalize(int index, int stateNum) 
{
	this->index = index;
	st.clear();
	for (unsigned int i=0; i<intervals.size(); i++) {
		st.resize(st.size() + intervals[i].end - intervals[i].beg +1, 
			State(hmm.hintToState[intervals[i].label].size(), 
			&(hmm.hintToStateFull[intervals[i].label]),
			&(hmm.hintToState[intervals[i].label]), number));
	}
	
}

void Hint::deinit()
{
	for (unsigned int i=0; i<st.size(); i++) {
		st[i].deinit();
	}

	//intervals.clear();
	subsets.clear();
}

void Hint::transition()
{
	unsigned int subIndex = 0;

	// for each interval
		// first interval has prefilled start state, others not
	for (unsigned int i=0; i<intervals.size(); i++) {
		// whole interval transition
		transition(i, subIndex);

		// transition to another interval
		if (i != intervals.size()-1) {
			vector<bool> *allowed = &(hmm.allowedState[intervals[i+1].label]);
			vector<int>  *allowedPrev = &(hmm.hintToState[intervals[i].label]);

			unsigned int prevState = intervals[i].end;  // end of previous hint

			// count if any subsets are here
			double subsetBonus = 0.0;
			while ((subIndex < subsets.size()) && (subsets[subIndex]->end() == prevState+1)) { 
				subsetBonus += subsets[subIndex]->bonus; 
				subIndex++; 
			}

			vector<int>::iterator it;
			vector<pair<int, double> >::iterator it2;

			// from last computed st[] to next with respect to restriction 
			for (it=allowedPrev->begin(); it!=allowedPrev->end(); it++) {
				int ps = prevState - begin();
				if (st[ps].getStat(*it) <= MININFTY) continue;
				for (it2=hmm.transMatrix[*it].begin(); it2!=hmm.transMatrix[*it].end(); it2++) {
					// if state in next node is allowed
					if (allowed->at(it2->first)) {
						// newval = oldval + transition + emission + subset bonus
						double newVal = st[ps].getStat(*it) + it2->second + hmm.emit(it2->first, prevState+1) + subsetBonus;
						if (st[ps+1].getStat(it2->first) < newVal) {
							st[ps+1].getStat(it2->first) = newVal;
							// remember the path...
							st[ps+1].getLastStat(it2->first).back = &(st[ps].getLastStat(*it));
						}
					}
				}
			}
		}
	}

	// write result to right place
	char restriction = intervals[intervals.size()-1].label;
	vector<int>::iterator it;

	// -- to position in viterbi
	for (it=hmm.hintToState[restriction].begin(); it!=hmm.hintToState[restriction].end(); it++) {
		if (st[st.size()-1].getStat(*it) > hmm.viterbi[end()].getStat(*it)) {
			hmm.viterbi[end()].getStat(*it) = st[st.size()-1].getStat(*it);
			// remember path - go double back cause of backward pass
			hmm.viterbi[end()].getLastStat(*it).back = st[st.size()-1].getLastStat(*it).back;
		}
	}
		
	// -- to compatible hints in right place
	unsigned int j = index +1;
	while ((j < hmm.hints.size()) && hmm.hints[j].begin() < end()) {
		
		// if incompatible do nothing
		if (!compatible(*this, hmm.hints[j])) { j++; continue; }

		// pos in the 2nd hint
		unsigned int pos = end() - hmm.hints[j].begin();
		assert(pos>=0 && pos <hmm.hints[j].st.size());

		//State &stat = hmm.hints[j].st[pos];

		for (it=hmm.hintToState[restriction].begin(); it!=hmm.hintToState[restriction].end(); it++) {
			if (st[st.size()-1].getStat(*it) > hmm.hints[j].st[pos].getStat(*it)) {
				// newval + bonus
				hmm.hints[j].st[pos].getStat(*it) = st[st.size()-1].getStat(*it) + hmm.hints[j].bonus;
				// remember path - go double back cause of backward pass
				hmm.hints[j].st[pos].getLastStat(*it).back = st[st.size()-1].getLastStat(*it).back;
			}
		}
		j++;
	}
}

void Hint::transition(int interval, unsigned int &subIndex) 
{
	int begin = intervals[interval].beg;	// begin of interval
	int end = intervals[interval].end;		// end of interval
	char restriction = intervals[interval].label;	

	int length = end - this->begin() + 1;	// length from beginning of hint to end of interval
	int i = begin - this->begin() + 1;		// iteration for s[i]

	vector<int>::iterator it;
	vector<pair<int, double> >::iterator it2;

	vector<bool> *allowed = &(hmm.allowedState[restriction]);

	while (i < length) {
		// count if any subsets are here
		double subsetBonus = 0.0;
		while ((subIndex < subsets.size()) && (subsets[subIndex]->end() == i + this->begin())) { 
			subsetBonus += subsets[subIndex]->bonus; 
			subIndex++; 
		}

		// only restricted states
		for (it=hmm.hintToState[restriction].begin(); it!=hmm.hintToState[restriction].end(); it++) {
			if (st[i-1].getStat(*it) <= MININFTY) continue;
			for (it2=hmm.transMatrix[*it].begin(); it2!=hmm.transMatrix[*it].end(); it2++) {
				// if state in next node is allowed
				if (allowed->at(it2->first)) {
					// newval = oldval + transition + emission + subset bonus
					double newVal = st[i-1].getStat(*it) + it2->second + hmm.emit(it2->first, i + this->begin()) + subsetBonus;
					if (st[i].getStat(it2->first) < newVal) {
						st[i].getStat(it2->first) = newVal;
						// remember the path...
						st[i].getLastStat(it2->first).back = &(st[i-1].getLastStat(*it));
					}
				}
			}
		}
		i++;
	}
}

bool compatible(Hint &first, Hint &second, bool allowSubset)
{
	// wrong order
	if (second.begin() < first.begin()) return false;
	// not overlapping
	if (second.begin() > first.end()) return false;
	// subset
	if (!allowSubset && (second.end() <= first.end())) return false;

	unsigned int i = 0;
	unsigned int j = 0;
	while (first.intervals[i].end < second.intervals[j].beg) {		
		i++;
	}

	// compatible simple intervals
	while (i<first.intervals.size() && j<second.intervals.size()) {
		if (first.intervals[i].label != second.intervals[j].label) return false;
		if (first.intervals[i].end == second.intervals[j].end) { i++; j++; }
		else {	
			if (first.intervals[i].end > second.intervals[j].end) j++;
			else i++; 
		}
	}

	return true;
}


