#include "dataset.h"

#define print_error(str) for(fprintf(stderr,"%s\n",str); TRUE; exit(1))

static
FILE *openFile(char *filename)
{
	FILE *fptr;
	if((fptr = fopen(filename,"r")) == NULL) {
		fprintf(stderr,"Could not open file \"%s\"\n",filename);
		print_error("File does not exist!\n");
	}
	return fptr;
}

static
char alphaToNum(char alpha) {
	switch (alpha) {
		case 'A': return 0; 
		case 'C': return 1; 
		case 'G': return 2; 
		case 'T': return 3; 
	}
	print_error("ERROR: alphaToNum has incorrect input\n");
}

char numToAlpha(char num) {
	switch (num) {
		case 0: return 'A'; 
		case 1: return 'C'; 
		case 2: return 'G'; 
		case 3: return 'T'; 
	}
	print_error("ERROR: numToAlpha has incorrect input\n");
}

static
boolean isAlpha(char alpha) {
	return(alpha == 'A'
		|| alpha == 'C'
		|| alpha == 'G'
		|| alpha == 'T');
}

//returns the length of the sequence
//This function can easily leads to buffer overflow, but I don't care about security now. 
static
int readSingleSeq(FILE *fptr, Dataset *data, int seqInd) {
	int siteInd = 0;
	int num;
	char c;
	while((c = fgetc(fptr)) != '>' && c!= EOF) {
		if( c == '\n' || c== '\r') {
			continue;
		}
		if(siteInd >= data->seqLen[seqInd]) {
			print_error("ERROR: exceeds seqlen\n");
		}
		num = alphaToNum(c);
		data->seqs[seqInd][siteInd] = num;
		data->count[num]++;
		data->total++;
		siteInd++;
	}
	ungetc(c, fptr);

	if(DEBUG0) {
		int i;
		int sum = 0;
		for(i = 0; i < 4; i++) {
			sum+=data->count[i];
		}
		if(data->total != sum) {
			printf("Error: data->total != sum in readSingleSeq()\n");
		}
	}

	return siteInd;
}

static 
void readSeqs(FILE *fptr, Dataset *data) {
	char c;
	int seqInd = 0;
	int siteInd;
	boolean inHeader = FALSE;
	data->total = 0;
	while ((c = fgetc(fptr)) != EOF) {
		if( c == '>' ) {
			inHeader = TRUE;
			if(seqInd >= data->numseqs) {
				print_error("ERROR in readSeqs: number of seqs exceed max.");
			}
		}
		else if ( inHeader && c == '\n') {
			inHeader = FALSE;
			siteInd = readSingleSeq(fptr, data, seqInd);
			if(siteInd != data->seqLen[seqInd]) {
				print_error("ERROR: length of seqs does not match");
			}
			seqInd++;
		}
	}
}


static
int getNumOfSeqs(FILE *fptr) {
	char c;
	int seqInd = 0;
	while ((c = fgetc(fptr)) != EOF) {
		if( c == '>' ) {
			seqInd++;
		}
	}
	return seqInd;
}

static
void getLenOfSeqs(FILE *fptr, Dataset *data) {
	char c;
	int seqInd = -1;
	boolean inHeader = FALSE;
	boolean inSeq = FALSE;
	while ((c = fgetc(fptr)) != EOF) {
		if( c == '>' ) {
			inHeader = TRUE;
			inSeq = FALSE;
			seqInd++;
			if(seqInd >= data->numseqs) {
				print_error("ERROR in getLenOfSeqs: number of seqs exceed max.");
			}
			data->seqLen[seqInd] = 0;		
		}
		else if ( inHeader && c == '\n') {
			inHeader = FALSE;
			inSeq = TRUE;
		}
		else if (inSeq && isAlpha(c)) {
			data->seqLen[seqInd]++;	
		}
	}
}

void initPseudocount(Dataset *data, double pseudoweight) {
	//pseudocount[a] = numOfSeq * pseudoweight * counts[a] ./ sum(counts);
	int a;
	data->pseudocount = (double*) malloc(data->numalphas * sizeof(double));
	data->pseudoweight = pseudoweight;
	data->sumOfPseudocount = 0;
	for(a = 0; a<data->numalphas; a++) {
		data->pseudocount[a] = data->numseqs * pseudoweight * data->bgfreq[a];
		data->sumOfPseudocount += data->pseudocount[a];
	}
}


Dataset *openDataset(char *filename, int numalphas)
{
	Dataset *data;
	FILE *fptr;
	int i;

	fptr = openFile(filename);
	data = (Dataset*) malloc(sizeof(Dataset));;

	//number of sequence
	data->numseqs = getNumOfSeqs(fptr);
	rewind(fptr);

	data->numalphas = numalphas;
	data->count = (int*) malloc(data->numalphas * sizeof(int));
	data->seqLen = (int*) malloc( data->numseqs * sizeof(int));
	data->seqs = (char**) malloc(data->numseqs * sizeof(char*));

	//len per seq
	getLenOfSeqs(fptr, data);
	rewind(fptr);
	for(i = 0; i < data->numseqs; i++) {
		data->seqs[i] = (char*) malloc(data->seqLen[i] * sizeof(char));
	}
	
	readSeqs(fptr, data);
	fclose(fptr);

	data->maxSeqLen = 0;
	for(i = 0; i < data->numseqs; i++) {
		if(data->maxSeqLen < data->seqLen[i]) {
			data->maxSeqLen = data->seqLen[i];
		}
	}

	if(DEBUG2) {
		int j;
		for(i = 0; i < data->numseqs; i++) {
			for(j = 0; j < data->seqLen[i]; j++) {
				fprintf(stderr, "%d", data->seqs[i][j]);
			}
			fprintf(stderr, "\n");
		}

	}

	int a;
	data->bgfreq = (double*) malloc(sizeof(double) * data->numalphas);
	for(a = 0; a< data->numalphas; a++) {
		data->bgfreq[a] = data->count[a]/ ((double)data->total);
	}

	return data;
}

void printBackgroundFreq(FILE *fptr, Dataset *data) {
	int i;
	for(i = 0; i < data->numalphas; i++) {
		fprintf(fptr, "Background freq of %c: %.4lf\n", numToAlpha(i), 
			data->bgfreq[i]);
	}
	fprintf(fptr, "\n");
}

void nilDataset(Dataset *data) {
	free(data->pseudocount);
	free(data->count);
	free(data->seqLen);
	free(data->bgfreq);
	
	int i;
	for(i = 0; i< data->numseqs; i++) {
		free(data->seqs[i]);
	}
	free(data->seqs);
	free(data);
}

