#include <stdio.h>
#include <pthread.h>
#include <stdlib.h>
#include <stdbool.h>
#include <assert.h>
#include <sys/time.h>

typedef struct {
    int* data;
    int capacity;  // The size of the `data` array.
    int head;      // The next index to pop.
    int tail;      // The next index to push.

    pthread_mutex_t* mutex;
    bool done;

    pthread_cond_t* full_cv;
    pthread_cond_t* empty_cv;
} bounded_buffer_t;

bounded_buffer_t* bb_create(int capacity) {
    bounded_buffer_t* bb = malloc(sizeof(bounded_buffer_t));
    bb->data = malloc(sizeof(int) * capacity);
    bb->capacity = capacity;
    bb->head = 0;
    bb->tail = 0;

    bb->mutex = malloc(sizeof(pthread_mutex_t));
    pthread_mutex_init(bb->mutex, NULL);

    bb->full_cv = malloc(sizeof(pthread_mutex_t));
    pthread_cond_init(bb->full_cv, NULL);

    bb->empty_cv = malloc(sizeof(pthread_mutex_t));
    pthread_cond_init(bb->empty_cv, NULL);

    bb->done = false;
    return bb;
}

void bb_free(bounded_buffer_t* bb) {
    free(bb->data);
    pthread_mutex_destroy(bb->mutex);
    free(bb->mutex);
    pthread_cond_destroy(bb->full_cv);
    free(bb->full_cv);
    pthread_cond_destroy(bb->empty_cv);
    free(bb->empty_cv);
    free(bb);
}

int bb_size(bounded_buffer_t* bb) {
    int s = bb->tail - bb->head;
    if (s < 0) {
        s += bb->capacity;
    }
    return s;
}

bool bb_full(bounded_buffer_t* bb) {
    // We avoid using exactly `capacity` items because it creates an awkward
    // special case where we can't distinguish between completely empty and
    // full states (in both cases, head == tail). This sad restriction is not
    // *too* hard to resolve, but we keep it just to keep things simple.
    return bb_size(bb) >= bb->capacity - 1;
}

bool bb_empty(bounded_buffer_t* bb) {
    return bb_size(bb) == 0;
}

void bb_push(bounded_buffer_t* bb, int value) {
    assert(!bb_full(bb));
    bb->data[bb->tail] = value;
    bb->tail = (bb->tail + 1) % bb->capacity;
}

int bb_pop(bounded_buffer_t* bb) {
    assert(!bb_empty(bb));
    int value = bb->data[bb->head];
    bb->head = (bb->head + 1) % bb->capacity;
    return value;
}

// ANCHOR: push
void bb_block_push(bounded_buffer_t* bb, int value) {
    pthread_mutex_lock(bb->mutex);
    while (bb_full(bb)) {
        pthread_cond_wait(bb->full_cv, bb->mutex);
    }
    bb_push(bb, value);
    pthread_mutex_unlock(bb->mutex);
    pthread_cond_signal(bb->empty_cv);
}
// ANCHOR_END: push

// ANCHOR: pop
int bb_block_pop(bounded_buffer_t* bb, bool* done) {
    pthread_mutex_lock(bb->mutex);
    while (bb_empty(bb) && !bb->done) {
        pthread_cond_wait(bb->empty_cv, bb->mutex);
    }
    int value;
    if (bb->done) {
        *done = true;
        value = 0;
    } else {
        value = bb_pop(bb);
    }
    pthread_mutex_unlock(bb->mutex);
    pthread_cond_signal(bb->full_cv);
    return value;
}
// ANCHOR_END: pop

void bb_finish(bounded_buffer_t* bb) {
    pthread_mutex_lock(bb->mutex);
    while (!bb_empty(bb)) {
        pthread_cond_wait(bb->full_cv, bb->mutex);
    }
    bb->done = true;
    pthread_mutex_unlock(bb->mutex);
    pthread_cond_broadcast(bb->empty_cv);
}

//////////

bool is_prime(int n) {
    for (int i = 2; i < n; ++i) {
        if (n % i == 0) {
            return false;
        }
    }
    return true;
}

typedef struct {
    int start_number;
    int end_number;
    bounded_buffer_t *buf;
} producer_args_t;

typedef struct {
    bounded_buffer_t *buf;
    int* prime_count;
    pthread_mutex_t* mutex;
} consumer_args_t;

void* producer_thread(void* arg_in) {
    producer_args_t* arg = (producer_args_t*)arg_in;
    for (int i = arg->start_number; i < arg->end_number; ++i) {
        bb_block_push(arg->buf, i);
    }
    bb_finish(arg->buf);
    return NULL;
}

void* consumer_thread(void* arg_in) {
    consumer_args_t* arg = (consumer_args_t*)arg_in;
    while (1) {
        bool done;
        int number = bb_block_pop(arg->buf, &done);
        if (done)
            break;

        if (is_prime(number)) {
            pthread_mutex_lock(arg->mutex);
            (*(arg->prime_count))++;
            pthread_mutex_unlock(arg->mutex);
        }
    }
    return NULL;
}

int seq_primes(int start_number, int end_number) {
    int primes = 0;
    for (int i = start_number; i < end_number; ++i) {
        if (is_prime(i)) {
            primes++;
        }
    }
    return primes;
}

long time_diff(struct timeval start, struct timeval end) {
    return (end.tv_sec - start.tv_sec) * 1000000 + (end.tv_usec - start.tv_usec);
}

int prodcon_primes(int workers, int start_number, int end_number) {
    bounded_buffer_t* buf = bb_create(32);
    int primes = 0;
    pthread_mutex_t mutex = PTHREAD_MUTEX_INITIALIZER;
    pthread_t *threads = malloc((workers + 1) * sizeof(pthread_t));

    // Launch 1 producer thread.
    producer_args_t prod_args = {start_number, end_number, buf};
    pthread_create(&threads[0], NULL, producer_thread, &prod_args);

    // Launch consumer threads.
    consumer_args_t con_args = {buf, &primes, &mutex};
    for (int i = 0; i < workers; ++i) {
        pthread_create(&threads[i + 1], NULL, consumer_thread, &con_args);
    }

    // Join all threads & free resources.
    for (int i = 0; i < workers + 1; ++i) {
        pthread_join(threads[i], NULL);
    }
    bb_free(buf);

    return primes;
}

int main() {
    int start_number = 100000;
    int end_number = start_number + 7000;
    int reps = 10;

    printf("workers,us\n");

    // Run the sequential algorithm.
    int primes;
    long total_time = 0;
    for (int rep = 0; rep < reps; ++rep) {
        struct timeval start, end;
        gettimeofday(&start, NULL);
        primes = seq_primes(start_number, end_number);
        gettimeofday(&end, NULL);
        total_time += time_diff(start, end);
    }
    printf("sequential,%ld\n", total_time / reps);

    // Run the parallel algorithm, and check the answer.
    for (int threads = 1; threads < 16; ++threads) {
        long total_time = 0;
        for (int rep = 0; rep < reps; ++rep) {
            struct timeval start, end;
            gettimeofday(&start, NULL);
            int p_primes = prodcon_primes(threads, start_number, end_number);
            gettimeofday(&end, NULL);
            total_time += time_diff(start, end);
            if (primes != p_primes) {
                fprintf(
                    stderr,
                    "%d threads: got %d primes vs. %d from sequential\n",
                    threads, p_primes, primes
                );
            }
        }
        printf("%d,%ld\n", threads, total_time / reps);
    }

    return 0;
}
