/* 
 *  Copyright (c) 2008  Noah Snavely (snavely (at) cs.washington.edu)
 *    and the University of Washington
 *
 *  This program is free software; you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published by
 *  the Free Software Foundation; either version 2 of the License, or
 *  (at your option) any later version.
 *
 *  This program is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License for more details.
 *
 */

/* colorcorrect.c */
/* Correct color two images */

#include <float.h>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>

#include "defines.h"
#include "dmap.h"
#include "image.h"
#include "matrix.h"

/* Return a color corrected version of b */
img_t *img_color_correct_cspnd(img_t *a, img_t *b, img_dmap_t *cspnd) 
{
    int i, num_cspnd, x, y, count;
    double *a_red, *b_red;
    double *a_green, *b_green;
    double *a_blue, *b_blue;
    double x_red[2], x_blue[2], x_green[2];
    img_t *b_cc;

    /* Count the correspondences */
    num_cspnd = 0;
    for (i = 0; i < cspnd->w * cspnd->h; i++) {
	if (cspnd->dists[i] != DBL_MAX) {
	    num_cspnd++;
	}
    }

    /* Fill in the matrices */
    b_red = (double *) malloc(sizeof(double) * 2 * num_cspnd);
    b_green = (double *) malloc(sizeof(double) * 2 * num_cspnd);
    b_blue = (double *) malloc(sizeof(double) * 2 * num_cspnd);

    a_red = (double *) malloc(sizeof(double) * num_cspnd);
    a_green = (double *) malloc(sizeof(double) * num_cspnd);
    a_blue = (double *) malloc(sizeof(double) * num_cspnd);

    count = 0;
    for (y = 0; y < cspnd->h; y++) {
	for (x = 0; x < cspnd->w; x++) {
	    color_t a_c, b_c;
	    int nx, ny;
	    int idx = y * cspnd->w + x;

	    if (cspnd->dists[idx] == DBL_MAX)
		continue;

	    nx = (int) rint(Vx(cspnd->nns[idx]));
	    ny = (int) rint(Vy(cspnd->nns[idx]));

	    a_c = img_get_pixel(a, x, y);
	    b_c = img_get_pixel(b, nx, ny);

	    b_red[2 * count + 0] = (double) b_c.r;
	    b_red[2 * count + 1] = 1.0;
	    a_red[count] = a_c.r;
	    
	    b_green[2 * count + 0] = (double) b_c.g;
	    b_green[2 * count + 1] = 1.0;
	    a_green[count] = a_c.g;

	    b_blue[2 * count + 0] = (double) b_c.b;
	    b_blue[2 * count + 1] = 1.0;
	    a_blue[count] = a_c.b;

	    count++;
	}
    }

    /* Solve the three linear systems */
    dgelsy_driver(b_red, a_red, x_red, num_cspnd, 2, 1);
    dgelsy_driver(b_green, a_green, x_green, num_cspnd, 2, 1);
    dgelsy_driver(b_blue, a_blue, x_blue, num_cspnd, 2, 1);

    printf("r channel: %0.3fx + %0.3f\n", x_red[0], x_red[1]);
    printf("g channel: %0.3fx + %0.3f\n", x_green[0], x_green[1]);
    printf("b channel: %0.3fx + %0.3f\n", x_blue[0], x_blue[1]);

    free(a_red);
    free(a_blue);
    free(a_green);
    free(b_red);
    free(b_green);
    free(b_blue);

    /* Do the color correction */
    b_cc = img_new(b->w, b->h);
    
    for (y = 0; y < b->h; y++) {
	for (x = 0; x < b->w; x++) {
	    if (img_pixel_is_valid(b, x, y)) {
		color_t c = img_get_pixel(b, x, y);
		double r = x_red[0] * (double) c.r + x_red[1];
		double g = x_green[0] * (double) c.g + x_green[1];
		double b = x_blue[0] * (double) c.b + x_blue[1];
		
		img_set_pixel(b_cc, x, y, 
			      CLAMP((int) rint(r), 0, 255), 
			      CLAMP((int) rint(g), 0, 255),
			      CLAMP((int) rint(b), 0, 255));
	    } else {
		img_invalidate_pixel(b, x, y);
	    }
	}
    }

    return b_cc;
}
