- val f = fn z => z+2 val f = fn : int->int - val ident = fn x => x val ident = fn : 'a -> 'a - let fun square z = z*z in fn f => fn x => fn y => if f x y then f (square x) y else f x (f x y) end val it = fn : (int->bool->bool)->int->bool->bool
To see how this works, we'll start with a simple type checker for an ML-like language and extend it to support type inference. There are several things to notice about this code:
check_equal
that raises
an exception if the two types are not equal. This is okay because the type
checker would always raise an exception if the types were unequal.Env
using
closures, though any functional set implementation (e.g., red-black trees)
would likely be at least as good. expr
and decl
give the language syntax, which includes both
anonymous and recursive (named) functions.type id = string structure Type = struct datatype type_ = Int | Bool | Arrow of type_ * type_ (* t1 -> t2 *) | Product of type_ * type_ (* t1 * t2 *) (* Checks that t1 and t2 are the same type. *) fun check_equal(t1:type_, t2:type_): unit = case t1 of Int => (case t2 of Int => () | _ => raise Fail("expected int")) | Bool => (case t2 of Bool => () | _ => raise Fail("expected bool")) | Arrow(t3,t4) => (case t2 of Arrow(t3',t4') => (check_equal(t3,t3'); check_equal(t4,t4')) | _ => raise Fail("expected an arrow type")) | Product(t3,t4) => (case t2 of Product(t3',t4') => (check_equal(t3,t3'); check_equal(t4,t4')) | _ => raise Fail("expected a pair type")) end signature ENVIRONMENT = sig (* An 'a env is an environment: * a partial function from identifiers (id) to * values of type 'a *) type 'a env val empty: 'a env (* lookup(env,id) is the value id maps to in env * Checks: env has a mapping for id. *) val lookup: 'a env * id -> 'a (* update(env,id,v) is an environment just like env * except that it has a mapping from id to v *) val update: 'a env * id * 'a -> 'a env end (* Environments as association lists *) structure Env : ENVIRONMENT = struct type 'a env = (id * 'a) list val empty = nil fun lookup(r: 'a env, x: id) = case r of nil => raise Fail("No such variable: "^x) | (y,v)::t => if x = y then v else lookup(t,x) fun update(r: 'a env, x: id, t: 'a) = (x,t) :: r end (* Another interesting O(n) implemention: structure Env : ENVIRONMENT = struct type 'a env = id -> 'a fun empty(x: id) = raise Fail "no such variable"; fun lookup(r: 'a env, x: id) = r(x) fun update(r: 'a env, x: id, t: 'a) = (fn(y:id) => if x = y then t else r(y)) end *) type env = Type.type_ Env.env type type_ = Type.type_ datatype expr = Var of id (* id *) | True | False (* true, false *) | IntConst of int (* n *) | If of expr*expr*expr (* if e1 then e2 else e3 *) | Op of expr * expr (* e1 + e2 *) | Select of int*expr (* #i e *) | Pair of expr * expr (* (e1, e2) *) | Let of (decl list) * expr (* let d in e end *) | Fun of id * type_ * expr (* fn (id:type) => expr *) | Apply of expr * expr (* e1(e2) *) and decl = VarDecl of id * type_ * expr (* val id:t = e *) | FunDecl of id*id*type_*type_*expr (* fun id1(id2:t1):t2 = e *) (* A convenient definition *) val check_equal = Type.check_equal (* tcheck(r,e) is the type of e in type environment r. * Raises Fail if the expression e does not type-check. *) fun tcheck(r: env, e: expr):type_ = case e of Var(x) => Env.lookup(r, x) | True => Type.Bool | False => Type.Bool | IntConst(n) => Type.Int | Let(d,e) => let val r' = foldl declcheck r d in tcheck(r', e) end | If(e1, e2, e3) => let val t2 = tcheck(r, e2) val t3 = tcheck(r, e3) in check_equal(Type.Bool, tcheck(r, e1)); check_equal(t2,t3); t2 end | Op(e1,e2) => (check_equal(tcheck(r,e1), Type.Int); check_equal(tcheck(r,e2), Type.Int); Type.Int) | Pair(e1,e2) => Type.Product(tcheck(r, e1), tcheck(r, e2)) | Select(i, e) => (case tcheck(r, e) of Type.Product(t1,t2) => (case i of 1 => t1 | 2 => t2 | _ => raise Fail("Illegal index")) | _ => raise Fail("expected pair")) | Fun(x,t,e) => Type.Arrow(t, tcheck(Env.update(r,x,t), e)) | Apply(e1, e2) => case tcheck(r, e1) of Type.Arrow(t1, t2) => (check_equal(t1, tcheck(r, e2)); t2) | _ => raise Fail("wrong arg type") and (* declcheck(d, r) is the environment r extended with the * declaration d. Raises Fail if the declaration d does not * type-check. *) declcheck(d: decl, r:env):env = case d of VarDecl(x,t,e) => (check_equal(t, tcheck(r, e)); Env.update(r, x, t)) | FunDecl(f,x,t1,t2,e) => let val r' = Env.update(r, f, Type.Arrow(t1,t2)) in check_equal(t2, tcheck(Env.update(r', x, t1), e)); r' end (* program_type(e) is the type of the program e, which is assumed not * to contain any unbound variables. * Checks: that e is a well-formed program. *) fun tc(e: expr):type_ = tcheck(Env.empty, e) (* let fun f(n: int):int = n in f end *) val e1 = Let(FunDecl("f", "n", Type.Int, Type.Int, Var("n"))::nil, Var("f")) (* (2, true) *) val e2:expr = Pair(IntConst(2), True) (* fn(x: int->bool) => fn (y: int) => x (y + 1) *) val e3:expr = Fun("x", Type.Arrow(Type.Int,Type.Bool), Fun("y", Type.Int, Apply(Var("x"), Op(Var("y"), IntConst(1))))) val dummy = (Compiler.Control.Print.printDepth := 1000)
The key idea behind the unification-based type inference algorithm is to
introduce type variables to stand in place of types that the algorithm
hasn't figured out yet. If a new undeclared type is encountered (for example, in a
let
expression), it is added to the environment, bound to a type variable. During type
checking, type variables are solved for as necessary. This works because the only time that
the type checker above generates any constraints on types is when they are
compared for equality. We can build a type inference algorithm by having the
test for equality also simultaneously solve for type variables as needed to unify
the two types being compared: that is, make them equal.
For example, if we compare the two types T1->bool
and (int->T3)->T2
,
where T1, T2, and T3 are type variables, we can see that we can make these two
types equal by picking types T1=int->T3
and T2=bool
.
We can think of these two equations as substitutions that, if applied to
the types being compared, make them equal to one another. In this case, the
result of applying the substitutions to both types is (int->T3)->bool
.
There are many substitutions that would make these two types equal, because
we can add T3=t
to the substitution for any arbitrary
type t
and still unify the two types. Comparing these two
types doesn't give us any information about T3, so we wouldn't want to do that.
Therefore unification of the two types finds the weakest substitution that unifies the
two types. A substitution is weaker than another if the stronger substitution
can be described as applying the weaker substitution, followed by another
non-trivial substitution. For example, any substitution of the form (T1=int->t
,
T2=bool,T3=t
) can be achieved by first doing a substitution
(T1=int->T3
, T2=bool
) and then a substitution (T3=t
).
The unification algorithm should find the weakest unifying substitution: (T1=int->T3
,
T2=bool
).
The downside of unification is that it can lead to confusing error messages when an expression is not well-typed. For example:
- fn z => let val (x,y)=z in z(x) end stdIn:2.4-24.5 Error: operator is not a function [tycon mismatch] operator: 'Z * 'Y in expression: z x
The types 'Z
and 'Y
are the way that SML reports an
unsolved type variable in an error message. This error message tells us that SML
tried to unify a tuple type 'Z*'Y
against a function type and
failed (as we would expect). SML hadn't figured out what the types of the tuple
elements were, so it just reported the type variables instead.
A nice way to implement unification-based type inference is to represent type
variables using ref cells. For an unsolved type variable, the ref cell points to
NONE
; once it is solved and set equal to some type t
,
the cell is updated to point to SOME(t)
. Here is an implementation
of type inference using that technique:
datatype expr = Var of id (* id *) | True | False (* true, false *) | IntConst of int (* n *) | If of expr * expr * expr (* if e1 then e2 else e3 *) | Let of (decl list) * expr (* let d in e end *) | Fun of id * expr (* fn (id) => expr -- NO TYPE DECL *) | Op of expr * expr (* e1 + e2 *) | Select of int*expr (* #i e *) | Pair of expr * expr (* (e1, e2) *) | Apply of expr * expr (* e1(e2); *) and decl = VarDecl of id * expr (* val id = e -- NO TYPE DECL *) | FunDecl of id * id * expr (* fun id1(id2) = e -- NO TYPE DECL *) structure Type' = struct (* A type_ is a type that may contain some * type variables still to be solved for. A * type variable (Var) is solved by assigning * its ref to a type. *) datatype type_ = Int | Bool | Arrow of type_ * type_ | Pair of type_ * type_ | TypeVar of type_ option ref * int (* Create a fresh type variable. *) val cur_index = ref 0 fun freshTypeVar(): type_ = (cur_index := 1 + !cur_index; TypeVar(ref NONE, !cur_index)) fun unparse(t: type_): string = case t of Int => "int" | Bool => "bool" | Arrow(t1,t2) => unparse(t1)^"->"^unparse(t2) | Pair(t1,t2) => unparse(t1)^"*"^unparse(t2) | TypeVar(r,nm) => case !r of NONE => "'T"^Int.toString(nm) | SOME t => unparse(t) (* Determine whether t1 can be made equal to t2, by unification. * Effects: may solve for some type variables, as necessary. Raises * Fail if the two types cannot be unified. *) fun unify(t1: type_, t2: type_): unit = ( print ("Unifying " ^ unparse(t1) ^ " with " ^ unparse(t2) ^ "\n"); case t1 of TypeVar(r,nm) => unifyVar(t2, r) | Int => (case t2 of Int => () | TypeVar(r,_) => unifyVar(t1, r) | _ => raise Fail("expected int")) | Bool => (case t2 of Bool => () | TypeVar(r,_) => unifyVar(t1, r) | _ => raise Fail("expected bool")) | Arrow(t3, t4) => (case t2 of Arrow(t3',t4') => (unify(t3,t3'); unify(t4,t4')) | TypeVar(r,_) => unifyVar(t1,r) | _ => raise Fail("expected a function type")) | Pair(t3,t4) => (case t2 of Pair(t3', t4') => (unify(t3,t3'); unify(t4,t4')) | TypeVar(r,_) => unifyVar(t1, r) | _ => raise Fail("expected a product type")) ) and unifyVar(t1:type_, r:type_ option ref): unit = case !r of NONE => (case t1 of TypeVar(r1,_) => if r1 <> r then r := SOME t1 else () (* cycle *) | _ => r := SOME t1) | SOME t => unify(t1, t) end type env = Type'.type_ Env.env type type_ = Type'.type_ (* tcheck(r,e) is the type of e in type environment r. * Raises Fail if the expression e does not type-check. *) fun tcheck(r: env, e: expr):type_ = case e of Var(x) => Env.lookup(r, x) | True => Type'.Bool | False => Type'.Bool | IntConst(n) => Type'.Int | Let(d,e) => tcheck(foldl declcheck r d, e) | If(e1, e2, e3) => let val t2 = tcheck(r, e2) val t3 = tcheck(r, e3) in Type'.unify(Type'.Bool, tcheck(r, e1)); Type'.unify(t2,t3); t2 end | Op(e1,e2) => (Type'.unify(tcheck(r,e1), Type'.Int); Type'.unify(tcheck(r,e2), Type'.Int); Type'.Int) | Fun(x,e) => let val t = Type'.freshTypeVar() in Type'.Arrow(t, tcheck(Env.update(r,x,t), e)) end | Apply(e1, e2) => let val t = Type'.freshTypeVar() val t' = Type'.freshTypeVar() val t1 = tcheck(r, e1) val t2 = tcheck(r, e2) val tf = Type'.Arrow(t, t') in Type'.unify(t1, tf); Type'.unify(t2, t); t' end | Pair(e1,e2) => Type'.Pair(tcheck(r, e1), tcheck(r, e2)) | Select(i, e) => let val t1 = Type'.freshTypeVar() val t2 = Type'.freshTypeVar() in Type'.unify(Type'.Pair(t1,t2), tcheck(r, e)); (case i of 1 => t1 | 2 => t2 | _ => raise Fail("Illegal index")) end and (* declcheck(d, r) is the environment r extended with the * declaration d. Raises Fail if the declaration d does not * type-check. *) declcheck(d: decl, r:env):env = case d of VarDecl(x,e) => Env.update(r, x, tcheck(r, e)) | FunDecl(f,x,e) => let val t1 = Type'.freshTypeVar() val t2 = Type'.freshTypeVar() val tf = Type'.Arrow(t1,t2) val r' = Env.update(r,f,tf) val r'' = Env.update(r',x,t1) val te = tcheck(r'', e) in Type'.unify(te, t2); r' end (* EXAMPLES *) fun tc(e) = (Type'.cur_index := 0; Type'.unparse(tcheck(Env.empty, e))) (* let fun f(n) = n + 2 in f end *) val e1 = Let(FunDecl("f", "n", Op(Var("n"), IntConst(2)))::nil, Var("f")) (* let fun f(n) = n in f end *) val e2 = Let(FunDecl("f", "n", Var("n"))::nil, Var("f")) (* let fun f(n) = n in f(2) end *) val e2a = Let(FunDecl("f", "n", Var("n"))::nil, Apply(Var("f"), IntConst(2))) (* let fun f(x) = fn(y) => x + y in f end *) val e3 = Let(FunDecl("f", "x", Fun("y", Op(Var("x"), Var("y"))))::nil, Var("f")) (* let fun loop(x) = loop(x+1) in loop end *) val e4 = Let([FunDecl("loop", "x", Apply(Var("loop"), Op(Var("x"),IntConst(1))))], Var("loop"))This implementation of type inference can be extended with a little more effort to provide SML-style polymorphism (called let-polymorphism) in which variables can have polymorphic type. Consider using the type inference algorithm above to find the type of a variable. If that type contains unsolved type variables that don't appear anywhere else in the program, they clearly can be replaced with any type we want. Therefore the variable with that type can actually be used with different type bindings in different places. For example, if we write
let fun ident(x) => x in ident(ident)(2) end
then the second ident
has type int->int
and the first ident
has type (int->int)->(int->int)
. The type inference
algorithm will find by checking the declaration of ident
that it
has some type 'X->'X
for a type variable 'X
that is
used nowhere else in the program (the type checker can tell this by looking in
the type environment to see whether 'X
appears there). At each use
of ident
, it replaces 'X
with new type variables (say,
'Y
and 'Z
respectively). This decoupling permits them
to be solved independently, as desired, to obtain 'Y = (int->int)->(int->int)
and 'Z=int->int
.
Below is some code that implements type inference with let-polymorphism.
There are a few changes from the simple type inference just given. The
environment no longer maps identifiers to types; it maps them to type schemas.
A type schema is a type along with a list of type variables that can be
substituted differently at every use of the identifier. In declcheck, the types
that are determined for variables are abstracted by schema
to
construct type schemas. Then, when type-checking an identifier, the instantiate
function is used to replace all the type parameters identified by schema
with fresh type variables.
(* Let-polymorphism: type inference and polymorphism ala SML *) type type_ = Type'.type_ type typeVar = type_ option ref * int type typeSchema = type_ * (typeVar list) (* An environment now bind a variable to a type schema rather * than to just a type. *) type env = typeSchema Env.env (* Union of l1 and l2 considered as sets. Requires: l1 and l2 contain no duplicates *) fun union(l1: typeVar list,l2: typeVar list): typeVar list = foldl (fn (r,l) => if List.exists(fn(r') => r=r') l then l else r::l) l1 l2 (* Difference of l1 and l2 considered as sets. Requires: l1 and l2 contain no duplicates *) fun diff(l1,l2): typeVar list = List.filter (fn(r) => not (List.exists(fn(r') => r=r') l2)) l1 (* All unsolved type variables in t. *) fun unsolved(t: Type'.type_): typeVar list = case t of Type'.Int => nil | Type'.Bool => nil | Type'.Arrow(t1,t2) => union(unsolved(t1), unsolved(t2)) | Type'.Pair(t1, t2) => union(unsolved(t1), unsolved(t2)) | Type'.TypeVar(r,nm) => (case !r of NONE => [(r,nm)] | SOME(t) => unsolved(t)) (* All unsolved type variables mentioned in the type environment. This * implementation is not very efficient! *) fun envUnsolved(r: env): typeVar list = foldl (fn((x,(t,tvs)),l0) => let val l1 = unsolved(t) in union(l0,l1) end) nil r (* Given an environment r, produce a schema for t that identifies * all the new type variables in t, which can be arbitrarily * substituted for. *) fun schema(t: type_, r: env): typeSchema = let val uv = unsolved(t) val ev = envUnsolved(r) val uv' = diff(uv,ev) in (t, uv') end (* A type just like t except that every type variable in tvs has * been consistently replaced by a fresh type variable. *) fun instantiate(t: type_, tvs: typeVar list): type_ = let val tm = foldl (fn (tv: typeVar, tm: (typeVar*type_) list) => (tv, Type'.freshTypeVar())::tm) nil tvs fun instVar(tv: typeVar, tm: (typeVar*type_) list): type_ = case tm of nil => Type'.TypeVar(tv) | (tv1, tv2)::tm' => if (tv1 = tv) then tv2 else instVar(tv,tm') fun inst(t: type_) = case t of Type'.Int => t | Type'.Bool => t | Type'.Arrow(t1,t2) => Type'.Arrow(inst(t1), inst(t2)) | Type'.Pair(t1,t2) => Type'.Pair(inst(t1), inst(t2)) | Type'.TypeVar(tv) => instVar(tv, tm) in inst(t) end (* tcheck(r,e) is the type of e in type environment r. * Raises Fail if the expression e does not type-check. *) fun tcheck(r: env, e: expr):type_ = case e of Var(x) => instantiate(Env.lookup(r, x)) (* instantiate schema here *) | True => Type'.Bool | False => Type'.Bool | IntConst(n) => Type'.Int | Let(d,e) => tcheck(foldl declcheck r d, e) | If(e1, e2, e3) => let val t2 = tcheck(r, e2) val t3 = tcheck(r, e3) in Type'.unify(Type'.Bool, tcheck(r, e1)); Type'.unify(t2,t3); t2 end | Op(e1,e2) => (Type'.unify(tcheck(r,e1), Type'.Int); Type'.unify(tcheck(r,e2), Type'.Int); Type'.Int) | Fun(x,e) => let val t = Type'.freshTypeVar() val ts = (t, []) (* no polymorphism inside the -> *) in Type'.Arrow(t, tcheck(Env.update(r,x,ts), e)) end | Apply(e1, e2) => (case tcheck(r, e1) of Type'.Arrow(t1, t2) => (Type'.unify(t1, tcheck(r, e2)); t2) | _ => raise Fail("wrong arg type")) | Pair(e1,e2) => Type'.Pair(tcheck(r, e1), tcheck(r, e2)) | Select(i, e) => let val t1 = Type'.freshTypeVar() val t2 = Type'.freshTypeVar() in Type'.unify(Type'.Pair(t1,t2), tcheck(r, e)); (case i of 1 => t1 | 2 => t2 | _ => raise Fail("Illegal index")) end and (* The environment r extended with the declaration d. * Raises Fail if the declaration d does not * type-check. *) declcheck(d: decl, r:env):env = case d of VarDecl(x,e) => Env.update(r, x, schema(tcheck(r, e), r)) (* generate schema here *) | FunDecl(f,x,e) => let val t1 = Type'.freshTypeVar() val t2 = Type'.freshTypeVar() val tf = Type'.Arrow(t1,t2) val r' = Env.update(r,f,(tf,[])) val r'' = Env.update(r',x,(t1,[])) val te = tcheck(r'', e) in Type'.unify(te, t2); Env.update(r, f, schema(tf,r)) (* generate schema here *) end val ident = Fun("x", Var("x")) (* let fun f(x) = x in f(2) + (f f)(3) end *) val e1 = Let([FunDecl("f", "x", Var("x"))], Op(Apply(Var("f"), IntConst(2)), Apply(Apply(Var("f"), Var("f")), IntConst(3))))