#include "gibbs_util.h"


//----------------------------------------------------------------------
// RunNode/RunSet
//----------------------------------------------------------------------
RunNode* createRunNode(int runId, Dataset *data, int initspan, int maxspan) {
	RunNode *rnode = (RunNode*) malloc(sizeof(RunNode));
	rnode->countmat = initProfile(initspan, maxspan);
	rnode->pswm = initProfile(initspan, maxspan);
	rnode->next = NULL;
	rnode->runId = runId;
	rnode->sites = (int*) malloc( data->numseqs * sizeof(int));
	return rnode;
}

void nilRunNode(RunNode *rnode) {
	free(rnode->sites);
	nilProfile(rnode->countmat);
	free(rnode);
}

RunSet* createRunSet() {
	RunSet *rset = (RunSet*) malloc(sizeof(RunSet));
	rset->head = NULL;
	rset->len = 0;
	return rset;
}

void nilRunSet(RunSet *rset) {
	RunNode* garbage;
	while(rset->head != NULL) {
		garbage = rset->head;
		rset->head = rset->head->next;
		nilRunNode(garbage);
	}
	free(rset);
}


//----------------------------------------------------------------------
// Scoring functions
//----------------------------------------------------------------------

//compute ILR of "pswm"
//this scoring method appears in Ng P., et al. "Apples to apples: ..."
double computeIlrFromPswm(Profile *pswm, Dataset *data) {
	int i,j,m,a;
	double sum; 
	double ilrval;
	double cursor;
	double oddsratioMat[MAX_MOTIF_WIDTH][NUMALPHAS];

	for(m = 0; m < pswm->span; m++) {
		for(a = 0; a < data->numalphas; a++) {
			oddsratioMat[m][a] = pswm->mat[m][a] / data->bgfreq[a];
		}
	}

	ilrval = 0.0;
	for(i = 0; i < data->numseqs; i++) {
		sum = 0.0;
		for(j = 0; j < data->seqLen[i] - pswm->span + 1; j++) {
			cursor = 1.0;
			for(m = 0; m < pswm->span; m++) {
				cursor *= oddsratioMat[m][ ((int)data->seqs[i][j+m]) ];
			}
			sum+= cursor;
		}
		ilrval += log(sum) - log(data->seqLen[i] - pswm->span + 1);
	}

	return ilrval; 
}

//compute ILR of count-matrix (WITHOUT pseudocount)
//this scoring method appears in Ng P., et al. "Apples to apples: ..."
double computeIlrFromCount(Profile *countmat, Dataset *data) {
	int i,j,m,a;
	double sum; 
	double ilrval;
	double cursor;
	double oddsratioMat[MAX_MOTIF_WIDTH][NUMALPHAS];

	for(m = 0; m < countmat->span; m++) {
		for(a = 0; a < data->numalphas; a++) {
			oddsratioMat[m][a] = countmat->mat[m][a] / (data->numseqs * data->bgfreq[a]);
		}
	}

	ilrval = 0.0;
	for(i = 0; i < data->numseqs; i++) {
		sum = 0.0;
		for(j = 0; j < data->seqLen[i] - countmat->span + 1; j++) {
			cursor = 1.0;
			for(m = 0; m < countmat->span; m++) {
				cursor *= oddsratioMat[m][ ((int)data->seqs[i][j+m]) ];
			}
			sum+= cursor;
		}
		ilrval += log(sum) - log(data->seqLen[i] - countmat->span + 1);
	}

	return ilrval; 
}



//compute "real" entropy from the motif-finding literature
double computeEntropyFromCount(Profile *countmat, Dataset *data) {
	int i,j;
	double entropy = 0.0;
	for( i = 0; i < countmat->span; i++) {
		for(j = 0; j < data->numalphas; j++) {
			if(countmat->mat[i][j] > 0.5) { //not equal to zero
				entropy += countmat->mat[i][j] * 
					log(countmat->mat[i][j] / (data->numseqs * data->bgfreq[j]));
			}
		}
	}
	return entropy;
}

//compute KL-divergence
double computeEntropyFromPswm(Profile *pswm, Dataset *data) {
	int i,j;
	double entropy = 0.0;
	for( i = 0; i < pswm->span; i++) {
		for(j = 0; j < data->numalphas; j++) {
			entropy += pswm->mat[i][j] * log(pswm->mat[i][j] / data->bgfreq[j]);
		}
	}
	return entropy;
}


double computeEntropyOneColumn(double count[NUMALPHAS], Dataset *data) {
	int i;
	double entropy = 0.0;
	for(i = 0; i < data->numalphas; i++) {
		if(count[i] > 0.5) { //non-zero
			entropy += count[i] * log( count[i] / (data->numseqs * data->bgfreq[i]));
		}
	}
	return entropy;
}


char* scoreMetricToStr(enum ScoreMetric metric) {
	if(metric == ILR) {
		return "ILR";
	}
	else if(metric == ENTROPY){
		return "Entropy";
	}
	else {
		return "ERROR";
	}
}

//----------------------------------------------------------------------
// Matrix/Profile updates
//----------------------------------------------------------------------

//Warning: Not designed for profiles with gaps
void updateCountmatFromSites(Profile *countmat, int *sites, Dataset *data) {
	int i,m,a;

	if(DEBUG0) {
		for(i = 0; i < data->numseqs; i++) {
			if(sites[i] < 0 || sites[i] + countmat->span - 1 >= data->seqLen[i]) {
				fprintf(stderr, "Error: sites out of range at updateCountMat.\n");
				fprintf(stderr, "sites[%d] %d\n", i, sites[i]);
				exit(1);
			}
		}
	}

	for( m =0; m < countmat->span; m++) {
		for(a = 0; a < data->numalphas; a++) {
			countmat->mat[m][a] = 0.0;
		}
	}
	for(i = 0; i < data->numseqs; i++) {
		for( m =0; m < countmat->span; m++) {
			countmat->mat[m][ ((int)data->seqs[i][(m+sites[i])]) ] += 1.0;
		}
	}
	if(DEBUG0) {
		double sum;
		for(m = 0; m < countmat->span; m++) {
			sum = 0.0;
			for(a = 0; a < data->numalphas ; a++) {
				sum += countmat->mat[m][a];
			}
			if(fabs(sum - data->numseqs) > 0.000001) {
				fprintf(stderr, "Error: countmat inconsistent at updateCountmat().\n");
				exit(1);
			}
		}

		for(i = 0; i < data->numseqs; i++) {
			if(sites[i] < 0 || sites[i] + countmat->span - 1 >= data->seqLen[i]) {
				fprintf(stderr, "Error: sites out of range at updateCountMat.\n");
				exit(1);
			}
		}
	}

}

//May need to run CopyProfile() before using this function to get the correct span for pswm
//Warning: Not designed for profiles with gaps
void updatePswmFromSites(Profile *pswm, int *sites, Dataset *data, boolean usePseudocount) {
	int i,j;
	updateCountmatFromSites(pswm, sites, data);
	double pseudocount[NUMALPHAS];
	double sumOfPseudocount;
	if(usePseudocount) {
		for(j = 0; j < data->numalphas; j++) {
			pseudocount[j] = data->pseudocount[j];
		}
		sumOfPseudocount = data->sumOfPseudocount;
	}
	else {
		for(j = 0; j < data->numalphas; j++) {
			pseudocount[j] = 0.0;
		}
		sumOfPseudocount = 0.0;
	}

	for(i = 0; i < pswm->span; i++) {
		for(j = 0; j < data->numalphas; j++) {
			pswm->mat[i][j] = (pswm->mat[i][j] + pseudocount[j]) 
				/ (data->numseqs + sumOfPseudocount);
		}
	}
}

//Warning: Not designed for profiles with gaps
boolean validCountmatWithSites(Profile *countmat, int *sites, Dataset *data) {
	int i,m,a;
	int cmat[MAX_MOTIF_WIDTH][NUMALPHAS];

	for( m =0; m < countmat->span; m++) {
		for(a = 0; a < data->numalphas; a++) {
			cmat[m][a] = 0;
		}
	}
	for(i = 0; i < data->numseqs; i++) {
		for( m =0; m < countmat->span; m++) {
			cmat[m][ ((int)data->seqs[i][(m+sites[i])]) ] ++;
		}
	}

	for( m =0; m < countmat->span; m++) {
		for(a = 0; a < data->numalphas; a++) {
			if(fabs((double)cmat[m][a] - countmat->mat[m][a]) > 0.0001) {
				return FALSE;
			}
		}
	}
	return TRUE;
}

//----------------------------------------------------------------------
// add/remove/set sites
//----------------------------------------------------------------------

void setRandomSites(int *sites, int initspan, Dataset *data) {
	int i;
	for (i = 0; i < data->numseqs; i++) {
		int range = (data->seqLen[i] - initspan + 1);
		sites[i] = (int)(Random() * range);

		if(DEBUG0) {
			if(sites[i] < 0 || sites[i] >= range) {
				fprintf(stderr, "Error: random generator error at setRandomSites()");
				exit(1);
			}
		}
	}
}

void addSite(int newsite, int seqind, int *sites, Profile *countmat, Dataset *data) {
	if(DEBUG0) {
		assert(sites[seqind] == EMPTY_SITE);
	}

	int m;
	for( m =0; m < countmat->span; m++) {
		countmat->mat[m][ ((int)data->seqs[seqind][(m+newsite)]) ] += 1.0;
	}
	sites[seqind] = newsite;
}

void removeSite(int seqind, int *sites, Profile *countmat, Dataset *data) {
	if(DEBUG0) {
		assert(sites[seqind] != EMPTY_SITE);
	}

	int m;
	for( m =0; m < countmat->span; m++) {
		countmat->mat[m][ ((int)data->seqs[seqind][(m+sites[seqind])]) ] -= 1.0;
	}
	sites[seqind] = EMPTY_SITE;
}

void findBestSitesFromPswm(Profile *pswm, int *sites, Dataset *data) {
	int i, j,m,a;
	int maxind;
	double maxval, scoreCur;

	double oddsratioMat[MAX_MOTIF_WIDTH][NUMALPHAS];

	for(m = 0; m < pswm->span; m++) {
		for(a = 0; a < data->numalphas; a++) {
			oddsratioMat[m][a] = pswm->mat[m][a] / data->bgfreq[a];
		}
	}
	for(i=0; i< data->numseqs; i++)
	{
		maxind = -1; maxval = -DBL_MAX;
		for(j=0; j<data->seqLen[i]- pswm->span +1; j++)
		{
			scoreCur = 1.0;
			for(m=0; m< pswm->span; m++) {
				scoreCur *= oddsratioMat[m][ ((int)data->seqs[i][j+m]) ];
			}
			if(scoreCur > maxval) {
				maxval = scoreCur;
				maxind = j;
			}
		}
		sites[i] = maxind;
	}
}


//----------------------------------------------------------------------
// Phase Shifting
//----------------------------------------------------------------------
enum PhaseShift {NO_SHIFT, LEFT, RIGHT}; 

/**
* Uses Metropolis algorithm on entropy of the shifted column
* 
* param: 
*   countmat - count matrix to apply phase shift
*   sites - sites to apply phase shift
**/
void attemptPhaseShift(Profile *countmat, int *sites, Dataset *data) {
	int i,a;

	//equal chance for left or right
	enum PhaseShift shift = ((Random() < 0.5) ? LEFT : RIGHT); 

	double newval[NUMALPHAS]; //newly attempted column to be used when shifted
	double *oldval = NULL; //old column to be remove out when shifted
	for(a = 0; a < NUMALPHAS; a++) {
		newval[a] = 0.0;
	}

	if(shift == LEFT) {
		for(i = 0; i < data->numseqs; i++) {
			if(sites[i] - 1 >= 0) {
				newval[((int)data->seqs[i][ sites[i]-1 ])] += 1.0;
			}
			else {
				//out of bound -- cannot shift left
				return;
			}
		}
		//last column is removed during a left-shift
		oldval = countmat->mat[countmat->span-1]; 
	}
	else if(shift == RIGHT) {
		for(i = 0; i < data->numseqs; i++) {
			if(sites[i] + countmat->span  < data->seqLen[i]) {
				newval[((int)data->seqs[i][ sites[i]+countmat->span ])] += 1.0;
			}
			else {
				//out of bound -- cannot shift right
				return;
			}
		}
		//first column is removed during a right-shift
		oldval = countmat->mat[0];
	}
	double newent = computeEntropyOneColumn(newval, data); //new entropy
	double oldent = computeEntropyOneColumn(oldval, data); 
	//if newent > oldent, then shift
	//if newent < oldent, then shift with prob exp(newent-oldent) 
	//[could introduce a temperature here]
	shift = (Random() < exp(newent - oldent) ? shift : NO_SHIFT);

	if(shift != NO_SHIFT){
		//diffShift is +1 for right shift; -1 for left shift
		int diffShift = (shift == LEFT ? -1 : +1);

		//shift count matrix
		if(shift == LEFT) {
			shiftProfileLeft(countmat, newval);
		}
		else if(shift == RIGHT) {
			shiftProfileRight(countmat, newval);
		}

		//shift sites
		for(i = 0; i < data->numseqs; i++) {
			sites[i] = sites[i] + diffShift; 
		}
	}

	if(DEBUG1) {
		fprintf(stderr, "shift with probability: %lf\n", exp(newent - oldent));
	}
}


