
#include <vector>
#include <string>
#include <list>
#include <fstream>
#include <iostream>
#include <iomanip>
#include <cstdlib>
#include <cmath>
#include <cassert>

//#include "data_type_api.h"
//#include "Sparm.h"
//#include "DenseVector.h"
//#include "Util.h"
#include "common.h"

#define LINE_SEARCH_C1 1E-4
#define LINE_SEARCH_C2 0.9
#define ALPHA_MAX 1024
#define MAX_ITER 1000

// small amount of regularization for offset b
#define B_CONSTANT 100

using namespace std;

double compute_gradient_obj(DenseVector &w, const vector<Example> &sample, DenseVector &g, const Sparm &sparm); 
double wolfe_line_search(const DenseVector &w, double v, const DenseVector &g, const DenseVector &r, const vector<Example> &sample, const Sparm &sparm);



// compute gradient for the smoothing term
void smoothing_grad(DenseVector &w, DenseVector &g, const Sparm &sparm)
{
	size_t oriSizePsi = sparm.GetOriginalSizePsi();
	double C2 = sparm.GetC2(); 

	for (int i=1; i<sparm.GetMaxMonth()-1; i++)
	{
		size_t offset1 = i*oriSizePsi; 
		size_t offset2 = (i+1)*oriSizePsi; 
		size_t offset3 = (i-1)*oriSizePsi; 

		for (size_t j=0; j<oriSizePsi; j++)
		{
			g[offset1+j] += C2*(w[offset1+j] - w[offset2+j]);
			g[offset1+j] += C2*(w[offset1+j] - w[offset3+j]);
		}
	}
	// boundary cases
	size_t offset1 = 0;
	size_t offset2 = oriSizePsi;
	for (size_t j=0; j<oriSizePsi; j++)
	{
		g[offset1+j] += C2*(w[offset1+j] - w[offset2+j]);
	}

	offset1 = (sparm.GetMaxMonth()-1)*oriSizePsi;
	offset2 = (sparm.GetMaxMonth()-2)*oriSizePsi;
	for (size_t j=0; j<oriSizePsi; j++)
	{
		g[offset1+j] += C2*(w[offset1+j] - w[offset2+j]);
	}

}


// multi-task smoothing term 
double smoothing_obj(DenseVector &w, const Sparm &sparm)
{
	double ans=0; 
	size_t oriSizePsi = sparm.GetOriginalSizePsi(); 

	for (int i=0; i<sparm.GetMaxMonth()-1; i++)
	{
		size_t offset1 = i*oriSizePsi; 
		size_t offset2 = (i+1)*oriSizePsi; 
		for (size_t j=0; j<oriSizePsi; j++)
		{
			ans += (w[offset1+j] - w[offset2+j])*(w[offset1+j] - w[offset2+j]);
		}
	}
	return(ans); 
}



int WriteWeights(const DenseVector &w, string filename)
{
	ofstream outfile(filename.c_str());
	if (outfile) {
		outfile << "DIM:" << w.dim() << std::endl; 
		for (size_t i=1; i<w.dim(); i++)
		{
			outfile << i << ":" << setprecision(8) << w[i] << endl; 
		}
		outfile.close(); 
		return 0;
	} else {
		cerr << "Unable to open model file " << filename << " for output!" << endl;  
		return -1;
	}
}





int main(int argc, char* argv[])
{
	vector<Example> sample; 
	Sparm sparm; 

	sparm.ReadParam(argc, argv); 
	sample = read_input_examples(sparm.GetInputFile(), sparm); 
	
	size_t sizePsi = sparm.GetSizePsi(); 

	// variables for L-BFGS
	const size_t bundle_size = sparm.GetBundleSize(); 
	list<DenseVector*> s_list;
	list<DenseVector*> y_list;
	list<double> rho_list; 
	DenseVector *last_g=NULL, *last_w=NULL; 
	DenseVector *g=NULL, *w=NULL;

	bool use_lbfgs = true; 


	if ((sparm.GetTrainUncensored()==0)&&(sparm.GetInitialWeightFile()!=""))
	{
		w = LoadWeights(sparm.GetInitialWeightFile()); 
		assert(w->dim()==sizePsi); 
	} else {
		w = new DenseVector(sizePsi); 
	}

	int outer_iter=0; 
	double v=1;
	double last_v=1;
	double relative_decrease=1;
	while ((outer_iter<MAX_ITER)&&((outer_iter<1)||(relative_decrease>1E-6)))
	{
		cout << "Iteration: " << outer_iter << endl; 

		delete last_g;
		last_g = g;
		last_v = v; 

		g = new DenseVector(sizePsi); 

		v = compute_gradient_obj(*w, sample, *g, sparm);
		cout << "objective (with regularizer): " << v << endl; 

		if (outer_iter>0) {
			relative_decrease = fabs((v-last_v)/last_v);
		} else {
			relative_decrease = 1; 
		}


		DenseVector *r = new DenseVector(sizePsi);

		if ((!use_lbfgs)||(outer_iter==0))
		{
			// search direction: r = -g, steepest descent
			multadd_nn(*r, *g, -1); 
		} else {
			// update s, y, rho
			s_list.push_back(new DenseVector(*w));
			y_list.push_back(new DenseVector(*g));
			multadd_nn(*(s_list.back()), *last_w, -1.0f); 
			multadd_nn(*(y_list.back()), *last_g, -1.0f);
			double sTy = sprod_nn(*(s_list.back()), *(y_list.back())); 
			double yTy = sprod_nn(*(y_list.back()), *(y_list.back())); 
			rho_list.push_back(1.0f/sTy); 
			double gamma_k = sTy/yTy;  
			if (rho_list.size()>bundle_size)
			{
				rho_list.erase(rho_list.begin());
				delete *(s_list.begin()); 
				s_list.erase(s_list.begin());
				delete *(y_list.begin()); 
				y_list.erase(y_list.begin()); 
			}

			list<double> a; 
			list<double>::const_iterator iter_rho = rho_list.end(); 
			list<DenseVector*>::const_iterator iter_s = s_list.end(); 
			list<DenseVector*>::const_iterator iter_y = y_list.end(); 

			multadd_nn(*r, *g, 1); 
			while (iter_rho!=rho_list.begin())
			{
				iter_rho--;
				iter_s--;
				iter_y--;
  
				double a_i = (*iter_rho)*sprod_nn(*(*iter_s), *r); 
				a.push_front(a_i); 
				multadd_nn(*r, *(*iter_y), -a_i);
			}

			smult_n(*r, 1.0f/gamma_k); 

			iter_rho = rho_list.begin(); 
			iter_s = s_list.begin(); 
			iter_y = y_list.begin(); 
			list<double>::const_iterator iter_a = a.begin(); 

			while (iter_rho!=rho_list.end())
			{
				double b = (*iter_rho)*sprod_nn(*r, *(*iter_y)); 
				multadd_nn(*r, *(*iter_s), (*iter_a)-b); 

				iter_rho++;
				iter_s++;
				iter_y++;
				iter_a++;  
			}

			smult_n(*r, -1.0f); 
		}

		double alpha = wolfe_line_search((const DenseVector &)*w, v, *g, *r, sample, sparm); 

		delete last_w;
		last_w = new DenseVector(*w); // copy before increment
		multadd_nn(*w, *r, alpha); 

		delete r; 

		outer_iter++; 
	}

	// write weights to file
	WriteWeights(*w, sparm.GetModelFile()); 

	// CLEANUP
	for (list<DenseVector*>::iterator iter_s = s_list.begin(); iter_s!=s_list.end(); ++iter_s)
	{
		delete *iter_s;
	}
	for (list<DenseVector*>::iterator iter_y = y_list.begin(); iter_y!=y_list.end(); ++iter_y)
	{
		delete *iter_y;
	}
	delete last_w; 
	delete last_g;
	delete g; 
	delete w; 

}


double compute_gradient_obj(DenseVector &w, const vector<Example> &sample, DenseVector &g, const Sparm &sparm)
{
	double v=0;
	int maxMonth = sparm.GetMaxMonth(); 

	for (size_t example_id=0; example_id<sample.size(); ++example_id)
	{
		int event_time = sample[example_id].m_y.event_time;  // 0 <= t <= maxMonth; t==maxMonth means no event during the observed period

		// there are maxMonth+1 possible sequences
		vector<double> scores(maxMonth+1, 0);
		double sum=0; 
		for (int i=1; i<maxMonth+1; i++)
		{
			size_t bias_offset = maxMonth*sparm.GetOriginalSizePsi() + (i-1); 
			size_t offset = (i-1)*sparm.GetOriginalSizePsi(); 
			scores[i] = sprod_ns(w,sample[example_id].m_x.features, offset) + w[bias_offset]*B_CONSTANT; // need offset here
			// sum stores the scores of the sequence of all 1's
			sum += scores[i]; 
		}

		vector<double> running_scores(maxMonth+1,0); 
		running_scores[0] = sum; 

		double max_score = running_scores[0]; 
		for (int i=1; i<maxMonth+1; i++)
		{
			running_scores[i] = running_scores[i-1] - scores[i]; // flipping signs
			if (running_scores[i]>max_score) max_score = running_scores[i]; 
		}

		double z=0; 
		vector<double> marginals(maxMonth+1, 0);
		for (int i=0; i<maxMonth+1; i++)
		{
			z += exp(running_scores[i]-max_score); 
			marginals[i] = z; // or cdf; marginals[i] stores p(t<i)
		}
		double log_z = log(z) + max_score; 
		v += log_z; 

		// for censored targets
		double z_h=0; 
		double log_z_h=0; 
		vector<double> marginals_h(maxMonth+1, 0);
		double max_score_h=0; 
		if ((sample[example_id].m_y.censored)&&(sparm.GetTrainUncensored()==0))
		{
			max_score_h = running_scores[event_time]; 
			for (int i=event_time+1; i<maxMonth+1; i++)
			{
				if (running_scores[i]>max_score_h) max_score_h = running_scores[i]; 
			}
			for (int i=event_time; i<maxMonth+1; i++)
			{
				z_h += exp(running_scores[i]-max_score_h); 
				marginals_h[i] = z_h; // or cdf; marginals[i] stores p(t<i)
			}
			log_z_h = log(z_h) + max_score_h; 
			v -= log_z_h; 
		} else {
			v -= running_scores[event_time]; 
		}

		for (int i=1; i<maxMonth+1; i++) 
		{
			size_t offset = (i-1)*sparm.GetOriginalSizePsi(); 
			double p = exp(log(marginals[i-1])+max_score-log_z);
			double factor = p; 
			if ((sample[example_id].m_y.censored)&&(sparm.GetTrainUncensored()==0))
			{
				if (event_time<i) {
					double p_h = exp(log(marginals_h[i-1])+max_score_h-log_z_h);
					factor -= p_h; 
				}
			} else {
				if (event_time<i) factor -= 1; 
			}
			multadd_ns(g, sample[example_id].m_x.features, factor, offset); // need offset here
			size_t bias_offset = maxMonth*sparm.GetOriginalSizePsi() + (i-1); 
			g[bias_offset] += factor*B_CONSTANT; 
		}
	}
	v = v*sparm.GetC1()/sample.size(); 
	smult_n(g, sparm.GetC1()/sample.size()); 

	v += 0.5*sprod_nn(w,w); 
	multadd_nn(g, w, 1.0); 

	v += 0.5*sparm.GetC2()*smoothing_obj(w, sparm);
	smoothing_grad(w, g, sparm); 

	return(v); 

}


double cubic_interpolation(double a1, double a2, double f1, double f2, double g1, double g2)
{
	double d1 = g1 + g2 - 3*(f1-f2)/(a1-a2); 
	double d2 = sqrt(d1*d1-g1*g2); 
	if (a1>a2) {
		d2*=-1;
	}
	return (a2 - (a2-a1)*(g2+d2-d1)/(g2-g1+2*d2));
}


double zoom(double a_lo, double a_hi, double f_lo, double f_hi, double g_lo, double g_hi, const DenseVector &w, const DenseVector &r,
		   const vector<Example> &sample, double v, double suff_decrease_value, const Sparm &sparm)
{
	double alpha_lo = a_lo;
	double alpha_hi = a_hi; 

	int iter=0;
	while (iter<100)
	{
		iter++;

		double alpha_j = cubic_interpolation(alpha_lo, alpha_hi, f_lo, f_hi, g_lo, g_hi); 

		DenseVector trial_pt(w); 
		multadd_nn(trial_pt, r, alpha_j);

		DenseVector trial_gradient(w.dim());
		double trial_obj = compute_gradient_obj(trial_pt, sample, trial_gradient, sparm); 
		double trial_decrease_value = sprod_nn(r, trial_gradient); 

		cout << "alpha_lo: " << alpha_lo << ", alpha_hi: " << alpha_hi << ", alpha_j: " << alpha_j << ", trial obj: " << trial_obj << endl; 

		if ((trial_obj>v+LINE_SEARCH_C1*alpha_j*suff_decrease_value)||(trial_obj>=f_lo))
		{
			alpha_hi = alpha_j; 
			f_hi = trial_obj; 
			g_hi = trial_decrease_value; 
		} else {
			if (fabs(trial_decrease_value)<=-LINE_SEARCH_C2*suff_decrease_value)
			{
				return alpha_j;
			}
			if (trial_decrease_value*(alpha_hi-alpha_lo)>=0) 
			{
				alpha_hi = alpha_lo;
				f_hi = f_lo;
				g_hi = g_lo;
			}
			alpha_lo = alpha_j;
			f_lo = trial_obj;
			g_lo = trial_decrease_value;

		} // end else

	} // end while 

	cout << "zoom() failed!" << endl; 
	exit(1); 
}

double wolfe_line_search(const DenseVector &w, double v, const DenseVector &g, const DenseVector &r, const vector<Example> &sample, const Sparm &sparm)
{
	double suff_decrease_value;
	double a_l, f_l, g_l, a_c, f_c, g_c; 

	suff_decrease_value = sprod_nn(r,g); 
	size_t n = w.dim(); 

	int iter = 0; 
	a_l = 0; 
	f_l = v;
	g_l = suff_decrease_value; 

	a_c = 1;
	while (iter<100)
	{
		iter++; 
		DenseVector trial_pt(w); // copy from w
		multadd_nn(trial_pt, r, a_c); 

		DenseVector trial_gradient(n); 

		f_c = compute_gradient_obj(trial_pt, sample, trial_gradient, sparm); 
		g_c = sprod_nn(r, trial_gradient); 

		cout << "in wolfe, v: " << v << ", alpha: " << a_c << ", trial_obj: " << f_c << endl; 

		if ((f_c>v+LINE_SEARCH_C1*a_c*suff_decrease_value)||((iter>1)&&(f_c>f_l)))
		{
			return zoom(a_l,a_c,f_l,f_c,g_l,g_c,w,r,sample,v,suff_decrease_value,sparm); 
		}

		if (fabs(g_c)<=-LINE_SEARCH_C2*suff_decrease_value)
		{
			return a_c; 
		}

		if (g_c>=0)
		{
			return zoom(a_c,a_l,f_c,f_l,g_c,g_l,w,r,sample,v,suff_decrease_value,sparm);
		}

		a_l = a_c;
		f_l = f_c; 
		g_l = g_c;

		a_c = min(2*a_c, (double) ALPHA_MAX);
	}
	cout << "wolfe_line_search() failed! " << endl; 
	exit(1); 

}
