#include "gibbsilr.h"

enum IterMode { TRIAL, NORMAL }; //see updateOneIter


//----------------------------------------------------------------------
// Core functions
//----------------------------------------------------------------------
/**
* One iteration of Gibbs-sampling. This samples a new site for each sequence.
* 
* param: 
*   countmat - count matrix at the beginning of this iteration
*   sites - sites at the beginning of this iteration
*   mode - TRIAL is for quick sample and the entropy score is not computed; 
*          NORMAL otherwise
* return: 
*   returns the entropy score of the best pswm/sites.
**/
static
double updateOneIter(Profile *countmat, int *sites, Gibbs *gibbs, enum IterMode mode)
{
	int i, a;
	int j, m;
	int newsite; //new site chosen
	double entropy = -DBL_MAX; //entropy of this iteration
	double sum = 0.0;
	double drand; //double random number
	double *posScore = gibbs->posScore; //positional odds-ratio
	double scoremat[MAX_MOTIF_WIDTH][NUMALPHAS]; //current model over random model

	//does not use temperature under TRIAL mode
	boolean useTemperature = (mode == TRIAL ? FALSE : gibbs->useTemperature);

	for(i = 0; i < gibbs->numseqs; i++) {
		//remove site and remove site count from count-matrix
		removeSite(i, sites, countmat, gibbs->data);

		//numsites = numseqs-1 because one site is removed
		for(m = 0; m < countmat->span; m++) {
			for(a = 0; a < gibbs->numalphas; a++) {
				scoremat[m][a] = ((countmat->mat[m][a] + gibbs->data->pseudocount[a])
					/ (gibbs->numseqs-1 + gibbs->data->sumOfPseudocount)) 
					/ gibbs->data->bgfreq[a];
			}
		}

		if(useTemperature) {
			for(m = 0; m < countmat->span; m++) {
				for(a = 0; a < gibbs->numalphas; a++) {
					scoremat[m][a] = pow(scoremat[m][a], gibbs->recipTemp);
				}
			}
		}

		sum = 0.0;
		for(j = 0; j < gibbs->data->seqLen[i] - countmat->span + 1; j++) {
			posScore[j] = 1.0;
			for(m = 0; m < countmat->span; m++) {
				int alpha = (int)gibbs->data->seqs[i][j+m];
				posScore[j] *= scoremat[m][alpha];
			}
			sum += posScore[j];
		}

		//Sampling step
		drand = Random() * sum;
		sum = 0.0;
		//if drand is a epsilon bigger than sum, then uses the last site
		newsite = gibbs->data->seqLen[i] - countmat->span; 
		for(j = 0; j < gibbs->data->seqLen[i] - countmat->span + 1; j++) {
			sum += posScore[j];
			if(drand <= sum) {
				newsite = j;
				break;
			}
		}

		//add new site and add site count to count-matrix
		addSite(newsite, i, sites, countmat, gibbs->data);
	}

	if(mode == NORMAL) {
		entropy = computeEntropyFromCount(countmat, gibbs->data);
	}
	return entropy;
}

/**
* One run of Gibbs-sampling. A run is given a set of starting positions (sites).
* 
* param: 
*   rnode - run-node that contains the new starting positions (sites)
*   initspan - initial span of the run
*   mode - TRIAL is for quick sample and the entropy score is not computed; 
*          NORMAL otherwise
**/
static
void runIters(RunNode *rnode, int initspan, Gibbs *gibbs) {
	int iterplat = 0; 
	int iters = 0;
	int i;
	double entropy;
	double maxval = -DBL_MAX;

	Profile *countmat = initProfile(initspan, gibbs->span);

	int *sites = (int*) malloc(gibbs->numseqs * sizeof(int));
	for(i = 0; i < gibbs->numseqs; i++) {
		sites[i] = rnode->sites[i];
	}

	//the best are stored in rnode
	Profile *bestcountmat = rnode->countmat;
	int *bestsites = rnode->sites;

	updateCountmatFromSites(countmat, sites, gibbs->data);

	//Trial samplings
	for(i = 0; i < gibbs->trialIters; i++) {
		updateOneIter(countmat, sites, gibbs, TRIAL);
		iters++;
	}

	//Normal samplings (the "real" samplings)
	while(iterplat < gibbs->iterPlateauLen) {
		if(DEBUG1) {
			fprintf(stderr, "\nRun %d, iter %d, span %d\n", rnode->runId, 
				iters, countmat->span);
		}

		//sampling section
		entropy = updateOneIter(countmat, sites, gibbs, NORMAL);
		if(entropy > maxval) {
			maxval = entropy;
			rnode->score = entropy;
			copyProfile(bestcountmat, countmat);
			for(i = 0; i < gibbs->numseqs; i++) {
				bestsites[i] = sites[i];
			}
			iterplat = 0;
		}
		else {
			iterplat++;
		}

		if(DEBUG1) {
			fprintf(stderr, "entropy %.4lf\n", entropy);
			for(i = 0; i < gibbs->numseqs; i++) {
				fprintf(stderr, "sites[%d] %d\n", i, sites[i]);
			}
			printProfile(stderr, countmat);
		}

		//phase shift
		if(Random() < gibbs->phaseShiftFreq) {
			attemptPhaseShift(countmat, sites, gibbs->data);
		}

		iters++; 
	}

	gibbs->totalIters += iters;

	free(sites);
	nilProfile(countmat);
}

/**
* Main engine of Gibbs-sampling
* 
* return: 
*   returns the run-node of the best run
**/
RunNode* runGibbs(Gibbs *gibbs) {
	int r;

	for(r = 0; r < gibbs->numruns; r++) {
		//initspan and maxspan are the same in GibbsILR 
		int initspan = gibbs->span;
		int maxspan = gibbs->span;
		RunNode *rnode = createRunNode(r, gibbs->data, initspan, maxspan); 

		//new set of positions (sites) by random
		setRandomSites(rnode->sites, initspan, gibbs->data);
		runIters(rnode, initspan, gibbs);
		
		//add to the list of gibbs->runset
		if(gibbs->runset->len > 0) {
			rnode->next = gibbs->runset->head;
		}
		gibbs->runset->head = rnode;
		gibbs->runset->len++;
	}

	if(DEBUG0) {
		assert(gibbs->runset->len == gibbs->numruns);
	}
	
	RunNode *node; 

	//EM 
	for(node = gibbs->runset->head; node!= NULL; node = node->next) {
		if(DEBUG3) {
			fprintf(stderr, "runid: %03d\n", node->runId);
			fprintf(stderr,"Before EM:\n");
			printCountmat(stderr, node->countmat);
		}

		copyProfile(node->pswm, node->countmat); //copy the gaps
		//update the pswm with pseudocount (frequency weight matrix)
		updatePswmFromSites(node->pswm, node->sites, gibbs->data, FALSE);

		runEmSteps(node->pswm, gibbs->data, gibbs->ems, gibbs->emStep, TRUE); 
		findBestSitesFromPswm(node->pswm, node->sites, gibbs->data);
		updateCountmatFromSites(node->countmat, node->sites, gibbs->data);

		node->score = computeIlrFromCount(node->countmat, gibbs->data);

		if(DEBUG3) {
			fprintf(stderr,"After EM:\n");
			int i;
			for(i = 0; i < gibbs->numseqs; i++) {
				fprintf(stderr, "sites[%2d] %4d\n", i, node->sites[i]);
			}
			fprintf(stderr,"Score: %.2lf\n", node->score);
			printCountmat(stderr, node->countmat);
			fprintf(stderr,"\n");
		}
	}

	RunNode *bestnode = NULL;
	double maxval = -DBL_MAX; //favors largest values 
	for(node = gibbs->runset->head; node!= NULL; node = node->next) {
		if(maxval < node->score) {
			maxval = node->score;
			bestnode = node;
		}
	}

	if(DEBUG3) {
		fprintf(stderr, "\nBest among all runs:\n");
		fprintf(stderr, "runid: %03d\n", bestnode->runId);
		printCountmat(stderr, bestnode->countmat);
	}

	return bestnode;
}
void nilGibbs(Gibbs *gibbs ) {
	free(gibbs->posScore);
	nilRunSet(gibbs->runset) ;
	nilDataset(gibbs->data);
	nilEmStruct(gibbs->ems, gibbs->data);
	free(gibbs);
}
