#include "bignum.hpp"

#include <algorithm>
#include <cctype>
#include <iostream>

#include "exceptions.hpp"

const std::vector<char> bignum::Bignum::operators
        = {'+', '-', '*'};

bignum::Bignum::Bignum(const std::string& strnum)
        : digits(parse(strnum)) {
}

bignum::Bignum bignum::Bignum::operator+(
        const Bignum& other) const {
    Bignum sum(std::max(num_digits(), other.num_digits()) + 1);

    bool carry = false;
    uint32_t i = 0;
    uint32_t j = 0;
    uint32_t k = 0;

    auto update_carry_and_sum = [&k, &sum, &carry]() {
        carry = sum[k] / 10 == 1;
        sum[k] %= 10;
        ++k;
    };

    while(i < num_digits() && j < other.num_digits()) {
        sum[k] = digits[i++] + other[j++] + (carry ? 1 : 0);
        update_carry_and_sum();
    }

    while(i < num_digits()) {
        sum[k] = digits[i++] + (carry ? 1 : 0);
        update_carry_and_sum();
    }

    while(j < other.num_digits()) {
        sum[k] = other[j++] + (carry ? 1 : 0);
        update_carry_and_sum();
    }

    if(carry) {
        sum[k] = 1;
    } else {
        sum.digits.resize(k);
    }

    return sum;
}

bignum::Bignum bignum::Bignum::operator-(
        const Bignum& other) const {
    if(*this < other) {
        throw bignum::negative_number_error();
    }

    Bignum diff(num_digits());

    bool borrow = false;
    uint32_t j = 0;

    while(j < other.num_digits()) {
        if(digits[j] < other[j] + (borrow ? 1 : 0)) {
            diff[j] = (10 + digits[j])
                      - (other[j] + (borrow ? 1 : 0));
            borrow = true;
        } else {
            diff[j] = digits[j] - (other[j] + (borrow ? 1 : 0));
            borrow = false;
        }
        j++;
    }

    while(j < num_digits()) {
        if(digits[j] < (borrow ? 1 : 0)) {
            diff[j] = (10 + digits[j]) - (borrow ? 1 : 0);
            borrow = true;
        } else {
            diff[j] = digits[j] - (borrow ? 1 : 0);
            borrow = false;
        }
        j++;
    }

    diff.purge();
    return diff;
}

bignum::Bignum bignum::Bignum::operator*(const Bignum& other) const {
    Bignum prod(num_digits() + other.num_digits());

    const Bignum& smaller = (*this < other ? *this : other);
    const Bignum& larger = (*this < other ? other : *this);

    for(uint32_t i = 0; i < smaller.num_digits(); ++i) {
        uint32_t carry = 0;
        for(uint32_t j = 0; j < larger.num_digits(); ++j) {
            prod[i + j] += smaller[i] * larger[j] + carry;
            carry = prod[i + j] / 10;
            prod[i + j] %= 10;
        }

        for(uint32_t k = i + larger.num_digits(); carry != 0; ++k) {
            prod[k] += carry;
            carry = prod[k] / 10;
            prod[k] %= 10;
        }
    }

    prod.purge();
    return prod;
}

bool bignum::Bignum::operator<(const Bignum& other) const {
    if(num_digits() != other.num_digits()) {
        return num_digits() < other.num_digits();
    }
    return std::lexicographical_compare(
            digits.rbegin(), digits.rend(),
            other.digits.rbegin(), other.digits.rend());
}

bool bignum::Bignum::operator==(const Bignum& other) const {
    return digits == other.digits;
}

uint8_t& bignum::Bignum::operator[](uint32_t idx) {
    return digits[idx];
}

uint8_t bignum::Bignum::operator[](uint32_t idx) const {
    return digits[idx];
}

uint32_t bignum::Bignum::num_digits() const {
    return digits.size();
}

std::string bignum::Bignum::to_string() const {
    if(digits.size() == 0) {
        return "0";
    }

    std::string strnum(digits.size(), 0);
    std::transform(digits.begin(), digits.end(), strnum.rbegin(),
                   [](const uint8_t digit) {
                       return digit + '0';
                   });
    return strnum;
}

bignum::Bignum::Bignum(const uint32_t num_digits)
        : digits(num_digits) {
}

std::vector<uint8_t> bignum::Bignum::parse(
        const std::string& strnum) const {
    if(strnum.size() == 0) {
        throw bignum::parse_error(strnum);
    }
    if(strnum[0] == '0') {
        if(strnum.size() == 1) {
            return {};
        }
        throw bignum::parse_error(strnum);
    }

    std::vector<uint8_t> digits(strnum.size());
    std::transform(strnum.begin(), strnum.end(), digits.rbegin(),
                   [&strnum](const char ch) {
                       if(!std::isdigit(ch)) {
                           throw bignum::parse_error(strnum);
                       }
                       return ch - '0';
                   });

    return digits;
}

void bignum::Bignum::purge() {
    int32_t i = digits.size() - 1;
    while(i >= 0 && digits[i] == 0) {
        --i;
    }
    digits.resize(i + 1);
}
