/*******************************************************************************
 * SLRF.cpp
 *
 * An implementation of the Sparse Long-range Random Field (SLRF) prior.
 * The model is described in:
 *
 *   "Sparse Long-range Random Field and its Application to Image Denoising"
 *   by Yunpeng Li and Daniel P. Huttenlocher, in ECCV 2008
 *
 * The function provided by the interface compute the value and the gradient
 * (i.e. derivative) of the energy function corresponding to SLRF, which can
 * then be used with one's favorite gradient descent based optimizer.
 *
 * Yunpeng Li
 */

#include <math.h>

#define SQ(x)	((x) * (x))
#define PV(imgrows, nb, x, y, b)	imgrows[y][(x) * (nb) + b]

using namespace std;


// The following values are based on the range of byte-images, i.e. 0..255
// 3 levels
static const int n_levels = 3;
static const double parameters[3*n_levels] = {
	// level 0
    9.747903e-03,  // beta_1
    8.318796e-02,  // beta_2
    8.077487e+00,  // alpha
	// level 1
    1.562282e-01,
    7.557423e-01,
    5.443404e-01,
	// level 2
    3.505201e-01,
    4.385837e-01,
    5.828941e-01
};


// Values of log(1+0.5*x) at discrete values for approximate, fast look up
static const int log1p_table_res = 1000;  // # values within a unit interval
static const int log1p_table_range = 1000;  // range of x
static double *log1p_table = (double*)0;  // initially, not yet instantiated

// Initialize the lookup table
static void init_log1p_lookup_table() 
{
	if(log1p_table) {
        return;  // assume already initialized if not NULL
	}
	// o.w. init
    log1p_table = new double[log1p_table_res*log1p_table_range+1];
    for(int i=0; i<log1p_table_res*log1p_table_range+1; i++) {
        log1p_table[i] = log(1 + 0.5 * (double)i / log1p_table_res);
    }
}

// approx log(1+0.5*x) function using table look up
inline static double quick_log1p(double x) 
{
    int i = (int)(x * log1p_table_res + 0.5);
	if(i > log1p_table_res*log1p_table_range)
        return log(1 + 0.5 * x);
    else
        return log1p_table[i];
}


/* Compute the SLRF energy and its gradient (derivative) of a given image
   Make computation specific to the number of channels, so as to force inlining
   and loop unrolling/elimination
 */
typedef struct { static const int nb = 1; } C1;
typedef struct { static const int nb = 2; } C2;
typedef struct { static const int nb = 3; } C3;
typedef struct { static const int nb = 4; } C4;
typedef struct { static int nb; }           CV;
int CV::nb;

template <class T, class C>
double SLRFenergy_nb(const T **imgrows, int w, int h, double weight, 
					 double range, T **grad)
{
	int nb = C::nb;
	double eng = 0;
	int dx[2] = {1, 0}, dy[2] = {0, 1};  // clique orientation: 0-hor., 1-vert.
	init_log1p_lookup_table();

	for(int level=0, rad=1; level<n_levels; level++, rad*=2) 
	{
		const double *parm = parameters + 3 * level;
		double b1 = parm[0] * 255 / range;  // beta_1
		double b2 = parm[1] * 255 / range;  // beta_2
		double a = parm[2];  // alpha
		double sqb1 = SQ(b1), sqb2 = SQ(b2);
		double fac = a * weight;  // common factor
		
		for(int d=0; d<2; d++) 
		{
			int rx = rad * dx[d], ry = rad * dy[d];  // h/v dist to mid pixel

			for(int y=ry; y<h-ry; y++) 
			{
				int ylo = y - ry, yhi = y + ry;
				double roweng = 0;

				for(int x=rx; x<w-rx; x++) 
				{
					int xlo = x - rx, xhi = x + rx;

					for(int b=0; b<nb; b++) 
					{
						double p = PV(imgrows, nb, xlo, ylo, b);
						double q = PV(imgrows, nb, x, y, b);
						double r = PV(imgrows, nb, xhi, yhi, b);
						double val = sqb1 * SQ(p-r) + sqb2 * SQ(p-2*q+r);

						// energy
						roweng += quick_log1p(val) * fac;

						// gradient
						PV(grad, nb, x, y, b) += (T)((-2 * sqb2 * (p+r-2*q)) * (fac / (1+0.5*val)));
						PV(grad, nb, xlo, ylo, b) += (T)((sqb2 * (p+r-2*q) + sqb1 * (p - r)) * (fac / (1+0.5*val)));
						PV(grad, nb, xhi, yhi, b) += (T)((sqb2 * (p+r-2*q) - sqb1 * (p - r)) * (fac / (1+0.5*val)));
					}
				}

				eng += roweng;
			}
		}
	}

	return eng;
}


/* Compute the SLRF energy and its gradient (derivative)
   - Return the energy value and store gradient in 'grad' (assume already 
     allocated)
   - Assume image in row-major order: row-column-channel

   Input: 
     img -      the image data (array of pointers to the rows)
	 w -        image width
	 h -        image height
	 nb -       number of channels (bands)
	 range -    the range of pixel values for the image (e.g. 255 for 8-bit RGB)
	 weight -   weighting factor of the SLRF prior, application dependent
   Output: 
	 grad -     *accumulated* (i.e. '+=') the gradient of SLRF's energy function 
	            (same dimension as 'img')
	 Return value: the value of the energy function
 */
template <class T>
double SLRFenergy(const T **imgrows, int w, int h, int nb, double weight, 
				  double range, T **grad)
{
	double eng;
	switch(nb) {
		case 1:
			eng = SLRFenergy_nb<T,C1>(imgrows, w, h, weight, range, grad);
			break;
		case 2:
			eng = SLRFenergy_nb<T,C2>(imgrows, w, h, weight, range, grad);
			break;
		case 3:
			eng = SLRFenergy_nb<T,C3>(imgrows, w, h, weight, range, grad);
			break;
		case 4:
			eng = SLRFenergy_nb<T,C4>(imgrows, w, h, weight, range, grad);
			break;
		default:
			CV::nb = nb;
			eng = SLRFenergy_nb<T,CV>(imgrows, w, h, weight, range, grad);
	}
	return eng;
}


/* Call this version if the image data is given as a contiguous array
 */
template <class T>
double SLRFenergy(const T *imgarr, int w, int h, int nb, double weight, 
				  double range, T *grad)
{
	const T **imgrows = new const T*[h];
	T **gradrows = new T*[h];

	for(int y=0; y<h; y++) {
		imgrows[y] = imgarr + y * w * nb;
		gradrows[y] = grad + y * w * nb;
	}

	double eng = SLRFenergy(imgrows, w, h, nb, weight, range, gradrows);

	delete[] imgrows;
	delete[] gradrows;

	return eng;
}



//---------- Instantiation ------------
template <class T>
void __instantiate()
{
	SLRFenergy((const T*)0, 0, 0, 0, 0.0, 0.0, (T*)0);
}

void ____SLRF_cpp__instantiate_all()
{
	__instantiate<float>();
	__instantiate<double>();
}
