Type inference and unification

- val f = fn z => z+2val f = fn : int->int- val ident = fn x => xval ident = fn : 'a -> 'a- let fun square z = z*z in fn f => fn x => fn y => if f x y then f (square x) y else f x (f x y) endval it = fn : (int->bool->bool)->int->bool->bool

To see how this works, we'll start with a simple type checker for an ML-like language and extend it to support type inference. There are several things to notice about this code:

- Instead of a function that checks type equality, we've
implemented the comparison of two types using a function
`check_equal`

that raises an exception if the two types are not equal. This is okay because the type checker would always raise an exception if the types were unequal. - The environment has been made polymorphic so we can store different kinds
of types into it. Just for fun, it is implemented in
`Env`

using closures, though any functional set implementation (e.g., red-black trees) would likely be at least as good. - The datatypes
`expr`

and`decl`

give the language syntax, which includes both anonymous and recursive (named) functions.

code/lec22/type-checker.sml | ||

1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
exception TypeError of string type id = string datatype expr = Var of id (* id *) | True | False (* true, false *) | IntConst of int (* n *) | If of expr*expr*expr (* if e1 then e2 else e3 *) | Op of expr * expr (* e1 + e2 *) | Select of int*expr (* #i e *) | Pair of expr * expr (* (e1, e2) *) | Let of (decl list) * expr (* let d in e end *) | Fun of id * type_ * expr (* fn (id:type) => expr *) | Apply of expr * expr (* e1(e2) *) and decl = VarDecl of id * type_ * expr (* val id:t = e *) | FunDecl of id * id * type_ * type_ * expr (* fun id1(id2:t1):t2 = e *) and type_ = Int (* int *) | Bool (* bool *) | Arrow of type_ * type_ (* t1 -> t2 *) | Product of type_ * type_ (* t1 * t2 *) type env = type_ Env.env (* tcheck(r,e) is the type of e in type environment r. * Raises Fail if the expression e does not type-check. *) fun tcheck(r: env, e: expr):type_ = case e of Var(x) => Env.lookup(r, x) | True => Bool | False => Bool | IntConst(n) => Int | Let(d,e) => let val r' = foldl declcheck r d in tcheck(r', e) end | If(e1, e2, e3) => let val t2 = tcheck(r, e2) val t3 = tcheck(r, e3) in if Bool = tcheck(r, e1) andalso t2 = t3 then t2 else raise TypeError "If" end | Op(e1,e2) => if tcheck(r,e1) = Int andalso tcheck(r,e2) = Int then Int else raise TypeError "Op" | Pair(e1,e2) => Product(tcheck(r, e1), tcheck(r, e2)) | Select(i, e) => (case tcheck(r, e) of Product(t1,t2) => (case i of 1 => t1 | 2 => t2 | _ => raise TypeError "Select") | _ => raise TypeError "Select") | Fun(x,t,e) => Arrow(t, tcheck(Env.update(r,x,t), e)) | Apply(e1, e2) => case tcheck(r, e1) of Arrow(t1, t2) => if t1 = tcheck(r, e2) then t2 else raise TypeError "Apply" | _ => raise TypeError "Aplly" (* declcheck(d, r) is the environment r extended with the * declaration d. Raises Fail if the declaration d does not * type-check. *) and declcheck(d: decl, r:env):env = case d of VarDecl(x,t,e) => if tcheck(r, e) = t then Env.update(r, x, t) else raise TypeError "VarDecl" | FunDecl(f,x,t1,t2,e) => let val r' = Env.update(r, f, Arrow(t1,t2)) in if tcheck(Env.update(r', x, t1), e) = t2 then r' else raise TypeError "FunDecl" end (* program_type(e) is the type of the program e * Requires: e contains no unbound variables. * Checks: that e is a well-typed program. *) fun tc(e: expr):type_ = tcheck(Env.empty, e) (* let fun f(n: int):int = n in f end *) val e1 = Let(FunDecl("f", "n", Int, Int, Var("n"))::nil, Var("f")) (* (2, true) *) val e2:expr = Pair(IntConst(2), True) (* fn(x: int->bool) => fn (y: int) => x (y + 1) *) val e3:expr = Fun("x", Arrow(Int,Bool), Fun("y", Int, Apply(Var("x"), Op(Var("y"), IntConst(1))))) val dummy = (Compiler.Control.Print.printDepth := 1000) |

The key idea behind the unification-based type inference algorithm is to
introduce **type variables** to stand in place of types that the algorithm
hasn't figured out yet. If a new undeclared type is encountered (for example, in a
` let`

expression), it is added to the environment, bound to a type variable. During type
checking, type variables are solved for as necessary. This works because the only time that
the type checker above generates any constraints on types is when they are
compared for equality. We can build a type inference algorithm by having the
test for equality also simultaneously solve for type variables as needed to **unify**
the two types being compared: that is, make them equal.

For example, if we compare the two types `T1->bool`

and `(int->T3)->T2`

,
where T1, T2, and T3 are type variables, we can see that we can make these two
types equal by picking types `T1=int->T3`

and `T2=bool`

.
We can think of these two equations as *substitutions* that, if applied to
the types being compared, make them equal to one another. In this case, the
result of applying the substitutions to both types is `(int->T3)->bool`

.

There are many substitutions that would make these two types equal, because
we can add `T3=`

to the substitution for any arbitrary
type *t*

and still unify the two types. Comparing these two
types doesn't give us any information about T3, so we wouldn't want to do that.
Therefore *t***unification** of the two types finds the **weakest substitution** that unifies the
two types. A substitution is weaker than another if the stronger substitution
can be described as applying the weaker substitution, followed by another
non-trivial substitution. For example, any substitution of the form (`T1=int->`

,
*t*`T2=bool,T3=`

) can be achieved by first doing a substitution
(*t*`T1=int->T3`

, `T2=bool`

) and then a substitution (`T3=`

).
The unification algorithm should find the weakest unifying substitution: (*t*`T1=int->T3`

,
`T2=bool`

).

The downside of unification is that it can lead to confusing error messages when an expression is not well-typed. For example:

- fn z => let val (x,y)=z in z(x) endstdIn:2.4-24.5 Error: operator is not a function [tycon mismatch] operator: 'Z * 'Y in expression: z x

The types `'Z`

and `'Y`

are the way that SML reports an
unsolved type variable in an error message. This error message tells us that SML
tried to unify a tuple type `'Z*'Y`

against a function type and
failed (as we would expect). SML hadn't figured out what the types of the tuple
elements were, so it just reported the type variables instead.

A nice way to implement unification-based type inference is to represent type
variables using ref cells. For an unsolved type variable, the ref cell points to
`NONE`

; once it is solved and set equal to some type `t`

,
the cell is updated to point to `SOME(t)`

. Here is an implementation
of type inference using that technique:

code/lec22/type-inference.sml | ||

1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
type id = string exception TypeUnifyError of string datatype expr = Var of id (* id *) | True | False (* true, false *) | IntConst of int (* n *) | If of expr * expr * expr (* if e1 then e2 else e3 *) | Let of (decl list) * expr (* let d in e end *) | Fun of id * expr (* fn (id) => expr -- NO TYPE DECL *) | Op of expr * expr (* e1 + e2 *) | Select of int*expr (* #i e *) | Pair of expr * expr (* (e1, e2) *) | Apply of expr * expr (* e1(e2); *) and decl = VarDecl of id * expr (* val id = e -- NO TYPE DECL *) | FunDecl of id * id * expr (* fun id1(id2) = e -- NO TYPE DECL *) and type_ = Int | Bool | Arrow of type_ * type_ | Product of type_ * type_ | TypeVar of type_ option ref * int (* Create a fresh type variable. *) val cur_index = ref 0 fun freshTypeVar(): type_ = (cur_index := 1 + !cur_index; TypeVar(ref NONE, !cur_index-1)) fun typeToString(t: type_): string = case t of Int => "int" | Bool => "bool" | Arrow(t1,t2) => "("^typeToString(t1)^"->"^typeToString(t2)^")" | Product(t1,t2) => "("^typeToString(t1)^"*"^typeToString(t2)^")" | TypeVar(r,nm) => case !r of NONE => "'"^Char.toString(chr(nm + ord #"a")) | SOME t => typeToString(t) (* Determine whether t1 can be made equal to t2, by unification. * Effects: may solve for some type variables, as necessary. Raises * Fail if the two types cannot be unified. *) fun unify(t1: type_, t2: type_): unit = ( print ("Unify " ^ typeToString(t1) ^ " with " ^ typeToString(t2) ^ "\n"); case t1 of TypeVar(r,nm) => unifyVar(t2, r) | Int => (case t2 of Int => () | TypeVar(r,_) => unifyVar(t1, r) | _ => raise TypeUnifyError "Int") | Bool => (case t2 of Bool => () | TypeVar(r,_) => unifyVar(t1, r) | _ => raise TypeUnifyError "Bool") | Arrow(t3, t4) => (case t2 of Arrow(t3',t4') => (unify(t3,t3'); unify(t4,t4')) | TypeVar(r,_) => unifyVar(t1,r) | _ => raise TypeUnifyError "Arrow") | Product(t3,t4) => (case t2 of Product(t3', t4') => (unify(t3,t3'); unify(t4,t4')) | TypeVar(r,_) => unifyVar(t1, r) | _ => raise TypeUnifyError "Product") ) and unifyVar(t1:type_, r:type_ option ref): unit = case !r of NONE => (case t1 of TypeVar(r1,_) => if r1 <> r then r := SOME t1 else () (* cycle *) | _ => r := SOME t1) | SOME t => unify(t1, t) type env = type_ Env.env (* tinfer(r,e) is the type of e in type environment r. * The execution of tinfer(r,e) triggers type unifications. * Raises Fail if the expression e does not type-check. *) fun tinfer(r: env, e: expr):type_ = case e of Var(x) => Env.lookup(r, x) | True => Bool | False => Bool | IntConst(n) => Int | Let(d,e) => tinfer(foldl declcheck r d, e) | If(e1, e2, e3) => let val t1 = tinfer(r, e1) val t2 = tinfer(r, e2) val t3 = tinfer(r, e3) in unify(t1, Bool); unify(t2, t3); t2 end | Op(e1,e2) => (unify(tinfer(r,e1), Int); unify(tinfer(r,e2), Int); Int) | Fun(x,e) => let val t = freshTypeVar() in Arrow(t, tinfer(Env.update(r,x,t), e)) end | Apply(e1, e2) => let val t = freshTypeVar() val t' = freshTypeVar() in unify(tinfer(r, e1), Arrow(t, t')); unify(tinfer(r, e2), t); t' end | Pair(e1,e2) => Product(tinfer(r, e1), tinfer(r, e2)) | Select(i, e) => let val t1 = freshTypeVar() val t2 = freshTypeVar() in unify(Product(t1,t2), tinfer(r, e)); (case i of 1 => t1 | 2 => t2 | _ => raise Fail("Illegal index")) end (* declcheck(d, r) is the environment r extended with the * declaration d. Raises Fail if the declaration d does not * type-check. *) and declcheck(d: decl, r:env):env = case d of VarDecl(x,e) => Env.update(r, x, tinfer(r, e)) | FunDecl(f,x,e) => let val t1 = freshTypeVar() val t2 = freshTypeVar() val tf = Arrow(t1,t2) val r' = Env.update(r,f,tf) val r'' = Env.update(r',x,t1) val te = tinfer(r'', e) in unify(te, t2); r' end (* EXAMPLES *) fun ti(e) = (cur_index := 0; typeToString(tinfer(Env.empty, e))) (* let fun f(n) = n + 2 in f end *) val e1 = Let(FunDecl("f", "n", Op(Var("n"), IntConst(2)))::nil, Var("f")) (* let fun f(n) = n in f end *) val e2 = Let(FunDecl("f", "n", Var("n"))::nil, Var("f")) (* let fun f(n) = n in f(2) end *) val e2a = Let(FunDecl("f", "n", Var("n"))::nil, Apply(Var("f"), IntConst(2))) (* let fun f(x) = fn(y) => x + y in f end *) val e3 = Let(FunDecl("f", "x", Fun("y", Op(Var("x"), Var("y"))))::nil, Var("f")) (* let fun loop(x) = loop(x+1) in loop end *) val e4 = Let([FunDecl("loop", "x", Apply(Var("loop"), Op(Var("x"),IntConst(1))))], Var("loop")) val S = Fun("x", Fun("y", Fun("z", Apply(Apply(Var("x"),Var("z")), Apply(Var("y"),Var("z")))))) val K = Fun("x", Fun("y", Var("x"))) val O = Fun("f", Fun("g", Fun("x", Apply(Var("f"), Apply(Var("g"), Var("x")))))) val I = Fun("x", Var("x")) |