#include "em_alg.h"

#define THRSHLD 0.01

EmStruct* makeEmStruct(int maxspan, Dataset *data) {
	int i;
	EmStruct* ems = (EmStruct*)malloc(sizeof(EmStruct));
	ems->maxspan = maxspan;

	ems->oddsratioMat = (double**) malloc(ems->maxspan * sizeof(double*));
	for( i = 0; i < ems->maxspan; i++) {
		ems->oddsratioMat[i] = (double*) malloc(data->numalphas * sizeof(double));
	}

	ems->posOddsratio = (double*) malloc(data->maxSeqLen * sizeof(double));

	return ems;
}

void nilEmStruct(EmStruct* ems, Dataset *data) {
	int i;
	free(ems->posOddsratio);

	for(i = 0; i < ems->maxspan; i++) {
		free(ems->oddsratioMat[i]);
	}
	free(ems->oddsratioMat);
	free(ems);
}

//--------------------------------------------------------------------------------------------
// EM algorithm 
//--------------------------------------------------------------------------------------------

static
double EM_step(Profile *pswm, Dataset *data, double *posOddsratio, double **oddsratioMat)
{
	int i, x, y, pos;
	double sum;
	double ilrval = 0.0;
	double norm;

	for(x = 0; x < pswm->span; x++) {
		for(y = 0; y < data->numalphas; y++) {
			oddsratioMat[x][y] = pswm->mat[x][y] / data->bgfreq[y];
		}
	}

	setProfile(pswm, 0.0);
	for(i=0; i< data->numseqs; i++) {
		//E-step
		sum = 0.0;
		for(pos=0; pos < data->seqLen[i] - pswm->span + 1; pos++)  {
			posOddsratio[pos] = 1.0; 
			for(x=0; x< pswm->span; x++) {
				//if(!isgap(pswm, x)) { //speedup (gap entries are set to 1.0)
					posOddsratio[pos] *= oddsratioMat[x][ ((int)data->seqs[i][pos+x]) ];
				//}
			}
			sum += posOddsratio[pos]; 
		}

		//M-step
		for(pos=0; pos < data->seqLen[i] - pswm->span+1; pos++)  {
			//normalize positional odds-ratio
			norm = posOddsratio[pos] / sum;

			for(x=0; x < pswm->span; x++) {
				//if(!isgap(pswm, x)) { //set gap to 1.0 later (for use of cache)
					pswm->mat[x][ ((int)data->seqs[i][pos+x]) ] += norm;
					//pswm is constructed from the normalized probability value of
					//each nucleotide in the words that are at good positions 
				//}
			}
		}

		//ilrval is not part of the EM, but can be easily computed from EM_step()
		ilrval += log(sum) - log(data->seqLen[i] - pswm->span + 1);
	}

	//normalize pswm
	double sumthisx;
	for(x=0; x< pswm->span; x++)  {
		if(!pswm->isgap[x]) {
			sumthisx = 0.0;
			for(y=0; y < data->numalphas; y++) {
				sumthisx += pswm->mat[x][y];
			}
			for(y=0; y< data->numalphas; y++)   {
				pswm->mat[x][y] /= sumthisx; 
			}
		}
		else {
			for(y=0; y < data->numalphas; y++) { 
				pswm->mat[x][y] = 1.0;
			}
		}
	}
	return ilrval;
}

double runEmSteps(Profile *pswm, Dataset *data, EmStruct *ems, int steps, boolean fastMode) {
	int i;
	double score = 0, origscore = 0, bestscore = -DBL_MAX;
	double thrshld = THRSHLD;

	int stepcount = 0; 
	int plateau = 0;
	for( i = 0; i < steps; i++) {
		score = EM_step(pswm, data, ems->posOddsratio, ems->oddsratioMat);

		if(i == 0) {
			origscore = score;
		}
		
		if(fastMode) {
			if( score  > thrshld + bestscore ) { //thrshld > 0
				bestscore = score;
				plateau = 0;
			}
			else {
				plateau++;
			}
			if(plateau > 2) {
				break;
			}
		}

		stepcount++;
	}

	if(DEBUG3) {
		fprintf(stderr, "Before EM: %.3lf. After EM: %.3lf. Steps: %d\n", 
			origscore, score, stepcount);
	}

	return score;
}
				
