#include "sokoban.hpp"

#include <algorithm>

sokoban::sokoban_board::sokoban_board(
        const uint32_t num_rows, const uint32_t num_cols,
        const std::vector<bool>& texture)
        : num_rows(num_rows),
          num_cols(num_cols),
          texture(texture) {
}

sokoban::sokoban_board sokoban::sokoban_board::clone(
        const std::vector<uint32_t>& unreachable) const {
    std::vector<bool> texture(this->texture);
    for(const uint32_t pos : unreachable) {
        texture[pos] = false;
    }
    return sokoban_board(num_rows, num_cols, texture);
}

bool sokoban::sokoban_board::operator()(
        const uint32_t row_num, const uint32_t col_num) const {
    return texture[row_num * num_cols + col_num];
}

uint32_t sokoban::sokoban_board::size() const {
    return num_rows * num_cols;
}

sokoban::sokoban_state::sokoban_state(
        const sokoban_board& board,
        const std::vector<uint32_t>& box_positions,
        const uint32_t player_position,
        const std::list<move_t>& trace)
        : board(board),
          box_positions(box_positions),
          player_position(player_position),
          trace(trace) {
    std::sort(this->box_positions.begin(), this->box_positions.end());
}

sokoban::sokoban_state::sokoban_state(
        const sokoban_state& other)
        : board(other.board),
          box_positions(other.box_positions),
          player_position(other.player_position),
          trace(other.trace) {
}

bool sokoban::sokoban_state::operator<(
        const sokoban_state& other) const {
    return (player_position == other.player_position)
                   ? std::lexicographical_compare(
                           box_positions.begin(), box_positions.end(),
                           other.box_positions.begin(), other.box_positions.end())
                   : (player_position < other.player_position);
}

sokoban::sokoban_state sokoban::sokoban_state::make_move(
        const move_t& move) const {
    sokoban_state new_state(*this);
    new_state.player_position = box_positions[move.first];
    new_state.box_positions[move.first] = move.second;
    new_state.trace.push_back({new_state.player_position,
                               move.second});
    std::sort(new_state.box_positions.begin(),
              new_state.box_positions.end());
    return new_state;
}

std::vector<sokoban::move_t> sokoban::sokoban_state::find_all_moves() const {
    sokoban_board board(this->board.clone(box_positions));
    std::vector<bool> reachability_mask(board.size(), false);
    std::queue<uint32_t> positions;

    positions.push(player_position);

    while(!positions.empty()) {
        // int32_t to guard against underflow
        const int32_t pos = (int32_t)positions.front();
        positions.pop();
        reachability_mask[pos] = true;

        // down
        {
            const int32_t new_pos = pos + board.num_cols;
            if(new_pos < (int32_t)board.size()) {
                if(!reachability_mask[new_pos]
                   && board.texture[new_pos]) {
                    reachability_mask[new_pos] = true;
                    positions.push((uint32_t)new_pos);
                }
            }
        }

        // up
        {
            // int because of underflow
            const int32_t new_pos = pos - board.num_cols;
            if(new_pos >= 0) {
                if(!reachability_mask[new_pos]
                   && board.texture[new_pos]) {
                    reachability_mask[new_pos] = true;
                    positions.push((uint32_t)new_pos);
                }
            }
        }

        // right
        {
            const int32_t new_pos = pos + 1;
            if(new_pos < (int32_t)board.size()) {
                if(!reachability_mask[new_pos]
                   && board.texture[new_pos]) {
                    reachability_mask[new_pos] = true;
                    positions.push((uint32_t)new_pos);
                }
            }
        }

        // left
        {
            // int because of underflow
            const int32_t new_pos = pos - 1;
            if(new_pos >= 0) {
                if(!reachability_mask[new_pos]
                   && board.texture[new_pos]) {
                    reachability_mask[new_pos] = true;
                    positions.push((uint32_t)new_pos);
                }
            }
        }
    }

    std::vector<move_t> moves;
    for(uint32_t i = 0; i < box_positions.size(); ++i) {
        const int32_t pos = (int32_t)box_positions[i];
        // down or up
        {
            const int32_t pos_up = pos - board.num_cols;
            const int32_t pos_down = pos + board.num_cols;

            if(pos_up >= 0 && pos_down < (int32_t)board.size()) {
                if(reachability_mask[pos_up]
                   && board.texture[pos_down]) {
                    moves.push_back({i, (uint32_t)pos_down});
                }
                if(reachability_mask[pos_down]
                   && board.texture[pos_up]) {
                    moves.push_back({i, (uint32_t)pos_up});
                }
            }
        }

        // right or left
        {
            const int32_t pos_left = pos - 1;
            const int32_t pos_right = pos + 1;

            if(pos_right >= 0 && pos_left < (int32_t)board.size()) {
                if(reachability_mask[pos_left]
                   && board.texture[pos_right]) {
                    moves.push_back({i, (uint32_t)pos_right});
                }
                if(reachability_mask[pos_right]
                   && board.texture[pos_left]) {
                    moves.push_back({i, (uint32_t)pos_left});
                }
            }
        }
    }

    return moves;
}

bool sokoban::sokoban_state::is_final_state(
        const std::vector<uint32_t>& targets) const {
    return box_positions == targets;
}

bool sokoban::sokoban_state::is_dead(
        const std::vector<uint32_t>& targets) const {
    // corner check
    for(const uint32_t pos : box_positions) {
        if(std::find(targets.begin(), targets.end(), pos)
           != targets.end()) {
            continue;
        }
        const int32_t pos_up = pos - board.num_cols;
        const int32_t pos_down = pos + board.num_cols;
        const int32_t pos_left = pos - 1;
        const int32_t pos_right = pos + 1;

        const bool blocked_vertically
                = (pos_up < 0 || !board.texture[pos_up])
                  || (pos_down < 0 || !board.texture[pos_down]);
        const bool blocked_horizontally
                = (pos_left < 0 || !board.texture[pos_left])
                  || (pos_right < 0 || !board.texture[pos_right]);

        if(blocked_vertically && blocked_horizontally) {
            return true;
        }
    }

    return false;
}

sokoban::sokoban_board sokoban::sokoban_state::copy_board() const {
    return board.clone();
}

sokoban::sokoban_state sokoban::sokoban_state::reinitialize(
        const sokoban_board& board) const {
    return sokoban_state(board, box_positions,
                         player_position, trace);
}

std::list<sokoban::move_t> sokoban::sokoban_state::get_trace() {
    return trace;
}

sokoban::sokoban_solver::sokoban_solver(
        const sokoban_state& initial_state,
        const std::vector<uint32_t>& targets)
        : board(initial_state.copy_board()),
          targets([](std::vector<uint32_t> targets) {
              std::sort(targets.begin(), targets.end());
              return targets;
          }(targets)) {
    states_to_explore.push(initial_state.reinitialize(board));
}

std::optional<std::list<sokoban::move_t>> sokoban::sokoban_solver::solve() {
    while(!states_to_explore.empty()) {
        sokoban_state cur_state(states_to_explore.front());
        states_to_explore.pop();

        for(const move_t& move : cur_state.find_all_moves()) {
            sokoban_state new_state(cur_state.make_move(move));
            if(new_state.is_dead(targets)) {
                continue;
            }
            if(new_state.is_final_state(targets)) {
                return new_state.get_trace();
            }
            if(visited.find(new_state) == visited.end()) {
                states_to_explore.push(new_state);
                visited.insert(new_state);
            }
        }
    }

    return {};
}
