structure FOUnify = struct

open SLang
open FirstOrder

(* List of unifications *)
val subList:(loc * C) list ref = ref []

(* List of variables not unified (when dealing with s[x/t]) *)
val missingList: loc list ref = ref []

(* Unify commands *)
fun unifyA(genTerm,fromKAT) = 
  case genTerm of
    N(n) => (case fromKAT of N(m) => () | _ => raise Fail "Not able to unify")
  
  | Loc(loc) => (case lookup loc (!subList) of
                  NONE => subList := (loc,Math(varToGVarA(fromKAT)))::(!subList)
                | SOME(Math(x)) => if x = varToGVarA(fromKAT) then ()
                             else raise Fail "Not able to unify"
                | _ => raise Fail "Error in unifyA!")
  | RepA(Loc(a),Loc(b),Loc(c)) => (case lookup (Sub(a,b,c)) (!subList) of
              NONE => subList := (Sub(a,b,c),Math(fromKAT))::(!subList)
            | SOME(Math(x)) => if x = fromKAT then ()
                         else raise Fail "Not able to unify"
            | _ => raise Fail "Error in unifyA!")
  | RepA(a as RepA(Loc(a1),Loc(b1),Loc(c1)),Loc(b),Loc(c)) =>
        (case lookup (Sub(Sub(a1,b1,c1),b,c)) (!subList) of
              NONE => subList := (Sub(Sub(a1,b1,c1),b,c),Math(fromKAT))::(!subList)
            | SOME(Math(x)) => if x = fromKAT then ()
                         else raise Fail "Not able to unify"
            | _ => raise Fail "Error in unifyA!")


  | RepA(Loc(a),Loc(b),RepA(Loc(a1),Loc(b1),Loc(c1))) =>
        (case lookup (Sub(a,b,Sub(a1,b1,c1))) (!subList) of
              NONE => subList := (Sub(a,b,(Sub(a1,b1,c1))),Math(fromKAT))::(!subList)
            | SOME(Math(x)) => if x = fromKAT then ()
                         else raise Fail "N`ot able to unify"
            | _ => raise Fail "Error in unifyA!")
     
  | Times(a1,a2) => (case fromKAT of Times(a1',a2') => (unifyA(a1,a1');unifyA(a2,a2')) | _ => raise Fail "Not able to unify")
  | Plus(a1,a2) => (case fromKAT of Plus(a1',a2') => (unifyA(a1,a1');unifyA(a2,a2')) | _ => raise Fail "Not able to unify")
  | Minus(a1,a2) => (case fromKAT of Minus(a1',a2') => (unifyA(a1,a1');unifyA(a2,a2')) | _ => raise Fail "Not able to unify")
  | Div(a1,a2) => (case fromKAT of Div(a1',a2') => (unifyA(a1,a1');unifyA(a2,a2')) | _ => raise Fail "Not able to unify")
  | Mod(a1,a2) => (case fromKAT of Mod(a1',a2') => (unifyA(a1,a1');unifyA(a2,a2')) | _ => raise Fail "Not able to unify")



  | _ => raise Fail "Not able to unify" 

fun unifyB(genTerm,fromKAT) =
  case (genTerm,fromKAT) of
    ((True,_) | (False,_)) => raise Fail "Not able to unify"
  | (Equals(a1,b1),Equals(a2,b2)) => (unifyA(a1,a2);unifyA(b1,b2))
  | (LessEq(a1,b1),LessEq(a2,b2)) => (unifyA(a1,a2);unifyA(b1,b2))
  | (GreEq(a1,b1),GreEq(a2,b2)) => (unifyA(a1,a2);unifyA(b1,b2))
  | (Less(a1,b1),Less(a2,b2)) => (unifyA(a1,a2);unifyA(b1,b2))
  | (Gre(a1,b1),Gre(a2,b2)) => (unifyA(a1,a2);unifyA(b1,b2))
  | (Not(b1),Not(b2)) => unifyB(b1,b2)
  | (And(a1,b1),And(a2,b2)) => (unifyB(a1,a2);unifyB(b1,b2))
  | (Or(a1,b1),Or(a2,b2)) => (unifyB(a1,a2);unifyB(b1,b2))
  | _ => raise Fail "Not able to unify"

fun isBooleans(term) = case term of
  Boolean(_) => true
| Seq([s1]) => isBooleans(s1)
| Seq(s1::s2) => isBooleans(s1) andalso isBooleans (Seq(s2))
| _ => false

fun maxBoolean(term) = let
  fun lastBoolean n term = case term of
    Boolean(_) => n+1
  | Seq(Boolean(_)::s2) => lastBoolean (n+1) (Seq(s2))
  | _ => n
in
  lastBoolean 0 term
end

fun unifyC(genTerm,fromKAT) =
  case genTerm of
    Phi => (case lookup LPhi (!subList) of
              NONE => (if isBooleans(fromKAT) then
                       subList := (LPhi,varToGVar(flatten(fromKAT)))::(!subList) 
                       else raise Fail "Not able to unify Phi")
            | SOME(x) => if x = fromKAT then ()
                                    else raise Fail "Not able to unify Phi")
  | Pre => (case lookup LPre (!subList) of
              NONE => (if isBooleans(fromKAT) then
                       subList := (LPre,varToGVar(fromKAT))::(!subList) 
                       else raise Fail "Not able to unify Pre")
            | SOME(x) => if x = fromKAT then ()
                                    else raise Fail "Not able to unify Pre")
  | Seq(Pre::Phi::s2) => (case fromKAT of
                     Seq(s1a) => (let val m = maxBoolean(fromKAT)
                                      val bs = if m > 1 
                                               then flatten(Seq(List.take(s1a,m-1)))
                                               else flatten(Seq(List.take(s1a,m)))
                                      val cs = if m > 1 
                                               then flatten(Seq(List.drop(s1a,m-1)))
                                               else flatten(Seq(List.drop(s1a,m)))
                                  in
                                   if m > 1 
                                   then (unifyC(Pre,bs); unifyC(Seq(Phi::s2),cs))
                                   else (unifyC(Pre,bs); unifyC(Phi,bs);
                                        unifyC(flatten(Seq(s2)),cs))
                                  end)
                   | _ => raise Fail "Not able to unify Pre/Phi Seq")

  | Seq(Phi::s2) => (case fromKAT of
                     Seq(s1a) => (let val m = maxBoolean(fromKAT)
                                      val bs = Seq(List.take(s1a,m))
                                      val cs = if (List.length(s1a) - m) > 1
                                               then flatten(Seq(List.drop(s1a,m)))
                                               else hd(rev(s1a))
                                  in
                                    unifyC(Phi,bs); 
                                    if(s2 <> nil) then unifyC(flatten(Seq(s2)),cs)
                                    else ()
                                  end)
                   | _ => raise Fail "Not able to unify Phi Seq")

  | Seq(Rep(Phi,a,b)::s2) => (case fromKAT of
                     Seq(s1a) => (let val m = maxBoolean(fromKAT)
                                      val bs = Seq(List.take(s1a,m))
                                      val cs = if (List.length(s1a) - m) > 1
                                               then Seq(List.drop(s1a,m))
                                               else hd(rev(s1a))
                                  in
                                    unifyC(Rep(Phi,a,b),bs); unifyC(Seq(s2),cs)
                                  end)
                   | _ => raise Fail "Not able to unify Rep Phi Seq")


  | Seq(Pre::Rep(Phi,a,b)::s2) => (case fromKAT of
                     Seq(s1a) => (let val m = maxBoolean(fromKAT)
                                      val bs = if m > 1 
                                               then flatten(Seq(List.take(s1a,m-1)))
                                               else flatten(Seq(List.take(s1a,m)))
                                      val cs = if m > 1 
                                               then flatten(Seq(List.drop(s1a,m-1)))
                                               else flatten(Seq(List.drop(s1a,m)))
                                  in
                                   if m > 1 
                                   then (unifyC(Pre,bs); unifyC(Seq(Rep(Phi,a,b)::s2),cs))
                                   else (unifyC(Pre,bs); unifyC(Rep(Phi,a,b),bs);
                                        unifyC(Seq(s2),cs))
                                  end)
                   | _ => raise Fail "Not able to unify Pre Rep Phi Seq")

  | Seq(Pre::s2) => (case fromKAT of
                     Seq(s1a) => (let val m = maxBoolean(fromKAT)
                                      val bs = Seq(List.take(s1a,m))
                                      val cs = Seq(List.drop(s1a,m))
                                  in
                                    unifyC(Pre,bs); unifyC(Seq(s2),cs)
                                  end)
                   | _ => raise Fail "Not able to unify Pre Seq")

  | Seq(s1::s2) => (case (fromKAT,s2) of
                   (_,nil) => unifyC(s1,fromKAT)
                 | (Seq(s1a::s2a),_) => (unifyC(s1,s1a); unifyC(flatten(Seq(s2)),flatten(Seq(s2a))))
                    | _ => raise Fail "Not able to unify Seq")
  | Seq(nil) => (case fromKAT of 
                   Seq(nil) => ()
                 | _ => raise Fail ("Not able to unify Nil Seq"^(cTermToString(fromKAT))))
(*  | Seq([s1]) => (case fromKAT of
                 Seq([s1a]) => unifyC(s1,s1a)
                | _ => raise Fail "Not able to unify")*)
  | Boolean(b1) => (case fromKAT of
                     Boolean(b2) => (unifyB(b1,b2))
                   | _ => raise Fail "Not able to unify Boolean")
  | Assign(x as Var(_),s) => (case fromKAT of
                     Assign(Var(y),t) => 
                       ((case lookup x (!subList) of
                          NONE => subList := (x,Math(Loc(GVar(y))))::(!subList)
                        | SOME(Math(Loc(z))) => if z = Var(y) orelse z = GVar(y)
                                     then ()
                                     else raise Fail "Not able to unify Var Assign"
                        | _ => raise Fail "Error in unifyC!");
                        unifyA(s,t))
                     | _ => raise Fail "Not able to unify Var Assign")
  | Assign(x as Arr(_),s) => (case fromKAT of
                     Assign(Arr(an,y),t) => 
                       ((case lookup x (!subList) of
                          NONE => subList := (x,Math(Loc(Arr(an,map varToGVarA y))))::(!subList)
                        | SOME(Math(Loc(z))) => if z = Arr(an,y)
                                     then ()
                                     else raise Fail "Not able to unify Arr Assign"
                        | _ => raise Fail "Error in unifyC!");
                        unifyA(s,t))
                     | _ => raise Fail "Not able to unify Arr Assign")
  | Assign(x as ASub(a,b,Loc(c)),s) => (case fromKAT of
                     Assign(x',t) => 
                       ((case lookup x (!subList) of
                          NONE => subList := (Sub(a,b,c),Math(varToGVarA (Loc(x'))))::(!subList)
                        | SOME(Math(Loc(z))) => if z = x'
                                     then ()
                                     else raise Fail "Not able to unify Arr Assign"
                        | _ => raise Fail "Error in unifyC!");
                        unifyA(s,t))
                     | _ => raise Fail "Not able to unify Arr Assign")

  | Rep(a,Loc(b),Loc(c)) => (case lookup (Sub(LPhi,b,c)) (!subList) of
              NONE => subList := (Sub(LPhi,b,c),flatten(fromKAT))::(!subList)
            | SOME(x) => if x = fromKAT then ()
                         else raise Fail "Not able to unify Rep")

  | Rep(a,Loc(b),RepA(Loc(a1),Loc(a2),Loc(a3))) => (case lookup (Sub(LPhi,b,Sub(a1,a2,a3))) (!subList) of
              NONE => subList := (Sub(LPhi,b,Sub(a1,a2,a3)),fromKAT)::(!subList)
            | SOME(x) => if x = fromKAT then ()
                         else raise Fail "Not able to unify Rep Rep")
  | FCall(name,args) => (case fromKAT of
                           FCall(n,a) => if n = name then app unifyA (ListPair.zip(args,a))
                                         else raise Fail "Not able to unify Fcall"
                         | _ => raise Fail "Not able to unify FCall")



  | One => if fromKAT = One then missingList := (Var("x"))::(!missingList)

           else raise Fail "Not able to unify One"
  | _ => raise Fail "Not able to unify Something Else"

(* Get the list of variables for which the user needs to give a unification *)
fun resolveSubs1() = let
  fun add(items) = case items of
     [] => ()
   | (item::rest) => if (member item (!missingList)) then add(rest)
                    else (missingList := item::(!missingList) ; add(rest))
  fun change((key,item)) =
    case key of
       Sub(t,x as Sub(x1,x2,x3),s) => ((case (lookup s (!subList)) of
                                    SOME(_) => ()
                                  | NONE => add([s]));change(t,item);change(x,item))

     | Sub(t,x,s as Sub(x1,x2,x3)) => ((case (lookup x (!subList)) of
                                    SOME(_) => ()
                                  | NONE => add([x]));change(t,item);change(s,item))

     | Sub(t,x,s) => 
                     ((case (lookup s (!subList),lookup x (!subList)) of
                       (SOME(_),SOME(_)) => ()
                     | (SOME(y),NONE) => add([x])
                     | (NONE,SOME(y)) => add([s])
                     | (NONE,NONE) => add([x,s]));
                    change(t,item))
    | _ => ()
in
  app change (!subList)
end
val debug1 = ref Phi
val debug2 = ref Phi
(* Resolve some more substitutions *)
fun resolveSubs2() = let
   fun change((key,item)) = 
     case key of
       Sub(t,Sub(x,y,z),s) => change(Sub(t,(#1(change(Sub(x,y,z),item))),s),item)
     | Sub(t,x,Sub(s,r,u)) => let 
                                  val newSubC = valOf(lookup (Sub(s,r,u)) 
                                        (!subList))
                                  val newSub = Var "z2323233343433321"
                                  val () = subList := ((newSub,newSubC)::(!subList))
                              in
                                  change(Sub(t,x,newSub),item)
                              end
     | Sub(t,x,s) => let
         val Math(exp) = case (lookup s (!subList)) of
                          NONE => raise Fail "Error in resolveSubs2!"
                        | SOME(Math(x)) => Math(x)
                        | _ => raise Fail "Error in resolveSubs2!"
        val Math(Loc(x')) = valOf(lookup x (!subList))
        val (newt,_) = change(t,item)
        val () = (subList := ((newt,varToGVar(subForXInC(gVarToVarA(exp),
                              Math(gVarToVarA(Loc(x'))),item)))::(!subList)))
       val () = if t <> LPhi then 
        let val Math(texp) = case (lookup t (!subList)) of
                              NONE => raise Fail "Error in resolveSubs2!"
                            | SOME(Math(x)) => Math(x)
                            | _ => raise Fail "Error in resolveSubs2!"
           val newreplace = gVarToVar(subForXInC(gVarToVarA(Loc(x')),
                        Math(varToGVarA(exp)),gVarToVar(Math(texp))))
           val () = debug1 := newreplace
           val () = debug2 := gVarToVar(item)
       in if (newreplace = gVarToVar(item)) then () 
          else raise Fail "Unable to unify!"
       end else ()
   in
     ((newt,varToGVar(subForXInC(gVarToVarA(exp),Math(Loc(x')),item))))
   end

   | (Arr(arr,[RepA(t,Loc(x),Loc(s))])) => let
         val Math(exp) = case (lookup s (!subList)) of
                          NONE => raise Fail "Error in resolveSubs2!"
                        | SOME(Math(x)) => Math(x)
                        | _ => raise Fail "Error in resolveSubs2!"
        val Math(Loc(x')) = valOf(lookup x (!subList))
        val Math(Loc(Arr(tarr,[newt]))) = item
        val () = (subList := ((Arr(arr,[t])),Math(Loc(Arr(tarr,[varToGVarA(subForXInA(gVarToVarA(exp),Math(gVarToVarA(Loc(x'))),newt))]))))::(!subList))
(*       val () = if t <> LPhi then 
        let val Math(Arr(arr,texp)) = case (lookup Arr(arr,[t]) (!subList)) of
                              NONE => raise Fail "Error in resolveSubs2!"
                            | SOME(Math(x)) => Math(x)
                            | _ => raise Fail "Error in resolveSubs2!"
           val newreplace = gVarToVar(subForXInC(gVarToVarA(Loc(x')),
                        Math(varToGVarA(exp)),gVarToVar(Math(texp))))
           val () = debug1 := newreplace
           val () = debug2 := gVarToVar(item)
       in if (newreplace = gVarToVar(item)) then () 
          else raise Fail "Unable to unify!"
       end else ()*)
   in
((Arr(arr,[t])),Math(Loc(Arr(tarr,[varToGVarA(subForXInA(gVarToVarA(exp),Math(gVarToVarA(Loc(x'))),newt))]))))
   end




   |  _ => (key,item)
in
     subList := (map change (!subList))
end



(* Get all user unifications *)
val getMissingFO:(loc list -> (unit -> unit) -> unit) ref = ref 
  (fn list => fn cont =>
    (let fun myCont list = 
       case list of
         nil => cont ()
       | (x::xs) => (let
                 val asString = (aTermToString(Loc(x)))
                 val line = ((!print) (asString^"?");
                             TextIO.inputLine TextIO.stdIn)
                 val Math(parsed) = if line = "\n" then Math(Loc(x))
                                    else Parser.parseLine ("#"^line)
                 val () = unifyA(Loc(x),parsed)
                     in (myCont xs)
                     end)
     in
         (myCont list)
     end))


(* Unify and resolve substitutions *)
fun unify(genTerm,fromKAT) cont =
  (subList := [];
   missingList := [];
  unifyC(genTerm,fromKAT);
  resolveSubs1();
  ((!getMissingFO) (!missingList) cont))

(* Create a substitution list for the citation in the proof *)
fun makeSubList(term,genterm) = case (term,genterm) of
  (Assign(_,_),Assign(_,_)) => [(genterm,valOf(findCom(term)))]
| (Boolean(Not(b1)),Boolean(Not(b2))) => [(genterm,NOT(valOf(findBool(b1))))]
| (Boolean(b1),Boolean(b2)) => [(genterm,valOf(findBool(b1)))]

| (Seq(c1),Seq(Pre::Phi::s2)) =>let val m = maxBoolean(term)
                                    val bs = if m > 1 
                                             then flatten(Seq(List.take(c1,m-1)))
                                             else flatten(Seq(List.take(c1,m)))
                                      val cs = if m > 1 
                                              then flatten(Seq(List.drop(c1,m-1)))
                                              else flatten(Seq(List.drop(c1,m)))
                                  in
                                   if m > 1 
                                   then makeSubList(cs,flatten(Seq(Phi::s2)))
                                   else ((Phi,slangCToKAT(bs))::
                                        makeSubList(cs,flatten(Seq(s2))))
                                  end

| (Seq(c1),Seq(Pre::Rep(Phi,a,b)::s2)) => let val m = maxBoolean(term)
                                              val bs = if m > 1 
                                               then flatten(Seq(List.take(c1,m-1)))
                                               else flatten(Seq(List.take(c1,m)))
                                              val cs = if m > 1 
                                               then flatten(Seq(List.drop(c1,m-1)))
                                               else flatten(Seq(List.drop(c1,m)))
                                  in
                                   if m > 1 
                                   then makeSubList(cs,flatten(Seq(Rep(Phi,a,b)::s2)))
                                   else ((Rep(Phi,a,b),slangCToKAT(bs))::
                                        makeSubList(cs,flatten(Seq(s2))))
                                  end
| (Seq(c1),Seq(Pre::s2)) =>  let val m = maxBoolean(term)
                                              val bs =
                                               flatten(Seq(List.take(c1,m)))
                                              val cs =
                                               flatten(Seq(List.drop(c1,m)))
                             in
                                makeSubList(cs,flatten(Seq(s2)))
                             end
                              


| (Seq(c1),Seq(Phi::gc1)) => let val m = maxBoolean(term)
                                      val bs = flatten(Seq(List.take(c1,m)))
                                      val cs = flatten(Seq(List.drop(c1,m)))
                                  in
                                    ((Phi,slangCToKAT(bs))::
                                    (makeSubList(cs,flatten(Seq(gc1)))))
                                  end
| (Seq(c1),Seq((p as Rep(Phi,_,_))::gc1)) => let val m = maxBoolean(term)
                                            val bs = flatten(Seq(List.take(c1,m)))
                                            val cs = flatten(Seq(List.drop(c1,m)))
                                        in
                                         ((p,slangCToKAT(bs))::
                                         (makeSubList(cs,flatten(Seq(gc1)))))
                                        end
| (Seq(c::c1),Seq(g::g1)) => (makeSubList(c,g)@makeSubList(flatten(Seq(c1)),flatten(Seq(g1))))
| (Seq(nil),Seq(nil)) => []


| (One,One) => []
| (_,Seq[Phi]) => [(genterm,slangCToKAT(term))]
| (_,Phi) => [(genterm,slangCToKAT(term))]
| (_,Pre) => []
| (Boolean(Not(x)),Rep(Phi,_,_)) => [(genterm,NOT(valOf(findBool(x))))]
| (Boolean(x),Rep(Phi,_,_)) => [(genterm,valOf(findBool(x)))]
| (x,Rep(Phi,_,_)) => [(genterm,slangCToKAT(term))]
| (x,Rep(_,_,_)) => [(genterm,slangCToKAT(term))]
| (Phi,_) => []
| (Pre,_) => []
| (FCall(n,a),FCall(_)) => [(genterm,valOf(findCom(term)))]
| (a,b) => raise Fail ((cTermToString(a))^(cTermToString(b)))
handle Option => raise Fail (cTermToString(term))

end
