

//#include <vector>
//#include <fstream>
//#include <iostream>
#include <cstdlib>
#include <cmath>

//#include "data_type_api.h"
//#include "DenseVector.h"
//#include "Sparm.h"
//#include "Util.h"
#include "common.h"

// small regularization for offset b
#define B_CONSTANT 100



vector<double> compute_survival_pdf(const DenseVector &w, const SparseVector &x, const Sparm &sparm)
{
	int maxMonth = sparm.GetMaxMonth();

	vector<double> scores(maxMonth+1,0); 

	double sum=0;
	double max_score = 0;
	for (int i=1; i<maxMonth+1; i++)
	{
		size_t bias_offset = maxMonth*sparm.GetOriginalSizePsi() + (i-1); 
		size_t offset = (i-1)*sparm.GetOriginalSizePsi(); 
		scores[i] = sprod_ns(w,x, offset) + w[bias_offset]*B_CONSTANT;
		// sum stores the scores of the sequence of all 0's
		sum += scores[i];
		if (scores[i]>max_score) max_score = scores[i]; 
	}

	vector<double> probs(maxMonth+1,0); 
	probs[0] = sum; 
	double z=1;


	double running_score=sum; 
	for (int i=1; i<maxMonth+1; i++) 
	{
		running_score -= scores[i]; // flipping signs
		probs[i] = running_score; 
		z += exp(running_score-sum); 
	}
	double log_z = log(z) + sum; 
	for (int i=0; i<maxMonth+1; i++) 
	{
		probs[i] = exp(probs[i] - log_z); 
	}

	return(probs); 
}


double rae_loss(double prediction, double truth)
{
	double ans; 
	if (prediction>0) {
		ans = fabs(((double) (prediction-truth))/prediction); 
	} else {
		ans = 1; 
	}
	if (ans>1) ans = 1; 

	return(ans);
}


double l2_loss(double prediction, double truth)
{
	return((double) (prediction-truth)*(prediction-truth)); 
}


double l1_loss(double prediction, double truth)
{
	return(fabs((double) (prediction-truth))); 
}


double l2_log_loss(double prediction, double truth)
{
	double log_prediction, log_truth;
	if (prediction==0) {
		// substitute e^-1 for 0; acceptable for survival time prediction
		log_prediction=-1; 
	} else {
		log_prediction=log((double) prediction); 
	}
	if (truth==0) {
		log_truth=-1; 
	} else {
		log_truth=log((double) truth); 
	}
	return((log_prediction-log_truth)*(log_prediction-log_truth)); 
}


double l1_log_loss(double prediction, double truth)
{
	double log_prediction, log_truth;
	if (prediction==0) {
		log_prediction=-1; 
	} else {
		log_prediction=log((double) prediction); 
	}
	if (truth==0) {
		log_truth=-1; 
	} else {
		log_truth=log((double) truth); 
	}

	return(fabs(log_prediction-log_truth)); 
}

double class_loss(double prediction, double truth, double threshold)
{
	if (((prediction>threshold)&&(truth>threshold))||((prediction<=threshold)&&(truth<=threshold)))
	{
		return(0); 
	} else {
		return(1); 
	}
}


DenseVector* remap_weights(DenseVector *w, Sparm &sparm)
{
	const int maxMonth = sparm.GetMaxMonth(); 
	const size_t sizePsi = sparm.GetSizePsi(); 
	const size_t oriSizePsi = sparm.GetOriginalSizePsi(); 

	size_t trainSizePsi = (w->dim()-maxMonth)/maxMonth; 	

	if (trainSizePsi>oriSizePsi)
	{
		sparm.SetSizePsi(trainSizePsi); 
		return w;
	}

	DenseVector *new_w = new DenseVector(sizePsi); 
	for (int i=0; i<maxMonth; i++) 
	{
		size_t train_bias_offset = maxMonth*trainSizePsi + i; 
		size_t train_offset = i*trainSizePsi; 

		size_t bias_offset = maxMonth*oriSizePsi + i; 
		size_t offset = i*oriSizePsi; 

		for (size_t j=0; j<trainSizePsi; j++)
		{
			(*new_w)[offset+j] = (*w)[train_offset+j]; 
		}
		(*new_w)[bias_offset] = (*w)[train_bias_offset]; 
	}

	delete w; 

	return new_w; 
}


int main(int argc, char* argv[])
{

	vector<Example> test_sample; 
	Sparm sparm; 

	sparm.ReadParam(argc, argv); 
	test_sample = read_input_examples(sparm.GetInputFile(), sparm); 

	// load weights
	DenseVector *w = LoadWeights(sparm.GetModelFile()); 
	
	const int maxMonth = sparm.GetMaxMonth(); 

	// re-map weights
	if (w->dim()!=sparm.GetSizePsi())
	{
		w = remap_weights(w, sparm); 
	}

	// evaluation
	double avg_l1_loss = 0;
	double avg_l2_loss = 0;
	double avg_rae_loss = 0;
	double avg_l1_log_loss = 0;
	double avg_l2_log_loss = 0; 
	double avg_log_loss = 0; 

	vector<double>& qt_time = sparm.GetQuantTime(); 

	// pre-compute contingency table
	vector<vector<double> > contingency_table(maxMonth+1, vector<double>(maxMonth+1,0)); 
	for (int j=1; j<maxMonth+1; j++)
	{
		for (int k=1; k<maxMonth+1; k++)
		{
			if (sparm.GetLossType()=="rae")	{
				contingency_table[j][k] = rae_loss(qt_time[j-1],qt_time[k-1]); 
			} else if (sparm.GetLossType()=="l1") {
				contingency_table[j][k] = l1_loss(qt_time[j-1],qt_time[k-1]); 
			} else if (sparm.GetLossType()=="l2") {
				contingency_table[j][k] = l2_loss(qt_time[j-1],qt_time[k-1]); 
			} else if (sparm.GetLossType()=="l1_log") {
				contingency_table[j][k] = l1_log_loss(qt_time[j-1],qt_time[k-1]); 
			} else if (sparm.GetLossType()=="l2_log") {
				contingency_table[j][k] = l2_log_loss(qt_time[j-1],qt_time[k-1]); 
			} else if (sparm.GetLossType()=="class") {
				double threshold = sparm.GetThreshold(); // classification
				contingency_table[j][k] = class_loss(qt_time[j-1],qt_time[k-1],threshold); 
			} else {
				cout << "Unknown loss type '" << sparm.GetLossType() << "'!" << endl; 
				exit(1); 
			}
		}
	}

	// accuracy and calibration
	double acc=0; 
	double mse=0; 
	long effective_sample_size=0; 

	// now classify the instances
	// first compute the probability
	for (size_t i=0; i<test_sample.size(); i++)
	{
		// return a vector a prob
		vector<double> survival_pdf = compute_survival_pdf(*w, test_sample[i].m_x.features, sparm); 

		// then use the distribution to optimize for different loss
		// can use this to fill in a (maxMonth+1)*(maxMonth+1) table
		vector<double> expected_loss(maxMonth+1, 0); 
		for (int j=1; j<maxMonth+1; j++)
		{
			for (int k=1; k<maxMonth+1; k++)
			{
				expected_loss[j] += survival_pdf[k]*contingency_table[j][k]; 
			}
		}
		// take min as prediction
		int best_prediction = 1; 
		for (int j=2; j<maxMonth+1; j++)
		{
			if (expected_loss[j]<expected_loss[best_prediction])
			{
				best_prediction=j; 
			}
		}
		// output prediction to std out
		if (sparm.GetLossType()!="class")
		{
			cout << qt_time[best_prediction-1]; 
		} else { // classification
			if (qt_time[best_prediction-1]<sparm.GetThreshold())
			{
				cout << 0 << ":" << 1-expected_loss[best_prediction]; // event (death) has not occurred yet
			} else {
				cout << 1 << ":" << expected_loss[best_prediction];  // event has occurred
			}
		}
		
		// print probabilities for plotting survival cdf
		if (sparm.GetPrintProb()) 
		{
			double surv_prob = 1; 
			for (int j=0; j<maxMonth+1; j++) 
			{
				surv_prob -= survival_pdf[j];
				cout << ", " << surv_prob;
			}
		}
		cout << endl; 

		// accuracy and calibration
		double p=0; 
		if (sparm.GetLossType()=="class")
		{
			int k=1; 
			p = survival_pdf[0]; 
			while ((k<sparm.GetMaxMonth()+1)&&(qt_time[k-1]<sparm.GetThreshold()))
			{
				p += survival_pdf[k]; 
				k++; 
			}
			// survival prob
			p = 1-p; 
			if (!(((test_sample[i].m_y.censored)&&(sparm.GetTrainUncensored()==0))&&(test_sample[i].m_y.original_survival_time<sparm.GetThreshold())))
			{
				effective_sample_size++; 
				if ((test_sample[i].m_y.original_survival_time-sparm.GetThreshold())*(p-0.5)>0)
				{
					acc+=1.0; 
				}
				int label = (test_sample[i].m_y.original_survival_time>sparm.GetThreshold()) ? 1 : 0;
				mse += (label-p)*(label-p); 
			} 
		}

		double a = qt_time[best_prediction-1]; 
		double b = test_sample[i].m_y.original_survival_time;
		if ((a<=b)||(!((test_sample[i].m_y.censored)&&(sparm.GetTrainUncensored()==0)))) {
			avg_l1_loss += l1_loss(a,b); 
			avg_l2_loss += l2_loss(a,b); 
			avg_rae_loss += rae_loss(a,b); 
			avg_l1_log_loss += l1_log_loss(a,b); 
			avg_l2_log_loss += l2_log_loss(a,b); 
		} // else the loss is 0; do nothing

		if (test_sample[i].m_y.censored)
		{
			double sum=0; 
			for (int j=test_sample[i].m_y.event_time; j<maxMonth+1; j++)
			{
				sum += survival_pdf[j];
			}
			avg_log_loss -= log(sum); 
		} else {
			avg_log_loss -= log(survival_pdf[test_sample[i].m_y.event_time]); 
		}
	}

	if (sparm.GetLossType()=="class")
	{
		cout << "#avg acc at threshold " << sparm.GetThreshold() << ": " << acc/effective_sample_size << endl; 
		cout << "#avg mse at threshold " << sparm.GetThreshold() << ": " << mse/effective_sample_size << endl; 
	}
	// output summary 
	cout << "#avg l1-loss: " << avg_l1_loss/test_sample.size() << endl; 
	cout << "#avg l2-loss: " << avg_l2_loss/test_sample.size() << endl; 
	cout << "#avg rae-loss: " << avg_rae_loss/test_sample.size() << endl; 
	cout << "#avg l1-log-loss: " << avg_l1_log_loss/test_sample.size() << endl; 
	cout << "#avg l2-log-loss: " << avg_l2_log_loss/test_sample.size() << endl; 
	cout << "#avg log-likelihood loss: " << avg_log_loss/test_sample.size() << endl; 


	delete w;

	return(0);
}
