package cs2110;

import cs2110.BipartiteGraph.Vertex;
import java.util.HashSet;
import java.util.LinkedList;

/**
 * Class which is utilized for finding the maximum matching on a graph with
 * fields 'graph' and 'matching'
 */
public class MaxBipartiteMatching {

    /**
     * Models an edge oriented from v1 to v2.
     */
     record Edge(Vertex v1, Vertex v2) {

        /**
         * Returns the edge oriented from v2 to v1
         */
        public Edge reverse() {
            return new Edge(v2, v1);
        }

        @Override
        public String toString() {
            return "(" + v1.label() + "," + v2.label() + ")";
        }
    }

    /**
     * A list of `Edge`s that form a contiguous path in a graph. The `v2` of one
     * `Edge` must be equal to the `v1` of the next `Edge`.
     */
    private static class Path extends LinkedList<Edge> {

    }

    /**
     * A set of bipartite graph `Edge`s oriented from their left vertex to their
     * right vertex.
     */
    public static class Matching extends HashSet<Edge> {

    }

    /**
     * Returns a maximum cardinality matching in the given bipartite `graph`.
     */
    public static Matching maxMatching(BipartiteGraph graph) {
        Matching matching = new Matching();

        for (Vertex v : graph.leftVertices()) {
            Path augmentingPath = findAugmentingPath(graph, v, matching);
            if (augmentingPath != null) {
                augment(matching, augmentingPath);
            }
        }
        return matching;
    }

    /**
     * Updates the given `matching` using the given `augmentingPath`.
     */
    private static void augment(Matching matching, Path augmentingPath) {
        for (int i = 0; i < augmentingPath.size(); i++) {
            Edge edge = augmentingPath.get(i);
            if (i % 2 == 0) {
                assert !matching.contains(edge);
                matching.add(edge);
            } else {
                assert matching.contains(edge.reverse());
                matching.remove(edge.reverse());
            }
        }
    }

    /**
     * Returns an augmenting path in the given `graph` with the given source
     * vertex `v` and current `matching`, or returns `null` if no such
     * augmenting path exists.
     */
    private static Path findAugmentingPath(BipartiteGraph graph, Vertex v, Matching matching) {
        return dfsRecursive(v, matching, new HashSet<>());
    }

    /**
     * Returns a path from the left vertex `current` to a right vertex that is not present in the
     * given `matching`, traversing visiting only un`discovered` vertices (besides `current`).
     * Returns `null` if no such path exists.
     */
    private static Path dfsRecursive(Vertex current, Matching matching, HashSet<Vertex> discovered) {
        for (Vertex neighbor : current.neighbors()) {
            if (discovered.contains(neighbor)) {
                continue; // already visited this neighbor earlier in the traversal
            }
            Path path;
            Vertex nextLeftVertex = matchOfRightVertex(neighbor, matching);
            if (nextLeftVertex == null) {
                // we've located an unmatched right vertex, ending the augmenting path
                path = new Path();
                path.add(new Edge(current, neighbor));
                return path;
            }
            // we've located a matched right vertex, so continue the search from there
            discovered.add(neighbor);
            path = dfsRecursive(nextLeftVertex, matching, discovered);
            if (path != null) { // found a path from `nextLeftVertex`
                path.addFirst(new Edge(neighbor, nextLeftVertex));
                path.addFirst(new Edge(current, neighbor));
                return path;
            }
        }
        return null;
    }

    /**
     * If `v` is the right vertex in some `Edge` in the given `matching`, then
     * the left vertex of this `Edge` is returned. Otherwise, `null` is
     * returned.
     */
    private static Vertex matchOfRightVertex(Vertex v, Matching matching) {
        for (Edge e : matching) {
            if (e.v2 == v) {
                return e.v1;
            }
        }
        return null;
    }
}
