/*******************************************************************************
 * Denoise.cpp
 *
 * Denoise using gradient descent (in particular L-BFGS)
 * - Based on the 3-clique MRF with coupled derivatives
 *
 * Part of SLRF denoising demo, which implements the model described in:
 * - Yunpeng Li & Daniel P. Huttenlocher, Sparse Long-range Random Field and 
 *   its Application to Image Denoising, ECCV 2008.
 */

#include <stdio.h>
#include <stdlib.h>
#include <assert.h>
#include <math.h>
#include <string.h>
#include <string>

#include "lbfgs/libs/ap.h"
#include "lbfgs/lbfgs.h"

#include "Utils.h"
#include "SLRF.h"


/* Total prior energy and accumulate gradient
   - Use SLRF
 */
static double prior_eng(ArrImage<double> &im, ArrImage<double> &imGrad) 
{
    return SLRFenergy(im.arr, im.width, im.height, im.nBands, 1.0, 255.0,
		imGrad.arr);
}


/* Data energy and gradient

   - Assume Gaussian noise, i.e. energy(x,y) = (x-y)^2/(2*sigma^2)
   - clqCounts contains for each pix the number of cliques it belongs to,
     used to scale data weight (multiply).
 */
static double data_eng(ArrImage<double> &im, CImageOf<float> &imData,
                       double sigma, double data_weight, 
                       CImageOf<int> &clqCounts, ArrImage<double> &imGrad)
{
    int w = im.width, h = im.height;
    const double base_w = data_weight / (2 * sq(sigma));
    double eng = 0;
	for(int y=0; y<h; y++) {
		double sum = 0;
		for(int x=0; x<w; x++) {
			double diff = im.Pixel(x, y) - imData.Pixel(x, y, 0);
			int cc = clqCounts.Pixel(x, y, 0);
			sum += cc * base_w * sq(diff);
			imGrad.Pixel(x, y) += cc * base_w * 2 * diff;
		}
		eng += sum;
	}
    return eng;
}


/* Initialized the cliques-each-pixel-belongs-to counters
 */
static void init_clq_counts(CImageOf<int> &clqCounts, int width, int height,
							int nLevels)
{
    int w = width, h = height;
    clqCounts.ReAllocate(CShape(w, h, 1));
    clqCounts.FillPixels(0);
	int dx[2] = {1, 0}, dy[2] = {0, 1};  // clique orientation: 0-hor., 1-vert.
	for(int level=0, rad=1; level<nLevels; level++, rad*=2) {
		for(int d=0; d<2; d++) {
			int rx = rad * dx[d], ry = rad * dy[d];  // h/v dist to mid pixel
			for(int y=ry; y<h-ry; y++) {
				int ylo = y - ry, yhi = y + ry;
				for(int x=rx; x<w-rx; x++) {
					int xlo = x - rx, xhi = x + rx;
					clqCounts.Pixel(x, y, 0)++;
					clqCounts.Pixel(xlo, ylo, 0)++;
					clqCounts.Pixel(xhi, yhi, 0)++;
				}
			}
		}
	}
}


/* Constants -- fixed for this demo
 */
static const int denoise_n_params = 3;
static const double denoise_inital_gs_coeff = 0.05;


/* Variables -- initialized later
 */
static double denoise_sigma = -1;
static double denoise_data_weight = -1;


/* Counter
 */
static int denoise_iter_counter;
static bool suppress_interloop_output = false;


/* Pointers to the data that cannot be passed by the L-BFGS interface
 */
static CImageOf<float> *denoise_imData_ptr = NULL;  // input (noisy) image
static CImageOf<int> *denoise_clqCounts_ptr = NULL;
static CImage *denoise_imOrig_ptr = NULL;


/* Function computing energy value and gradient required by lbfgs lib
 */
static void denoise_lbfgs_funcgrad(ap::real_1d_array &x, double &f, 
                                   ap::real_1d_array &g)
{
    CShape sh = denoise_imData_ptr->Shape();
    int width = sh.width, height = sh.height;
    ArrImage<double> imX(width, height, 1, x.getcontent());
    ArrImage<double> imG(width, height, 1, g.getcontent());
    for(int i=0; i<imG.length; i++) {
        imG.arr[i] = 0;
    }
    assert(denoise_sigma > 0);
    f = 0;
    f += prior_eng(imX, imG);
    f += data_eng(imX, *denoise_imData_ptr, denoise_sigma, denoise_data_weight,
        *denoise_clqCounts_ptr, imG);
}


/* Function to be called by L-BFGS after each iteration
 */
static void denoise_lbfgs_interloop(const ap::real_1d_array &, double, 
                             const ap::real_1d_array &)
{
    ++denoise_iter_counter;
    if(suppress_interloop_output) {
        return;
    }
    printf(".%d", denoise_iter_counter);
    fflush(stdout);
}


/* Energy minimization using gradient descent (L-BFGS)
   
   - max_iters: unlimited iterations if == 0
   - Return PSNR if imOrig is given, 0 otherwise
 */
template <class T>
static double denoise(CImageOf<float> &im, double noise_sigma, 
                      CImageOf<float> &imOut, int max_iters, int verbose, 
					  CImageOf<T> *imOrig)
{
    if(im.Shape().nBands != 1) {
        CImageOf<float> imYCbCr[3];
        // scaling factor of s.d. for transformation from RGB to YCbCr
        double sigma_mult[3] = {0.668555159, 0.623063139, 0.657199576};
        convertRGBtoYCbCr(im, imYCbCr[0], imYCbCr[1], imYCbCr[2]);
        const char *chs[3] = {"Y:  ", "Cb: ", "Cr: "};
        for(int i=0; i<3; i++) {
            if(verbose) {
                printf("%s\t", chs[i]);
            }
            denoise(imYCbCr[i], noise_sigma * sigma_mult[i], imYCbCr[i], 
                max_iters, verbose, (CImageOf<T>*)NULL);
        }
        convertYCbCrtoRGB(imYCbCr[0], imYCbCr[1], imYCbCr[2], imOut);
        double psnr = 0;
        if(imOrig) {  // gen stats
            CImageOf<T> imT;
            convertImageType(imOut, imT);
            double rmse = 0;
            psnr = signal2noise(*imOrig, imT, &rmse);
            printf("PSNR: %.2f (RMSE: %.2f)\n", psnr, rmse);
        }
        return psnr;
    }

    // Rough guideline for data term weight: (0.5, 0.06) corresponds to the 
    // results in the paper.  The values are approximately fitted on a subset
    // of the training images.  Smaller values e.g. (0.4, 0.055) makes output 
    // smoother but slightly lower in PSNR.
    const double b = 0.5, k = 0.06;  
    denoise_data_weight = b + k * noise_sigma;

    int width = im.Shape().width, height = im.Shape().height;
    int n = width * height;

    CImageOf<int> clqCounts;
    init_clq_counts(clqCounts, width, height, SLRF_LEVELS);

    // Initialize the vector for LBFGS
    ap::real_1d_array imX_vec;
    imX_vec.setbounds(1, n);
    double *imX_arr = imX_vec.getcontent();
    ArrImage<double> imX(width, height, 1, imX_arr);

    { // low-pass filter input (pre-process)
        CImageOf<float> imGS;
        gaussianSmooth(im, imGS, denoise_inital_gs_coeff * noise_sigma);
        convertCImageToArrImage(imGS, imX);  // LBFGS starting point
    }

    // Set global vars/ptrs
    if(verbose == 0) {
        suppress_interloop_output = true;
    }
    denoise_imData_ptr = &im;
    denoise_clqCounts_ptr = &clqCounts;
    denoise_imOrig_ptr = imOrig;
    denoise_sigma = noise_sigma;
    denoise_iter_counter = 0;

    // Call optimization routine
    int info = 0;  // return value form optimizer
    int m = 5;  // LBGGS memory option (3 -- 7 are recommended values)
    double epsg = n / 3e6;  // terminate when gradient magnitude is too small
    double epsf = 0;
    double epsx = 0;
    int maxit = max_iters;
    lbfgsminimize(denoise_lbfgs_funcgrad, n, m, imX_vec, epsg, epsf, epsx, 
        maxit, info, denoise_lbfgs_interloop);
    if(verbose) {
        printf("\n");
    }

    // Generate output
    convertArrImageToCImage(imX, imOut);
    double psnr = 0;
    if(imOrig) {  // generate stats
        CImageOf<T> imT;
        convertImageType(imOut, imT);
        double rmse = 0;
        psnr = signal2noise(*imOrig, imT, &rmse);
        printf("PSNR: %.2f (RMSE: %.2f)\n", psnr, rmse);
    }

    return psnr;
}


/* Main entry point
 */
template <class T>
double slrf_denoise(CImageOf<T> &im, double noise_sigma, CImageOf<T> &imOut, 
                    int max_iters, int verbose, CImageOf<T> *imOrig)
{
    if(verbose) {
        printf("Processing (max %d iterations)\n", max_iters);
    }
    CImageOf<float> imd, imOutd;
    convertImageType(im, imd);
    double psnr = denoise(imd, noise_sigma, imOutd, max_iters, verbose, imOrig);
    convertImageType(imOutd, imOut);
    return psnr;
}




//---------- Instantiation -----------

template <class T>
static void __instantiate(CImageOf<T> im) {
    slrf_denoise(im, 0.0, im, 0, 0, &im);
}

void ____denoise_gd__instantiate() {
    __instantiate(CByteImage());
}

