// Author: Daniel Cabrini Hauagge <hauagge@cs.cornell.edu>
// Date: 2011-10-20
// Description: Implementation of symmetry distance, naive version

#include <cmath>
#include <iostream>
#include <cassert>

#define LOG_EXPR(expr) //std::cout << #expr << " = " << (expr) << std::endl
#define LOG_MSG(msg) //std::cout << msg << std::endl

#ifndef NO_MEX
#include "mex.h"
#else
typedef size_t mwSize;
#endif

//#define PAIRWISE_DIST std::pow(*fIt - *rIt, 2) * (*kIt)
#define PAIRWISE_DIST std::fabs(*fIt - *rIt) * (*kIt);


inline
double
symdist_rot_ij(int kWidth, const double *kWeights,
               const double *im_ij, 
               mwSize nCols, mwSize nRows)
{
  int halfKWidth = (kWidth - 1) / 2;
    
  mwSize upperLeft = - halfKWidth - (halfKWidth * nRows);
  mwSize bottomRight = upperLeft + (kWidth - 1) * nRows + kWidth - 1;

  double const *fIt = im_ij + upperLeft; // forward iterator
  double const *rIt = im_ij + bottomRight; // reverse iterator
  double const *kIt = kWeights;
  double symdist = 0;
    
  int nElems = (kWidth * kWidth - 1)/2;
    
  for (int j = 0, el = 0; j < kWidth && el < nElems; j++) {
    for (int i = 0; i < kWidth && el < nElems; i++, fIt++, rIt--, kIt++, el++) {
      symdist += PAIRWISE_DIST;
    }
    fIt += nRows - kWidth ;
    rIt -= nRows - kWidth ;
  }
  LOG_EXPR(symdist);

  return symdist;
}

inline
double
symdist_hor_ij(int kWidth, const double *kWeights,
               const double *im_ij, 
               mwSize nCols, mwSize nRows)
{
  int halfKWidth = (kWidth - 1) / 2;
    
  mwSize upperLeft = - halfKWidth - (halfKWidth * nRows);
    
  double const *fIt = im_ij + upperLeft; // forward iterator
  double const *kIt = kWeights;
  double symdist = 0;
        
  for (int j = 0; j < kWidth; j++) {
    double const *rIt = fIt + kWidth - 1; // reverse iterator
    for (int i = 0; i < halfKWidth; i++, fIt++, rIt--, kIt++) {
      symdist += PAIRWISE_DIST;
    }
    fIt += 1 + halfKWidth + nRows - kWidth;

    kIt += 1 + halfKWidth;
  }
  LOG_EXPR(symdist);
    
  return symdist;
}

inline
double
symdist_vert_ij(int kWidth, const double *kWeights,
		const double *im_ij, 
		mwSize nCols, mwSize nRows)
{
    
  int halfKWidth = (kWidth - 1) / 2;
    
  mwSize upperLeft = - halfKWidth - (halfKWidth * nRows);
  mwSize upperRight = upperLeft + (kWidth - 1) * nRows;
    
  double const *fIt = im_ij + upperLeft; // forward iterator
  double const *rIt = im_ij + upperRight; // reverse iterator
  double const *kIt = kWeights;
  double symdist = 0;
    
  for (int j = 0, el = 0; j < halfKWidth; j++) {
    for (int i = 0; i < kWidth; i++, fIt++, rIt++, kIt++) {
      LOG_MSG(*fIt << ", " << *rIt);
      symdist += PAIRWISE_DIST;
    }
    fIt += nRows - kWidth ;
    rIt -= nRows + kWidth;
  }
  LOG_EXPR(symdist);
    
  return symdist;
}

void
errMsg(const char * msg)
{
#ifndef NO_MEX
  mexErrMsgTxt(msg);
#else
  std::cerr << "ERROR: " << msg << std::endl;
  exit(EXIT_FAILURE);
#endif
}

void
symdist(char symtype,
        int kWidth, const double *kWeights,
        const double *im, double *sd,
        mwSize nCols, mwSize nRows)
{
  assert(kWidth%2 == 1);

  double (*symfunc)(int, const double *, const double *, mwSize, mwSize);
  switch (symtype) {
  case 'r': case 'R':
    symfunc = symdist_rot_ij;
    break;
  case 'v': case 'V':
    symfunc = symdist_vert_ij;
    break;
  case 'h': case 'H':
    symfunc = symdist_hor_ij;
    break;    
  default:
    errMsg("Invalid value for symtype");
    break;
  }
        
  const mwSize kMid = (kWidth - 1)/2; // center of kernel
  const mwSize firstValid = nRows * kMid + kMid; // Fist valid pixel in symdist image
  const mwSize goodRows = nRows - kWidth + 1;
  const mwSize goodCols = nCols - kWidth + 1;
            
  double *sdIt = sd + firstValid;
  double const *imIt = im + firstValid;
  for(mwSize j = 0; j < goodCols; j++) {
    for(mwSize i = 0; i < goodRows; i++, sdIt++, imIt++) {
      //*sdIt = symdist_rot_ij(kWidth, kWeights, imIt, nCols, nRows);
      *sdIt = symfunc(kWidth, kWeights, imIt, nCols, nRows);
    }
    sdIt += kWidth - 1;
    imIt += kWidth - 1;
  }
}

#ifndef NO_MEX
void 
mexFunction(int nlhs, mxArray *plhs[],
            int nrhs, const mxArray *prhs[])
{
  // Make sure the number of parameters is correct
  if(nrhs != 3) mexErrMsgTxt("Three input arguments are required.");
  if(nrhs == 1 && !mxIsClass(prhs[0], "double")) {
    mexErrMsgTxt("Input must be of type double.");
  }
    
  char symtype = ((char *) mxGetPr(prhs[0]))[0];
  double *im = (double *)mxGetPr(prhs[1]);
  mwSize nCols = mxGetN(prhs[1]);
  mwSize nRows = mxGetM(prhs[1]);
    
  double *kWeights = (double *)mxGetPr(prhs[2]);
    
  int kWidth = mxGetN(prhs[2]);
  if (kWidth != mxGetM(prhs[2])) {
    mexErrMsgTxt("Kernel weights matrix should be a square matrix");
  }
            
  // Allocate output matrix
  plhs[0] = mxCreateDoubleMatrix(nRows, nCols, mxREAL);
  double *sd = mxGetPr(plhs[0]);
  for (int i = 0; i < nCols * nRows; i++) sd[i] = 0;
    
  // Call computational routine
  symdist(symtype, kWidth, kWeights, im, sd, nCols, nRows);
}
#else
int
main(int argc, char **argv)
{
  const int imSize = 5;
  double *im = new double[imSize * imSize];
  double *sd = new double[imSize * imSize];
  const int kWidth = 3;
  double *kWeights = new double [kWidth * kWidth];
    
  for (int i = 0; i < imSize*imSize; i++) im[i] = i;
  for (int i = 0; i < kWidth*kWidth; i++) kWeights[i] = 1;
    
  symdist('h', kWidth, kWeights, im, sd, imSize, imSize); 
    
  return EXIT_SUCCESS;
}
#endif
