#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <string.h>

#define MAX_ITEMS 100000
#define debug 0
#define eps 1.0e-10

  /* 
     modified 02.02.00
     to fix auto_threshold bug!!! and take negative of ACE

     modified 11.12.99
     to compile on unix, try "cc -o roc.auto roc.auto.c -lm"

     This program accepts a number of options:

     -stats                 : print stats (ACC, PPV, NPV, SEN, SPC)
     -threshold 0.83        : use value 0.83 as threshold for stats
     -plot                  : print ROC curve
     -accplot               : plot accuracy vs. threshold
     -noroc                 : don't print ROC area
     -ace                   : include ACE calibration score
     -breyce                : include breyce calibration score
     -monti                 : only output accuracies, roc area, ace, and breyce

     If no option is specified, only the ROC area is printed. Most
     options may be combined (the ROC plot and acc vs. thresh plots
     disable each other).  Options may be specified in any order.
     Some abbreviations are available for most options.


     Input to program is a sequence of lines, one line per case
     with format: "TRUE_VALUE whitespace PRED_VALUE"

     The file roc.test.data contains sample test data.  The file
     roc.test.results contains the results of running:

          roc -option < roc.test.data > roc.test.results

     TRUE_VALUE and PRED_VALUE can be any numbers, positive
     or negative.  The program computes the mean TRUE_VALUE and
     mean PRED_VALUE.  

     If TRUE_VALUE <= MEAN_TRUE_VALUE, the truth for the case is
     considered to be a 0, otherwise it is considered a 1. This 
     lets you use any coding for the TRUE_VALUE's as long as the
     value used to represent 0's is less than the value used to
     represent 1's.  So you can code things as 0/1, -1/+1, etc.
     Note that the ACE and Breyce calibration scores both assume 
     the input data uses 0=false, 1=true, with the prediction 
     probabilities ranging between 0 and 1.    

     The program computes the ROC curve, ROC area, accuracy (ACC), 
     Positive Predictive Value (PPV), Negative Predictive Value (NPC),
     Sensitivity (SEN), and Specificity(SPC).

     Computing things like Accuracy requires a prediction threshold
     for the PRED_VALUE:

     (PRED_VALUE <  thresh) is interpreted as a prediction of 0.
     (PRED_VALUE >= thresh) is interpreted as a prediction of 1.

     The program takes an optional argument that lets you specify the
     prediction threshold.  If you don't specify a value, the program
     assumes 0.5, which is the right thing to do if the PRED_VALUES
     are properly calibrated probabilities.
     
     The program also computes two different prediction threshold
     values itself. The first is done by finding the threshold that
     would make the number of predicted 1's match the number of true
     1's in the data set.  If 15 of 100 points given to the program
     are true 1's, it sorts the points by PRED_VALUE (least first) and
     computes a threshold from the average PRED_VALUE of points 84 and
     85 in the sort.  This prediction threshold makes 15 of the 100
     points be predicted 1.  (If two or more points have the same
     PRED_VALUE and these to sort to the place where the threshold
     finder picks values, the number of points predicted to be 1 can't
     match the number of true 1's.)

     The second computed threshold is the prediction threshold that
     maximizes the Accuracy.  This threshold is found by trying all
     thresholds that fall half way between adjacent prediction values
     in the data set, and reporting the one that yields the highest
     accuracy.  It is not uncommon for more than one threshold to
     yield the same maximum accuracy.  In this case, the program
     reports the lowest threshold that achieved maximum accuracy, and
     prints a caution that more than one prediction threshold achives
     this same accuracy.  When this happens, you can rerun the program
     with the -accplot option to see the accuracy for each threshold.
     (In Unix "roc.auto -accplot -noroc < infile | sort -n +1 | tail" 
     will find just those thresholds that yielded highest accuracy.)

     The program computes the error measures listed above for the 0.5 
     threshold (or user supplied threshold) and for the two automatically
     computed thresholds.  (Neither affects the ROC curve or its area.)  

     Be careful using computed thresholds and their associated stats.
     It is probably improper to use the threshold computed on a test
     set to estimate statistics for that test set.  The threshold
     probably should be estimated with the training set, and then
     given to the program to calculate test set stats.  That's why the
     threshold can be specified as an optional parameter.

     The ROC curve is a plot of SENSITIVITY vs. 1-SPECIFICITY as the
     prediction threshold sweeps from 0 to 1.  A typical ROC curve for
     reasonably accurate predictions looks like this:

              |                                   *
          S   |                         *        
          E   |                 *
          N   |           *
          S   |               
          I   |       *  
          T   |       
          I   |    *
          V   |    
          I   |  *
          T   |  
          Y   |*
              - - - - - - - - - - - - - - - - - - - 
                             1 - SPECIFICITY

     If there is no relationship between the prediction and truth, the
     ROC curve is a diagonal line with area 0.5.  If the prediction
     strongly predicts truth, the curve rises quickly and has area near
     1.0.  If the prediction strongly predicts anti-truth, the ROC area
     is less than 0.5.

     Here's a definition of SPECIFICITY, SENSITIVITY, and the other
     error measures this program computes.  (This is from Constantin's
     email.)

                           MODEL PREDICTION

                      |       1       0       |
                - - - + - - - - - - - - - - - + - - - - -
     TRUE         1   |       A       B       |    A+B
    OUTCOME           |                       |
                  0   |       C       D       |    C+D
                - - - + - - - - - - - - - - - + - - - - -
                      |      A+C     B+D      |  A+B+C+D


                1 = POSITIVE
                0 = NEGATIVE


                ACC = (A+D) /(A+B+C+D)
                PPV = A / (A+C)
                NPV = D / (B+D)
                SEN = A / (A+B)
                SPE = D / (C+D)


     WARNING!: This code has not been thoroughly tested.  If you find
               an error, please email me: caruana@cs.cmu.edu

  */

float  true[MAX_ITEMS];
float  pred[MAX_ITEMS];
double mean_true, mean_pred;
double pred_thresh;
int    a, b, c, d;
double freq_thresh, threshold;
double max_acc, max_acc_thresh, last_acc_thresh, acc, acc_plot;
int    freq_a, freq_b, freq_c, freq_d;
int    max_acc_a, max_acc_b, max_acc_c, max_acc_d, max_acc_count;

int arg, taken, area, plot, stats, thresh;
int ace, breyce, monti;
int no_item, item;
int tt, tf, ft, ff;
int total_true_0, total_true_1;
double sens, spec, tpf, fpf, tpf_prev, fpf_prev, roc_area;
double ace_sum, breyce_sum;

/* compute the accuracy using the threshold */

double accuracy (double threshold)
{
  int a,b,c,d,item;
  a = 0; b = 0; c = 0; d = 0;
  for (item=0; item<no_item; item++)
    if ( true[item] == 1 )
    /* true outcome = 1 */
      {
	if ( pred[item] >= threshold )
	  a++;
	else
	  b++;
      }
    else
    /* true outcome = 0 */
      {
	if ( pred[item] >= threshold )
	  c++;
	else
	  d++;
      }
  return( ((double)(a+d)) / (((double)(a+b+c+d)) + eps) );
}

/* partition is used by quicksort */

int partition (p,r)
     int p,r;
{
  int i, j;
  float x, tempf;
  
  x = pred[p];
  i = p - 1;
  j = r + 1;
  while (1)
    {
      do j--; while (!(pred[j] <= x));
      do i++; while (!(pred[i] >= x));
      if (i < j)
	{
	  tempf = pred[i];
	  pred[i] = pred[j];
	  pred[j] = tempf;
	  tempf = true[i];
	  true[i] = true[j];
	  true[j] = tempf;
	}
      else
	return(j);
    }
}

/* vanilla quicksort */

quicksort (p,r)
     int p,r;
{
  int q;
  
  if (p < r)
    {
      q = partition (p,r);
      quicksort (p,q);
      quicksort (q+1,r);
    }
}

main (argc, argv)
     int  argc;
     int  **argv;

{
  area = 1;
  plot = 0;
  stats = 0;
  thresh = 0;
  acc_plot = 0;
  pred_thresh = 0.5;
  ace = 0;
  breyce = 0;
  monti = 0;
  
  arg = 1;
  while ( arg < argc )
    {
      taken = 0;
      if (!strcmp((char *)argv[arg], "-a")        ||
	  !strcmp((char *)argv[arg], "-area")     ||
	  !strcmp((char *)argv[arg], "-rocarea")  ||
	  !strcmp((char *)argv[arg], "-roc_area") ||
	  !strcmp((char *)argv[arg], "-ROC_area") ||
	  !strcmp((char *)argv[arg], "-roc")      ||
	  !strcmp((char *)argv[arg], "-ROC"))
	{
	  area = 1;
	  taken = 1;
	}
      if (!strcmp((char *)argv[arg], "-noa")      ||
	  !strcmp((char *)argv[arg], "-noroc")    ||
	  !strcmp((char *)argv[arg], "-noROC")    ||
	  !strcmp((char *)argv[arg], "-noarea"))
	{
	  area = 0;
	  taken = 1;
	}
      if (!strcmp((char *)argv[arg], "-p")        ||
	  !strcmp((char *)argv[arg], "-rocplot")  ||
	  !strcmp((char *)argv[arg], "-ROCplot")  ||
	  !strcmp((char *)argv[arg], "-roc_plot") ||
	  !strcmp((char *)argv[arg], "-ROC_plot") ||
	  !strcmp((char *)argv[arg], "-plot"))
	{
	  plot = 1;
	  acc_plot = 0;
	  taken = 1;
	}
      if (!strcmp((char *)argv[arg], "-accuracy")       ||
	  !strcmp((char *)argv[arg], "-accuracyplot")   ||
	  !strcmp((char *)argv[arg], "-accuracy_plot")   ||
	  !strcmp((char *)argv[arg], "-ACCplot")        ||
	  !strcmp((char *)argv[arg], "-ACC_plot")       ||
	  !strcmp((char *)argv[arg], "-accplot")        ||
	  !strcmp((char *)argv[arg], "-acc_plot")       ||
	  !strcmp((char *)argv[arg], "-threshold_plot") ||
	  !strcmp((char *)argv[arg], "-thresholdplot"))
	{
	  acc_plot = 1;
	  plot = 0;
	  taken = 1;
	}
      if (!strcmp((char *)argv[arg], "-s")        ||
	  !strcmp((char *)argv[arg], "-stat")     ||
	  !strcmp((char *)argv[arg], "-stats"))
	{
	  stats = 1;
	  taken = 1;
	}
      if (!strcmp((char *)argv[arg], "-t")        ||
	  !strcmp((char *)argv[arg], "-thresh")   ||
	  !strcmp((char *)argv[arg], "-threshold"))
	{
	  thresh = 1;
	  stats = 1;
	  arg++;
	  pred_thresh = atof((char *)argv[arg]);
	  taken = 1;
	}
      if (!strcmp((char *)argv[arg], "-ace")      ||
	  !strcmp((char *)argv[arg], "-ACE"))
	{
	  ace = 1;
	  taken = 1;
	}
      if (!strcmp((char *)argv[arg], "-breyce"))
	{
	  breyce = 1;
	  taken = 1;
	}
      if (!strcmp((char *)argv[arg], "-monti"))
	{
	  monti = 1;
          ace = 1;
          breyce = 1;
	  taken = 1;
	}
      if (!taken)
	{
	  printf("\nWarning!: Unrecognized program option %s\n", argv[arg]);
	}
      arg++;
    }
  
  no_item = 0;
  mean_true = 0.0;
  mean_pred = 0.0;
  while ( (scanf("%f %f", &true[no_item], &pred[no_item])) != EOF )
    {
      mean_true+= true[no_item];
      mean_pred+= pred[no_item];
      if (ace) 
	{
	  /* 
	   * modified by S. Monti 2/4/2000 
	   *
	   * ace is computed as follows: ace = - sum_i log p(y_i)
	   *
	   * remove this...
	   *
	   if (pred[no_item] <= 0.0)
	   ace_sum += -9e99;
	   else
	   ace_sum += log(pred[no_item]);
           *
           * and replace with following
           */
	  if ( true[no_item] ) {
	    ace_sum += ( pred[no_item]<=0.0 ? -9e99 : log(pred[no_item]) );
	  }
	  else {
	    ace_sum += ( pred[no_item]>=1.0 ? -9e99 : log(1.0-pred[no_item]) );
	  }
	}
      if (breyce) breyce_sum += (true[no_item]-pred[no_item])*(true[no_item]-pred[no_item]);
      no_item++;
      if ( no_item >= MAX_ITEMS )
	{
	  printf ("Aborting.  Exceeded %d items.\n", MAX_ITEMS);
	  exit(1);
	}
    }
  mean_true = mean_true / ((double) no_item);
  mean_pred = mean_pred / ((double) no_item);

  if (debug)
    {
      printf("%d pats read. mean_true %6.4lf. mean_pred %6.4lf\n", no_item, mean_true, mean_pred);
      fflush(stdout);
    }

  total_true_0 = 0;
  total_true_1 = 0;
  for (item=0; item<no_item; item++)
    if ( true[item] < mean_true )
      {
	true[item] = 0;
	total_true_0++;
      }
    else
      {
	true[item] = 1;
	total_true_1++;
      }

  /* sort data by predicted value */

  quicksort (0,(no_item-1));

  /* find the prediction threshold that maximizes accuracy */

  max_acc = -9.9e10;
  max_acc_thresh = 0.0;
  last_acc_thresh = 0.0;
  max_acc_count = 1;
  for (item=0; item<(no_item-1); item++)
    {
      threshold = (pred[item] + pred[item+1]) / 2.0;
      acc = accuracy(threshold);
      if ( acc_plot )
	printf ("%lf %lf\n", threshold, acc);
      if ( acc > max_acc )
	{
	  max_acc = acc;
	  max_acc_thresh = threshold;
	  last_acc_thresh = threshold;
	  max_acc_count = 1;
	}
      if ( (acc == max_acc) && (threshold != last_acc_thresh) )
	{
	  max_acc_count++;
	  last_acc_thresh = threshold;
	}
    }

  /*  find the prediction threshold such that the predicted number    */
  /*  of 0's and 1's matches the observed number of true 0's and 1's  */

  freq_thresh = (pred[total_true_0]+pred[total_true_0-1])/2.0;

  /* now update some statistics using the various thresholds */

  a = 0; 
  b = 0; 
  c = 0; 
  d = 0;
  freq_a = 0;
  freq_b = 0;
  freq_c = 0;
  freq_d = 0;
  max_acc_a = 0;
  max_acc_b = 0;
  max_acc_c = 0;
  max_acc_d = 0;
  for (item=0; item<no_item; item++)
    if ( true[item] == 1 )
    /* true outcome = 1 */
      {
	if ( pred[item] >= pred_thresh )
	  a++;
	else
	  b++;
	if ( pred[item] >= freq_thresh )
	  freq_a++;
	else
	  freq_b++;
	if ( pred[item] >= max_acc_thresh )
	  max_acc_a++;
	else
	  max_acc_b++;
      }
    else
    /* true outcome = 0 */
      {
	if ( pred[item] >= pred_thresh )
	  c++;
	else
	  d++;
	if ( pred[item] >= freq_thresh )
	freq_c++;
      else
	freq_d++;
	if ( pred[item] >= max_acc_thresh )
	max_acc_c++;
      else
	max_acc_d++;
      }

  /* now let's do the ROC cruve and area */

  tt = 0; 
  tf = total_true_1; 
  ft = 0; 
  ff = total_true_0;

  sens = ((double) tt) / ((double) (tt+tf));
  spec = ((double) ff) / ((double) (ft+ff));
  tpf = sens;
  fpf = 1.0 - spec;
  if ( plot )
    printf ("%6.4lf %6.4lf\n", fpf, tpf);
  roc_area = 0.0;
  tpf_prev = tpf;
  fpf_prev = fpf;

  for (item=no_item-1; item>-1; item--)
    {
      tt+= true[item];
      tf-= true[item];
      ft+= 1 - true[item];
      ff-= 1 - true[item];
      sens = ((double) tt) / ((double) (tt+tf));
      spec = ((double) ff) / ((double) (ft+ff));
      tpf  = sens;
      fpf  = 1.0 - spec;
      if ( item > 0 )
	if ( pred[item] != pred[item-1] )
	  {
	    if ( plot )
	      printf ("%6.4lf %6.4lf\n", fpf, tpf);
	    roc_area+= 0.5*(tpf+tpf_prev)*(fpf-fpf_prev);
	    /*
	      printf ("0.5*(%6.4lf+%6.4lf)*(%6.4lf-%6.4lf) = 0.5*(%6.4lf)*(%6.4lf) = %6.4lf %6.4lf\n", tpf,tpf_prev,fpf,fpf_prev,tpf+tpf_prev,fpf-fpf_prev,0.5*(tpf+tpf_prev)*(fpf-fpf_prev),roc_area);
	    */
	    tpf_prev = tpf;
	    fpf_prev = fpf;
	  }
      if ( item == 0 )
	{
	  if ( plot )
	    printf ("%6.4lf %6.4lf\n", fpf, tpf);
	  roc_area+= 0.5*(tpf+tpf_prev)*(fpf-fpf_prev);
	  /*
	    printf ("0.5*(%6.4lf+%6.4lf)*(%6.4lf-%6.4lf) = 0.5*(%6.4lf)*(%6.4lf) = %6.4lf %6.4lf\n", tpf,tpf_prev,fpf,fpf_prev,tpf+tpf_prev,fpf-fpf_prev,0.5*(tpf+tpf_prev)*(fpf-fpf_prev),roc_area);
	  */
	}
    }

  if ( (stats || area) && (plot || acc_plot) )
    printf ("\n");
  if (debug) printf ("%d %d %d %d\n", a,b,c,d);
  if ( stats )
    {
      printf ("ACC %7.4lf   pred_thresh %9.6lf\n", ((double) (a+d)) / (((double) (a+b+c+d))+eps), pred_thresh);
      printf ("PPV %7.4lf   pred_thresh %9.6lf\n", ((double) (a)) / (((double) (a+c))+eps), pred_thresh);
      printf ("NPV %7.4lf   pred_thresh %9.6lf\n", ((double) (d)) / (((double) (b+d))+eps), pred_thresh);
      printf ("SEN %7.4lf   pred_thresh %9.6lf\n", ((double) (a)) / (((double) (a+b))+eps), pred_thresh);
      printf ("SPC %7.4lf   pred_thresh %9.6lf\n", ((double) (d)) / (((double) (c+d))+eps), pred_thresh);

      printf ("\n");
      if (debug) printf ("%d %d %d %d\n", freq_a,freq_b,freq_c,freq_d);
      printf ("ACC %7.4lf   freq_thresh %9.6lf\n", ((double) (freq_a+freq_d)) / (((double) (freq_a+freq_b+freq_c+freq_d))+eps), freq_thresh);
      printf ("PPV %7.4lf   freq_thresh %9.6lf\n", ((double) (freq_a)) / (((double) (freq_a+freq_c))+eps), freq_thresh);
      printf ("NPV %7.4lf   freq_thresh %9.6lf\n", ((double) (freq_d)) / (((double) (freq_b+freq_d))+eps), freq_thresh);
      printf ("SEN %7.4lf   freq_thresh %9.6lf\n", ((double) (freq_a)) / (((double) (freq_a+freq_b))+eps), freq_thresh);
      printf ("SPC %7.4lf   freq_thresh %9.6lf\n", ((double) (freq_d)) / (((double) (freq_c+freq_d))+eps), freq_thresh);

      printf ("\n");
      if (debug) printf ("%d %d %d %d\n", max_acc_a,max_acc_b,max_acc_c,max_acc_d);
      if (max_acc_count > 1)
	printf ("Caution: %d Different Thresholds Achieved Max Accuracy.\nRun with -accplot option to see accuracy vs. threshold.\n\n", max_acc_count);
      printf ("ACC %7.4lf   max_acc_thresh %9.6lf\n", ((double) (max_acc_a+max_acc_d)) / (((double) (max_acc_a+max_acc_b+max_acc_c+max_acc_d))+eps), max_acc_thresh);
      printf ("PPV %7.4lf   max_acc_thresh %9.6lf\n", ((double) (max_acc_a)) / (((double) (max_acc_a+max_acc_c))+eps), max_acc_thresh);
      printf ("NPV %7.4lf   max_acc_thresh %9.6lf\n", ((double) (max_acc_d)) / (((double) (max_acc_b+max_acc_d))+eps), max_acc_thresh);
      printf ("SEN %7.4lf   max_acc_thresh %9.6lf\n", ((double) (max_acc_a)) / (((double) (max_acc_a+max_acc_b))+eps), max_acc_thresh);
      printf ("SPC %7.4lf   max_acc_thresh %9.6lf\n", ((double) (max_acc_d)) / (((double) (max_acc_c+max_acc_d))+eps), max_acc_thresh);

      printf ("\n");
    }

      if ( monti && !stats )
	{
	  printf ("ACC %7.4lf   pred_thresh %9.6lf\n", ((double) (a+d)) / (((double) (a+b+c+d))+eps), pred_thresh);
	  printf ("ACC %7.4lf   freq_thresh %9.6lf\n", ((double) (freq_a+freq_d)) / (((double) (freq_a+freq_b+freq_c+freq_d))+eps), freq_thresh);
	  if (max_acc_count > 1)
	    printf ("Caution: %d Different Thresholds Achieved Max Accuracy.\nRun with -accplot option to see accuracy vs. threshold.\n\n", max_acc_count);
	  printf ("ACC %7.4lf   max_acc_thresh %9.6lf\n", ((double) (max_acc_a+max_acc_d)) / (((double) (max_acc_a+max_acc_b+max_acc_c+max_acc_d))+eps), max_acc_thresh);
	}
  if ( ace )
    printf ("ACE %7.4lf\n", -1.0 * ace_sum / ((double) no_item));
  if ( breyce )
    printf ("BRE %7.4lf\n", breyce_sum / ((double) no_item));
  if ( (ace || breyce) && area && !monti )
    printf ("\n");
  if ( area )
    printf ("ROC %7.4lf\n", roc_area);
}
