package cs2110;

import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;

public class DijkstraDemo {

    /**
     * Auxiliary vertex properties relevant to unweighted shortest paths, the best known `distance`
     * of this vertex from the source `s` (i.e., the length of its shortest known `s -> v` path),
     * and the label of the `prev` vertex on this shortest known `s -> v` path.
     */
    public record PathInfo(double distance, String prev) {
        @Override
        public String toString() {
            return "{distance = " + distance + ", prev = " + prev + "}";
        }
    }

    /**
     * An implementation of Dijkstra's shortest path algorithm for a directed graph with non-negative
     * edge weights. Returns a map associating each vertex `v` in this graph reachable from the given
     * `source` vertex with a `PathInfo` object containing the shortest `source -> v` path `distance`
     * and the `prev` vertex in this shortest path.
     */
    public static <V extends Vertex<? extends WeightedEdge<V>>> Map<String, PathInfo> dijkstra(V source) {

        // Queue of discovered vertices that have not yet been visited
        DemoDynamicPriorityQueue<V> frontier = new DemoDynamicPriorityQueue<>();

        // Map associating PathInfo to all discovered vertices
        Map<String, PathInfo> discovered = new LinkedHashMap<>();

        discovered.put(source.label(), new PathInfo(0, null)); // s does not have a "prev" vertex
        frontier.add(source, 0); // s has distance 0 from itself

        while(!frontier.isEmpty()) {
            V v = frontier.remove(); // v is frontier element that is closest to s

            for (WeightedEdge<V> edge : v.outgoingEdges()) { // iterate over v's neighbors
                V neighbor = edge.head();
                double dist = discovered.get(v.label()).distance() + edge.weight(); // neighbor's shortest-path distance via v
                if (!discovered.containsKey(neighbor.label())) { // neighbor is first discovered
                    discovered.put(neighbor.label(), new PathInfo(dist, v.label())); // add to discovered map
                    frontier.add(neighbor, dist); // add to frontier priority queue
                } else if (discovered.get(neighbor.label()).distance() > dist)  { // we found a shorter path to neighbor
                    discovered.put(neighbor.label(), new PathInfo(dist, v.label())); // update discovered map
                    frontier.updatePriority(neighbor, dist); // update priority to reflect this new distance
                }
            }
        }

        return discovered;
    }

    /**
     * Reconstructs and returns the shortest path from the vertex with label `srcLabel` to the
     * vertex with label `dstLabel` using the given `info` map produced by Dijkstra's Algorithm.
     */
    public static List<String> reconstructPath(Map<String,PathInfo> info, String srcLabel, String dstLabel) {
        List<String> path = new LinkedList<>();
        path.add(dstLabel);
        while (!path.getFirst().equals(srcLabel)) {
            path.addFirst(info.get(path.getFirst()).prev());
        }
        return path;
    }

    /**
     * An edge in a directed edge labeled with a non-negative double weight.
     */
    public record DWeightedEdge(AdjListVertex<DWeightedEdge> tail, AdjListVertex<DWeightedEdge> head, double weight)
            implements WeightedEdge<AdjListVertex<DWeightedEdge>> {
    }

    /**
     * Constructs a new unweighted edge from the vertex with label `tailLabel` to the vertex with
     * label `headLabel` in `graph` with the given `weight`. Throws an `IllegalArgumentException`
     * if one of these endpoint vertices does not exist in the graph, if there is already an
     * edge between them, or if `weight < 0`.
     */
    private static DWeightedEdge makeEdge(AdjListGraph<DWeightedEdge> graph, String tailLabel,
            String headLabel, double weight) {
        if (!graph.hasVertex(tailLabel)) {
            throw new IllegalArgumentException("No vertex labeled \"" + tailLabel + "\" in this graph.");
        }
        if (!graph.hasVertex(headLabel)) {
            throw new IllegalArgumentException("No vertex labeled \"" + headLabel + "\" in this graph.");
        }
        if (graph.hasEdge(tailLabel, headLabel)) {
            throw new IllegalArgumentException("Graph already has edge from \"" + tailLabel
                    + "\" to \"" + headLabel + "\".");
        }
        if (weight < 0) {
            throw new IllegalArgumentException("Edge weights must be non-negative");
        }
        return new DWeightedEdge(graph.getVertex(tailLabel), graph.getVertex(headLabel), weight);
    }

    /**
     * Returns the sequence of vertices in the given `path`, separated by " -> ".
     */
    public static String printPath(List<String> path) {
        StringBuilder sb = new StringBuilder();
        Iterator<String> it = path.iterator();
        while (it.hasNext()) {
            sb.append(it.next());
            if (it.hasNext()) {
                sb.append(" -> ");
            }
        }
        return sb.toString();
    }

    public static void main(String[] args) {
        AdjListGraph<DWeightedEdge> graph = new AdjListGraph<>();

        graph.addVertex("s");
        graph.addVertex("a");
        graph.addVertex("b");
        graph.addVertex("c");
        graph.addVertex("d");
        graph.addVertex("t");
        graph.addEdge(makeEdge(graph,"s","a", 4));
        graph.addEdge(makeEdge(graph,"s","c", 5));
        graph.addEdge(makeEdge(graph,"a","b", 5));
        graph.addEdge(makeEdge(graph,"a","d", 4));
        graph.addEdge(makeEdge(graph,"b","t", 5));
        graph.addEdge(makeEdge(graph,"c","a", 3));
        graph.addEdge(makeEdge(graph,"c","b", 2));
        graph.addEdge(makeEdge(graph,"c","d", 1));
        graph.addEdge(makeEdge(graph,"d","b", 2));
        graph.addEdge(makeEdge(graph,"d","t", 7));

        Map<String, PathInfo> map = dijkstra(graph.getVertex("s"));

        for(String vLabel : map.keySet()) {
            System.out.println(vLabel + ": " + map.get(vLabel));
            System.out.println("Shortest Path: " + printPath(reconstructPath(map, "s",vLabel)));
        }
    }
}