
structure Unify : UNIFY = struct

(* Here unification and substitution are one-sided only.
 * In unification, variables or constants may occur
 * in the left term, but the right is always ground. *)

(* Published theorems contain only (universally quantified)
 * variables.  Tasks on the other hand are substitution
 * instances of these and are always ground. *)

  open Util
  open Term


  fun substToString (s:substitution) : string =
    "[" ^ concat (mapAllButLast (fn x => x ^ " ") (map (fn(x,t) => x ^ "=" ^ toString t) s)) ^ "]"

  fun substToStringWC (s:substitution) : string =
    concat (mapAllButLast (fn x => x ^ ",") (map (fn(x,t) => x ^ " = " ^ toString t) s))

  fun substStructure (s:substitution) : string =
    "[" ^ concat (mapAllButLast (fn x => x ^ " ") (map (fn(x,t) => x ^ "=" ^ termStructure t) s)) ^ "]"

  (* print a substitution *)
  fun printSubst(s:substitution) =
    ((!print) "[ ";
     app (fn(x,t) => ((!print) (x ^ "=" ^ toString t ^ " "))) s;
     (!printtext) "]")


  (* check if a variable occurs in a term *)
  fun occurs (x:id) (t:term) : bool =
    case t of
      (TST_VAR y | ACT_VAR y) => x = y
    | (TST_CON _ | ACT_CON _ | ZERO | ONE) => false
    | (PLUS u | TIMES u) => List.exists (occurs x) u
    | (NOT u | STAR u) => occurs x u

  (* substitute a term for all occurrences of a variable *)
  fun subst (p as (x:id, s:term)) (t:term) : term = let
    fun subst' (t:term) : term =
      case t of
        (TST_VAR y | ACT_VAR y) => if x = y then s else t
      | PLUS u => PLUS (map subst' u)
      | TIMES u => TIMES (map subst' u)
      | NOT u => NOT (subst' u)
      | STAR u => STAR (subst' u)
      | _ => t
  in
    flatten(subst' t)
  end

  (* apply substitutions in parallel *)
  fun applyInParallel (s:substitution) (t:term) : term = let
    fun subst' (t:term) : term =
      case t of
        (TST_VAR y | ACT_VAR y) =>
          (case lookup y s of
            SOME u => u
          | NONE => t)
      | PLUS u => PLUS (map subst' u)
      | TIMES u => TIMES (map subst' u)
      | NOT u => NOT (subst' u)
      | STAR u => STAR (subst' u)
      | _ => t
  in
    flatten(subst' t)
  end

  (* compose substitutions -- apply t to s *)
  fun compose (s:substitution) (t:substitution) : substitution =
    map (fn (x,u) => (x, applyInParallel t u)) s

  (* apply substitutions right to left *)
  fun apply (s:substitution) (t:term) : term =
    foldr (fn(r,u) => subst r u) t s

  (* associate elements of a list of length m into n
   * nonmempty sublists, n <= m, in all possible ways *)
  fun associate (s:'a list, n:int) : 'a list list list =
    if n = 0 then [] else
    if List.length s < n then [] else
    if n = 1 then [[s]] else
    case s of
      [] => []
    | (x::t) => let
        val t1 = associate(t,n-1)
        val t2 = associate(t,n)
        val s1 = map (fn u => [x]::u) t1
        val s2 = map (fn u => (x::hd u) :: tl u) t2
    in
      s1 @ s2
    end

  (* unify a pair *)
  (* returns a list of possible unifications -- no most general
   * unifier because of associativity of PLUS and TIMES *)
   fun unify' (s:term, t:term) : substitution list =
    case (s,t) of
      ((ZERO,ZERO) | (ONE,ONE)) => [[]]
    | ((PLUS sc,PLUS tc) | (TIMES sc,TIMES tc)) =>
        let
          val s1 = associate(tc,List.length sc)
          val s2 = map (fn y => map (fn x => flatten ((headsymbol s) x)) y) s1
          val s3 = map (fn x => ListPair.zip(sc,x)) s2
          val s4 = map unifyList' s3
        in
          removeDuplicates(map sortByKey (List.concat s4))
        end
    | ((NOT sc,NOT tc) | (STAR sc,STAR tc)) => unify'(sc,tc)
    | (TST_VAR x,_) => if isTest t then [[(x,t)]] else []
    | (ACT_VAR x,_) => [[(x,t)]]
    | (ACT_CON x,ACT_CON y) =>  if x = y then [[]] else []
    | (TST_CON x,TST_CON y) =>  if x = y then [[]] else []
    | _ => []

  (* unify a list of pairs *)
  and unifyList' (s:(term * term) list) : substitution list =
    case s of [] => [[]]
    | (x,y)::t =>
      let
        val th2 = unifyList' t
        val th1 = map (fn z => unify'(apply z x,y)) th2
        val all = ListPair.map (fn(x,u) => map (fn y => y@u) x) (th1,th2)
      in
        removeDuplicates(map sortByKey (List.concat all))
      end

  (* resolve ambiguities by user input *)
  (* precondition: length of s >= 2 *)
  (* This one is for CLI, may be updated for other UIs *)
  val resolveSubst = ref (let fun resolve (s:substitution list) = let
    val z = List.tabulate(length s,Int.toString)
    val _ = println "ambiguous unification";
    val _ = println "specify desired bindings by number:\n";
    val _ = ListPair.app (fn(n,x) => ((!print) (n^": "); printSubst x)) (z,s);
    val inputLine = ((!print) "? "; TextIO.inputLine TextIO.stdIn)
    val tokenizedInput = String.tokens Char.isSpace inputLine
    val index =
      case tokenizedInput of
        [x] => Int.fromString x
      | _ => NONE
    val unif = (SOME (List.nth(s,valOf index))
      handle Subscript => NONE | Option => NONE)
  in
    case unif of
      SOME x => SOME x
    | NONE =>
       ((println) "invalid number, try again";
        (resolve s))
  end
  in
    fn (s:substitution list) => resolve(s)
  end)

  (* unify a pair, return a single unifier *)
  (* ask user to resolve ambiguities if necessary *)
  fun unify (s:term, t:term) : substitution option =
    case unify'(s,t) of
      [x] => SOME x
    | x::t => ((!resolveSubst) (x::t))
    | _ => NONE  

  fun unifyNoRes (s:term, t:term) : substitution option =
    case unify'(s,t) of
      x::t => SOME x
    | _ => NONE  

  (* unify a list of pairs, return a single unifier *)
  (* ask user to resolve ambiguities if necessary *)
  fun unifyList (s:(term * term) list) : substitution option =
    case unifyList' s of
      [x] => SOME x
    | x::t => ((!resolveSubst) (x::t))
    | _ => NONE

  fun unifyListNoRes (s:(term * term) list) : substitution option =
    case unifyList' s of
      (x::xs) => SOME x
    | _ => NONE

  fun unifyEqn(e1:eqn,e2:eqn) : substitution option =
    case (e1,e2) of
      (EQ(s1,t1),EQ(s2,t2)) => unifyList [(s1,s2),(t1,t2)]
    | (LE(s1,t1),LE(s2,t2)) => unifyList [(s1,s2),(t1,t2)]
    | _ => NONE

  fun unifyEqnNoRes(e1:eqn,e2:eqn) : substitution option =
    case (e1,e2) of
      (EQ(s1,t1),EQ(s2,t2)) => unifyListNoRes [(s1,s2),(t1,t2)]
    | (LE(s1,t1),LE(s2,t2)) => unifyListNoRes [(s1,s2),(t1,t2)]
    | _ => NONE


  fun unifyListOfEqn(f1:eqn list,f2:eqn list) : substitution option = let
    fun split(e1:eqn list,e2:eqn list) : (term * term) list option =
      case (e1,e2) of
        ([],[]) => SOME []
      | ((EQ(s1,t1)::u1,EQ(s2,t2)::u2) | (LE(s1,t1)::u1,LE(s2,t2)::u2))
          => Option.map (fn t => (s1,s2)::(t1,t2)::t) (split(u1,u2))
      | _ => NONE
  in
    case split(f1,f2) of
      NONE => NONE
    | SOME t => unifyList t
  end

  fun applyToEqn (s:substitution) (e:eqn) : eqn =
    case e of
      EQ(u,v) => EQ(apply s u,apply s v)
    | LE(u,v) => LE(apply s u,apply s v)

  fun applyToCondEqn (s:substitution) ((pre,con):cond_eqn) : cond_eqn =
    (map (applyToEqn s) pre, applyToEqn s con)

end