open Ast
open Corecursive
open Environment

module Eq = struct
  type variable = string
  
  type 'b f = I1 of bool | I2 of 'b | I3 of 'b list
  let fh (h : 'c * 'e -> 'a * 'e) : 'c f * 'e -> 'a f * 'e = function
    | I1 b, e -> I1 b, e
    | I2 c, e -> let a, e1 = h (c, e) in I2 a, e1
    | I3 l, e -> 
      let rec iterate l e = (match l with
       | [ ] -> [ ], e
       | hc :: tc -> let ha, ea1 = h(hc, e) in
                     let la, ea2 = iterate tc ea1 in
                     ha :: la, ea2)
      in let la, ea = iterate l e in I3 la, ea

  type equation = variable * (variable f)
  type coalgebra = state * state
  type algebra = bool
  
  let equal ((s1, o1), (t1, o2)) ((s2, o3), (t2, o4)) = 
    s1 == s2 && o1 == o3 && t1 == t2 && o2 == o4 
    (* TODO: why those comments?
       - it doesn't matter when calling equality from the program,
         all o1, o2, o3, o4 are the same *)

  let gamma ((e1, o1),(e2, o2):coalgebra) : coalgebra f = match (e1, e2) with
  | EVar x1, EVar x2 ->
    I2((Environment.lookup x1 o1, o1), 
       (Environment.lookup x2 o2, o2))
  | EVar x1, _ -> I2((Environment.lookup x1 o1, o1), (e2, o2))
  | _, EVar x2 -> I2((e1, o1), (Environment.lookup x2 o2, o2))
  | EInt i, EInt j -> I1(i=j)
  | EFloat i, EFloat j -> I1(i=j)
  | EString i, EString j -> I1(i=j)
  | EBool i, EBool j -> I1(i=j)
  | ESymbol s1, ESymbol s2 -> I1(s1 = s2)

  | EUnit, EUnit -> I1(true)
  | ETuple e1l, ETuple e2l -> I3 (List.combine (List.map (fun x -> (x, o1)) e1l)
                                             (List.map (fun x -> (x, o2)) e2l)) 
                            (* should always work because of typing *)
  | EInj(i1, None), EInj(i2, None) -> if i1 = i2 then I1(true) else I1(false)
  | EInj(i1, Some(e1)), EInj(i2, Some(e2)) -> 
      if i1 = i2 then I2((e1, o1), (e2, o2)) else I1(false)
  | EInj _, EInj _ -> I1(false)
  | _ -> I1(e1 = e2) (* to make map go through: need to have equality on functions *)
         (* runtime "invalid types for comparison" *)

  let string_of_equation = function
    | (v1, I1 b) -> v1 ^ " = " ^ (string_of_bool b)
    | (v1, I2 v2) -> v1 ^ " = " ^ v2
    | (v1, I3 lv) -> v1 ^ " = " ^ (String.concat " && " lv)

  let solve (name:variable) (eqs:equation list) = 
  (* print the equations, for debugging purposes *)
  (* print_string
      ("Equations: find " ^ name ^ " such that\n" ^
          (EString concat "\n"
             (List.map string_of_equation eqs)) ^
          "\n");
    print_string "Result: "; *)
  (* we just test if I1(false) (= false) appears anywhere in the list *)
  not(List.exists (fun (_, rhs) -> rhs = I1(false)) eqs)
end

module Corecursive_eq = Corecursive(Eq)

let equal (s1:state) (s2:state) : bool = 
  Corecursive_eq.main (s1, s2)

let rec assq x = function
  (* Same as List.assoc, but using S.equal for equality *)
  | (y, a):: tl when equal x y -> a
  | hd:: tl -> assq x tl
  | [] -> raise Not_found
