/* lec06mc.c --
 *
 * This code is a prototype Monte Carlo computation (though right now
 * it simply computes the expected value of the uniform generator, which
 * is a little silly).  It has the following interesting features:
 *
 * 1.  The pseudorandom numbers are generated by independently-seeded
 *     instances of the Mersenne twister RNG (where the seeds are
 *     generated on a single thread via the system random() function).  
 *     Note that this generator is thread-safe because the state
 *     variable is an explicit argument at each step.  This is not
 *     always the case!  Also, note that the random number generator
 *     is often the most subtle part of a parallel Monte Carlo code.
 *
 * 2.  The code uses adaptive error estimation to terminate as soon as
 *     it has enough data to get the 1-sigma error bars below some relative
 *     tolerance.  Unlike an a priori decision (i.e. "run a million trials
 *     and then take stock"), this termination criterion involves some
 *     coordination between the threads.  The coordination can be made
 *     relatively inexpensive by only updating global counts after doing
 *     a large enough batch on each thread.
 *
 * 3.  Timing is done using the gettimeofday function, which returns the
 *     wall clock time (as opposed to the CPU time for a particular
 *     process or thread).
 *
 * 4.  The code uses the getopt library to process the arguments.  While
 *     this has nothing in particular to do with the numerics or the parallel
 *     operation, it's still a good thing to know about.
 * 
 * In timing experiments on my laptop, this code gets very good speedup
 * on two processors (as it should).
 */

#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <pthread.h>
#include <math.h>
#include <sys/time.h>
#include "mt19937p.h"

#define NPTS 10

/* Parameters for termination criterion */
double rtol      = 1e-2;
long   maxtrials = 1000000;
int    nbatch    = 500;

/* Monte Carlo results */
double all_sum_X   = 0;
double all_sum_X2  = 0;
long   all_ntrials = 0;

/* Lock on MC results */
pthread_mutex_t counts_lock;


int is_converged(double rtol, long maxtrials)
{
    double EX  = all_sum_X / all_ntrials;
    double EX2 = all_sum_X2 / all_ntrials;
    double varX   = EX2-EX*EX;
    return (varX/(EX*EX) < rtol*rtol || all_ntrials > maxtrials);
}


double run_trial(struct mt19937p* mt)
{
    int i, j;
    double xx[2][NPTS];
    double d2 = -1;
    for (i = 0; i < NPTS; ++i) {
        double xi = genrand(mt);
        double yi = genrand(mt);
        xx[0][i] = xi;
        xx[1][i] = yi;
        for (j = 0; j < i; ++j) {
            double dxj = xx[0][j]-xi;
            double dyj = xx[1][j]-yi;
            double dij2 = dxj*dxj + dyj*dyj;
            if (d2 < 0 || dij2 < d2)
                d2 = dij2;
        }
    }
    return sqrt(d2);
}


void* thread_main(void* arg)
{
    struct mt19937p mt;
    long seed = (*(long*) arg);
    const int tnbatch = nbatch;
    int done_flag = 0;
    sgenrand(seed, &mt);

    do {

        /* Run batch of experiments */
        int t;
        double sum_X = 0;
        double sum_X2 = 0;
        for (t = 0; t < tnbatch; ++t) {
            double X = run_trial(&mt);
            sum_X += X;
            sum_X2 += X*X;
        }

        /* Update global counts and test for termination */
        pthread_mutex_lock(&counts_lock);
        done_flag = (done_flag || is_converged(rtol, maxtrials));
        all_sum_X += sum_X;
        all_sum_X2 += sum_X2;
        all_ntrials += tnbatch;
        done_flag = (done_flag || is_converged(rtol, maxtrials));
        pthread_mutex_unlock(&counts_lock);

    } while (!done_flag);
    return NULL;
}


int process_args(int argc, char** argv)
{
    int nthreads = 1;
    int c;
    while ((c = getopt(argc, argv, "p:t:n:b:")) != -1) {
        switch (c) {
        case 'p':
            nthreads = atoi(optarg);
            if (nthreads <= 0 || nthreads > 32) {
                fprintf(stderr, "nthreads must be in [1,32]\n");
                exit(-1);
            }
            break;
        case 't':
            rtol = atof(optarg);
            if (rtol < 0) {
                fprintf(stderr, "rtol must be positive\n");
                exit(-1);
            }
            break;
        case 'n':
            maxtrials = atol(optarg);
            if (maxtrials < 1) {
                fprintf(stderr, "maxtrials must be positive\n");
                exit(-1);
            }
            break;
        case 'b':
            nbatch = atoi(optarg);
            if (nbatch < 1) {
                fprintf(stderr, "nbatch must be positive\n");
                exit(-1);
            }
            break;
        case '?':
            if (optopt == 'p' || optopt == 't' || 
                optopt == 'n' || optopt == 'b')
                fprintf(stderr, "Option -%c requires argument\n", optopt);
            else 
                fprintf(stderr, "Unknown option '-%c'.\n", optopt);
            exit(-1);
            break;
        }
    }
    if (optind < argc) {
        fprintf(stderr, "No non-option arguments allowed\n");
        exit(-1);
    }
    return nthreads;
}


int main(int argc, char** argv)
{
    int nthreads = process_args(argc, argv);
    long seeds[32];
    pthread_t threads[32];
    int i;
    double EX, EX2, stdX, t_elapsed;
    struct timeval t1, t2;
    srandom(clock());

    /* Run parallel experiments on nthreads threads */
    gettimeofday(&t1, NULL);
    pthread_mutex_init(&counts_lock, NULL);
    for (i = 1; i < nthreads; ++i) {
        seeds[i] = random();
        pthread_create(&threads[i], NULL, thread_main, (void*)(seeds+i));
    }
    seeds[0] = random();
    thread_main((void*) &seeds);
    for (i = 1; i < nthreads; ++i)
        pthread_join(threads[i], NULL);
    pthread_mutex_destroy(&counts_lock);
    gettimeofday(&t2, NULL);

    /* Compute expected value and 1 sigma error bars */
    EX   = all_sum_X / all_ntrials;
    EX2  = all_sum_X2 / all_ntrials;
    stdX = sqrt((EX2-EX*EX) / all_ntrials);
    
    /* Output value, error bar, and elapsed time */
    t_elapsed = (t2.tv_sec-t1.tv_sec) + (t2.tv_usec-t1.tv_usec)/1.0e6;
    printf("%d threads (pthreads): %g (%g): %e s\n", 
           nthreads, EX, stdX, t_elapsed);

    return 0;
}
