CS 312 Lecture 22
Type inference and unification

In our SML programming we've been writing down types in function declarations. But if we leave the types off, the SML type-checker is able to figure out what the right type annotations should have been. This is called type inference or type reconstruction; In this lecture we'll see how it works.  Notice that even a simple type checker does some amount of type inference; we don't have to write down types on every expression because it figures out a lot of the types itself. But it turns out that we can type-check the core of SML without any type declarations, which may seem surprising. 
- val f = fn z => z+2
val f = fn : int->int
- val ident = fn x => x
val ident = fn : 'a -> 'a
- let fun square z = z*z
  in
    fn f => fn x => fn y =>
     if f x y then f (square x) y
     else f x (f x y)
  end
val it = fn : (int->bool->bool)->int->bool->bool

To see how this works, we'll start with a simple type checker for an ML-like language and extend it to support type inference. There are several things to notice about this code:


code/lec22/type-checker.sml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
 
exception TypeError of string

type id = string

datatype expr =
    Var of id                    (* id *)
  | True | False                 (* true, false *)
  | IntConst of int              (* n *)
  | If of expr*expr*expr         (* if e1 then e2 else e3 *)
  | Op of expr * expr            (* e1 + e2 *)
  | Select of int*expr           (* #i e *)
  | Pair of expr * expr          (* (e1, e2) *)
  | Let of (decl list) * expr    (* let d in e end *)
  | Fun of id * type_ * expr     (* fn (id:type) => expr *)
  | Apply of expr * expr         (* e1(e2) *)

and decl =
    VarDecl of id * type_ * expr (* val id:t = e *)
  | FunDecl of id * id * type_ * type_ * expr (* fun id1(id2:t1):t2 = e *)

and type_ =
    Int                          (* int *)
  | Bool                         (* bool *)
  | Arrow of type_ * type_       (* t1 -> t2 *)
  | Product of type_ * type_     (* t1 * t2 *)

type env = type_ Env.env

(* tcheck(r,e) is the type of e in type environment r.
 * Raises Fail if the expression e does not type-check. *)
fun tcheck(r: env, e: expr):type_ =
  case e of
    Var(x) => Env.lookup(r, x)

  | True => Bool

  | False => Bool

  | IntConst(n) => Int

  | Let(d,e) =>
      let val r' = foldl declcheck r d
      in tcheck(r', e)
      end

  | If(e1, e2, e3) =>
      let val t2 = tcheck(r, e2)
          val t3 = tcheck(r, e3)
      in
        if Bool = tcheck(r, e1) andalso t2 = t3
          then t2 else raise TypeError "If"
      end

  | Op(e1,e2) =>
      if tcheck(r,e1) = Int andalso tcheck(r,e2) = Int
        then Int else raise TypeError "Op"

  | Pair(e1,e2) =>
          Product(tcheck(r, e1), tcheck(r, e2))

  | Select(i, e) => (case tcheck(r, e)
                       of Product(t1,t2) =>
                         (case i
                            of 1 => t1
                             | 2 => t2
                             | _ => raise TypeError "Select")
                        | _ => raise TypeError "Select")

  | Fun(x,t,e) =>
          Arrow(t, tcheck(Env.update(r,x,t), e))

  | Apply(e1, e2) => case tcheck(r, e1)
                       of Arrow(t1, t2) =>
                         if t1 = tcheck(r, e2) then t2
                         else raise TypeError "Apply"
                        | _ => raise TypeError "Aplly"

(* declcheck(d, r) is the environment r extended with the
 * declaration d. Raises Fail if the declaration d does not
 * type-check. *)
and declcheck(d: decl, r:env):env =
  case d of
    VarDecl(x,t,e) =>
      if tcheck(r, e) = t then Env.update(r, x, t)
      else raise TypeError "VarDecl"

  | FunDecl(f,x,t1,t2,e) =>
      let val r' = Env.update(r, f, Arrow(t1,t2))
      in
        if tcheck(Env.update(r', x, t1), e) = t2 then r'
        else raise TypeError "FunDecl"
      end


(* program_type(e) is the type of the program e
 * Requires: e contains no unbound variables.
 * Checks: that e is a well-typed program. *)
fun tc(e: expr):type_ = tcheck(Env.empty, e)

(* let fun f(n: int):int = n in f end *)
val e1 = Let(FunDecl("f", "n", Int, Int, Var("n"))::nil, Var("f"))

(* (2, true) *)
val e2:expr = Pair(IntConst(2), True)

(* fn(x: int->bool) => fn (y: int) => x (y + 1) *)
val e3:expr = Fun("x", Arrow(Int,Bool),
                  Fun("y", Int,
                      Apply(Var("x"), Op(Var("y"), IntConst(1)))))

val dummy = (Compiler.Control.Print.printDepth := 1000)
 

The key idea behind the unification-based type inference algorithm is to introduce type variables to stand in place of types that the algorithm hasn't figured out yet. If a new undeclared type is encountered (for example, in a let expression), it is added to the environment, bound to a type variable. During type checking, type variables are solved for as necessary. This works because the only time that the type checker above generates any constraints on types is when they are compared for equality. We can build a type inference algorithm by having the test for equality also simultaneously solve for type variables as needed to unify the two types being compared: that is, make them equal.

For example, if we compare the two types T1->bool and (int->T3)->T2, where T1, T2, and T3 are type variables, we can see that we can make these two types equal by picking types T1=int->T3 and T2=bool. We can think of these two equations as substitutions that, if applied to the types being compared, make them equal to one another. In this case, the result of applying the substitutions to both types is (int->T3)->bool.

There are many substitutions that would make these two types equal, because we can add T3=t to the substitution for any arbitrary type t and still unify the two types. Comparing these two types doesn't give us any information about T3, so we wouldn't want to do that. Therefore unification of the two types finds the weakest substitution that unifies the two types. A substitution is weaker than another if the stronger substitution can be described as applying the weaker substitution, followed by another non-trivial substitution. For example, any substitution of the form (T1=int->t, T2=bool,T3=t) can be achieved by first doing a substitution (T1=int->T3, T2=bool) and then a substitution (T3=t). The unification algorithm should find the weakest unifying substitution:  (T1=int->T3, T2=bool).

The downside of unification is that it can lead to confusing error messages when an expression is not well-typed. For example:

- fn z => let val (x,y)=z in z(x) end
stdIn:2.4-24.5 Error: operator is not a function [tycon mismatch]
  operator: 'Z * 'Y
  in expression:
    z x

The types 'Z and 'Y are the way that SML reports an unsolved type variable in an error message. This error message tells us that SML tried to unify a tuple type 'Z*'Y against a function type and failed (as we would expect). SML hadn't figured out what the types of the tuple elements were, so it just reported the type variables instead.

A nice way to implement unification-based type inference is to represent type variables using ref cells. For an unsolved type variable, the ref cell points to NONE; once it is solved and set equal to some type t, the cell is updated to point to SOME(t). Here is an implementation of type inference using that technique:


code/lec22/type-inference.sml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
 
type id = string

exception TypeUnifyError of string

datatype expr =
    Var of id                    (* id *)
  | True | False                 (* true, false *)
  | IntConst of int              (* n *)
  | If of expr * expr * expr     (* if e1 then e2 else e3 *)
  | Let of (decl list) * expr    (* let d in e end *)
  | Fun of id * expr             (* fn (id) => expr -- NO TYPE DECL *)
  | Op of expr * expr            (* e1 + e2 *)
  | Select of int*expr           (* #i e *)
  | Pair of expr * expr          (* (e1, e2) *)
  | Apply of expr * expr         (* e1(e2); *)

and decl =
    VarDecl of id * expr         (* val id = e -- NO TYPE DECL *)
  | FunDecl of id * id * expr    (* fun id1(id2) = e -- NO TYPE DECL *)

and type_ =
    Int
  | Bool
  | Arrow of type_ * type_
  | Product of type_ * type_
  | TypeVar of type_ option ref * int

(* Create a fresh type variable. *)
val cur_index = ref 0
fun freshTypeVar(): type_ = (cur_index := 1 + !cur_index;
                             TypeVar(ref NONE, !cur_index-1))

fun typeToString(t: type_): string =
  case t
    of Int => "int"
     | Bool => "bool"
     | Arrow(t1,t2) => "("^typeToString(t1)^"->"^typeToString(t2)^")"
     | Product(t1,t2) => "("^typeToString(t1)^"*"^typeToString(t2)^")"
     | TypeVar(r,nm) =>  case !r
                           of NONE => "'"^Char.toString(chr(nm + ord #"a"))
                            | SOME t => typeToString(t)

(* Determine whether t1 can be made equal to t2, by unification.
 * Effects: may solve for some type variables, as necessary. Raises
 * Fail if the two types cannot be unified. *)
fun unify(t1: type_, t2: type_): unit =
  (
   print ("Unify " ^ typeToString(t1) ^ " with " ^ typeToString(t2) ^ "\n");
   case t1
     of TypeVar(r,nm) => unifyVar(t2, r)
      | Int =>
       (case t2
          of Int => ()
           | TypeVar(r,_) => unifyVar(t1, r)
           | _ => raise TypeUnifyError "Int")
      | Bool =>
          (case t2
             of Bool => ()
              | TypeVar(r,_) => unifyVar(t1, r)
              | _ => raise TypeUnifyError "Bool")
      | Arrow(t3, t4) =>
             (case t2
                of Arrow(t3',t4') => (unify(t3,t3');
                                      unify(t4,t4'))
                 | TypeVar(r,_) => unifyVar(t1,r)
                 | _ => raise TypeUnifyError "Arrow")
      | Product(t3,t4) =>
                (case t2
                   of Product(t3', t4') => (unify(t3,t3');
                                         unify(t4,t4'))
                    | TypeVar(r,_) => unifyVar(t1, r)
                    | _ => raise TypeUnifyError "Product")
                   )
and unifyVar(t1:type_, r:type_ option ref): unit =
  case !r
    of NONE =>
      (case t1 of
         TypeVar(r1,_) => if r1 <> r then r := SOME t1 else () (* cycle *)
       | _ => r := SOME t1)
     | SOME t => unify(t1, t)

type env = type_ Env.env

(* tinfer(r,e) is the type of e in type environment r.
 * The execution of tinfer(r,e) triggers type unifications.
 * Raises Fail if the expression e does not type-check. *)
fun tinfer(r: env, e: expr):type_ =
  case e of
    Var(x) => Env.lookup(r, x)
  | True => Bool
  | False => Bool
  | IntConst(n) => Int

  | Let(d,e) => tinfer(foldl declcheck r d, e)

  | If(e1, e2, e3) =>
      let
        val t1 = tinfer(r, e1)
        val t2 = tinfer(r, e2)
        val t3 = tinfer(r, e3)
      in
        unify(t1, Bool);
        unify(t2, t3);
        t2
      end

  | Op(e1,e2) =>
      (unify(tinfer(r,e1), Int);
       unify(tinfer(r,e2), Int);
       Int)

  | Fun(x,e) =>
      let val t = freshTypeVar()
      in Arrow(t, tinfer(Env.update(r,x,t), e))
      end

  | Apply(e1, e2) =>
      let
        val t = freshTypeVar()
        val t' = freshTypeVar()
      in
        unify(tinfer(r, e1), Arrow(t, t'));
        unify(tinfer(r, e2), t);
        t'
      end

  | Pair(e1,e2) => Product(tinfer(r, e1), tinfer(r, e2))

  | Select(i, e) =>
      let
        val t1 = freshTypeVar()
        val t2 = freshTypeVar()
      in
        unify(Product(t1,t2), tinfer(r, e));
        (case i
           of 1 => t1
            | 2 => t2
            | _ => raise Fail("Illegal index"))
      end

(* declcheck(d, r) is the environment r extended with the
 * declaration d. Raises Fail if the declaration d does not
 * type-check. *)
and declcheck(d: decl, r:env):env =
  case d of
    VarDecl(x,e) =>
      Env.update(r, x, tinfer(r, e))
  | FunDecl(f,x,e) =>
      let val t1 = freshTypeVar()
          val t2 = freshTypeVar()
          val tf = Arrow(t1,t2)
          val r' = Env.update(r,f,tf)
          val r'' = Env.update(r',x,t1)
          val te = tinfer(r'', e)
      in
        unify(te, t2); r'
      end

(* EXAMPLES *)

fun ti(e) = (cur_index := 0; typeToString(tinfer(Env.empty, e)))

(* let fun f(n) = n + 2 in f end *)
val e1 = Let(FunDecl("f", "n", Op(Var("n"), IntConst(2)))::nil, Var("f"))

(* let fun f(n) = n in f end *)
val e2 = Let(FunDecl("f", "n", Var("n"))::nil, Var("f"))

(* let fun f(n) = n in f(2) end *)
val e2a = Let(FunDecl("f", "n", Var("n"))::nil, Apply(Var("f"), IntConst(2)))

(* let fun f(x) = fn(y) => x + y in f end *)
val e3 = Let(FunDecl("f", "x", Fun("y", Op(Var("x"), Var("y"))))::nil, Var("f"))

(* let fun loop(x) = loop(x+1) in loop end *)
val e4 = Let([FunDecl("loop", "x",
                      Apply(Var("loop"),
                            Op(Var("x"),IntConst(1))))],
             Var("loop"))

val S = Fun("x", Fun("y", Fun("z", Apply(Apply(Var("x"),Var("z")),
                                         Apply(Var("y"),Var("z"))))))

val K = Fun("x", Fun("y", Var("x")))
val O = Fun("f", Fun("g", Fun("x", Apply(Var("f"),
                                         Apply(Var("g"), Var("x"))))))
val I = Fun("x", Var("x"))