CS312 Lecture 3
Lists and Other Recursive Datatypes

Recursive datatypes

In recitation you should have seen some simple examples of datatypes, which are SML types that can have more than one kind of value. This involved a new kind of declaration, a datatype declaration (where we've written bold brackets to indicate optional syntax):

datatype Y = X1 [of t1| ... Xn [of tn]

and new expression and pattern forms:

e ::= ... |  X(e)  |  case e of p1=>e1 | ... | pn=>en
p ::=  X  |  X(x1:t1..., xn:tn)

We can use datatypes to define many useful data structures.  We saw in recitation that the bool type is really just a datatype, so we don't need to have booleans built into SML the way we do in Java. We can even define data structures that act like numbers, demonstrating that we don't really have to have numbers built into SML either! A natural number is either the value zero or the successor of some other natural number. This definition leads naturally to the following definition for values that act like natural numbers nat:

datatype nat = Zero | Succ of nat

This is how you might define the natural numbers in a mathematical logic course. We have defined a new type nat, and Zero and Succ are constructors for values of this type. This datatype is more sophisticated than the ones we saw in recitation: it is a recursive datatype because the definition of what a nat is mentions nat itself. Using this definition, we can write define some values that act like natural numbers:

val zero = Zero
val one = Succ(Zero)
val two = Succ(Succ(Zero))
val three = Succ(two)
val four = Succ(three)

When we ask the compiler what four represents, we get

- four;
val it = Succ (Succ (Succ (Succ Zero))) : nat

Thus four is a nested data structure. The equivalent Java definitions would be

public interface nat { }
public class Zero implements nat { }
public class Succ implements nat { nat value; Succ(nat v) { value = v; } /* etc */ }

nat zero = new Zero();
nat one = new Succ(new Zero());
nat two = new Succ(new Succ(new Zero()));
nat three = new Succ(two);
nat four = new Succ(three);

And in fact the Java objects representing the various numbers are actually implemented similarly to the SML values representing the corresponding numbers.

Now we can write functions to manipulate values of this type.

fun iszero(n : nat) : bool = 
  case n of
    Zero => true
  | Succ(m) => false

The case expression allows us to do pattern matching on expressions. Here we're pattern-matching a value of type nat. If the value is Zero we evaluate to true; otherwise we evaluate to false.

fun pred(n : nat) : nat = 
  case n of
    Zero => raise Fail "predecessor on zero"
  | Succ(m) => m

Here we determine the predecessor of a number. If the value of n matches Zero then we raise an exception, since zero has no predecessor in the natural numbers. If the value matches Succ(m) for some value m (which of course also must be of type nat), then we return m.

Similarly we can define a function to add two numbers: (See if the students can come up with this with some coaching.)

fun add(n1:nat, n2:nat) : nat = 
  case n1 of
    Zero => n2
  | Succ(n_minus_1) => add(n_minus_1, Succ(n2))

If you were to try evaluating add(four,four), the compiler would respond with:

- add(four,four);
val it = Succ (Succ (Succ (Succ (Succ #)))) : nat

The compiler correctly performed the addition, but it has abbreviated the output because the data structure is nested so deeply. To easily understand the results of our computation, we would like to convert such values to type int:

fun nat_to_int(n:nat) : int = 
  case n of
    Zero => 0
  | Succ(n) => 1 + nat_to_int(n)

That was pretty easy. Now we can write nat_to_int(add(four,four)) and get 8. How about the inverse operation?

fun int_to_nat(i:int) : nat =
  if i < 0 then raise Fail "int_to_nat on negative number"
  else if i = 0 then Zero
  else Succ(int_to_nat(i-1))

To determine whether a natural number is even or odd, we can write a pair of mutually recursive functions:

fun even(n:nat) : bool =
  case n of
    Zero => true
  | Succ(n) => odd(n)
and odd (n:nat) : bool =
  case n of
    Zero => false
  | Succ(n) => even(n)

You have to use the keyword and to combine mutually recursive functions like this. Otherwise the compiler would flag an error when you refer to odd before it has been defined.

Finally we can define multiplication in terms of addition. (See if the students can figure this out.)

fun mul(n1:nat, n2:nat) : nat =
  case n1 of
    Zero => Zero
  | Succ(n_minus_1) => add(n2, mul(n_minus_1,n2))

Deep pattern matching

It turns out that the syntax of ML is bit richer than the syntax we saw in the last lecture. In addition to new kinds of terms for creating and projecting tuple and record values, and creating and examining datatype values, we also have the ability to match patterns against value to pull them apart into their parts.

When used properly, ML pattern matching leads to concise, clear code.  This is because  ML supports deep pattern matching in which one pattern can appear as a subexpression of another pattern. For example, we see above that Succ(n) is a pattern, but so is Succ(Succ(n)). This pattern matches only on a value that has the form Succ(Succ(v)) for some value v (that is, the successor of the successor of something), and binds the variable n to that something, v.

syntactic class syntactic variables and grammar rule(s) examples
identifiers x y a, x, y, x_y, foo1000, ...
datatypes, datatype constructors X, Y Nil, Conslist
constants c ...~2, ~1, 0, 1, 2 (integers)
 1.0, ~0.001, 3.141 (reals)
true, false (booleans)
"hello", "", "!" (strings)
#"A", #" " (characters)
unary operator u ~, not, size, ...
binary operators b +, *, -, >, <, >=, <=, ^, ...
expressions (terms) e ::=  x  |  u e  |  e1 b e2  | if e1 then e2 else e3  |  let d1...dn in e end  |  e (e1, ..., en)  | (e1,...,en)  | #n e  |   {x1=e1, ..., xn=en}  | #x e  |   X(e)  |  case e of p1=>e1 | ... | pn=>en ~0.001, foo, not b, 2 + 2Cons(2, Nil)
patterns

p ::= x  |  (p1,..., pn)  |  {x1= p1,...,xn= pn}  |  X  |  X ( p )

a:int, (x:int,y:int), I(x:int)
declarations d ::= val p = e  |  fun y p : t = e  |  datatype Y X1 [of t1] | ... | X[of tn] val one = 1
fun square(x: int):  int
datatype d = N | I of int
types t ::= int  |  real  |  bool  |  string  |  char  |  t1->t2  |  t1*...*tn  |  {x1:t1x2:t2,..., xn:tn}  |  Y int, string, int->int, bool*int->bool
values v ::= c  |  (v1,...,vn) |  {x1=v1, ..., xn=vn}  |  X(v) 2, (2,"hello"), Cons(2,Nil)

Pattern matching on records

Of course, natural numbers aren't quite as good as integers, but we can simulate integers in terms of the natural numbers by using a representation consisting of a sign and magnitude:

datatype sign = Pos | Neg
type integer = { sign : sign, mag : nat }

The type keyword simply defines a name for a type. Here we've defined integer to refer to a record type with two fields: sign and mag. Remember that records are unordered, so there is no concept of a "first" field.

Note that a type declaration is different from a datatype declaration; it creates a new way to name an existing type, whereas a datatype declaration creates a new type and also happens to give it a name, which is needed to support recursion. For example, we could write a declaration type natural = nat. The type natural and nat would then be exactly the same type and usable interchangeably.

We can use the definition of integer to write some integers:

val zero = {sign=Pos, mag=Zero}
val zero' = {sign=Neg, mag=Zero}
val one = {sign=Pos, mag=Succ(Zero)}
val neg_one = {sign=Neg, mag=Succ(Zero)}

Now we can write a function to determine the successor of any integer:

fun inc(i:integer) : integer =
    case i of
      {sign = _, mag = Zero} => {sign = Pos, mag = Succ(Zero)}
    | {sign = Pos, mag = n} => {sign = Pos, mag = Succ(n)}
    | {sign = Neg, mag = Succ(n)} => {sign = Neg, mag = n}

Here we're pattern-matching on a record type. Notice that in the third pattern we are doing deep pattern matching because the mag field is matched against a pattern itself, Succ(n). Remember that the patterns are tested in order. How does the meaning of this function change if the first two patterns are swapped?

The predecessor function is very similar, and it should be obvious that we could write functions to add, subtract, and multiply integers in this representation.


Another recursive datatype: integer lists

We can also use datatypes to define some useful data structures. One simple data structure that we're used to is singly linked lists. It turns out that SML has lists built in. For example, in SML the expression [] is an empty list. The expression [1,2,3] is a list containing three integers. The expression 1::[2,3] is 1 appended onto the list [2,3], which is same thing as [1,2,3].  It turns but we can write our own version of lists using datatypes. We want to define values that act like linked lists of integers. A linked list is either empty, or it has an integer followed by another list containing the rest of the list elements. This leads to a very natural datatype declaration:

(* This datatype defines integer lists as either Nil (empty) or
 * a "Cons" cell containing an integer and an integer list.  The
 * term "Cons" comes from Lisp.
 *)
datatype intlist = Nil | Cons of (int * intlist)

(* Here are some example lists *)
val list1 = Nil 		(* the empty list:  []*)
val list2 = Cons(1,Nil) 	(* the list containing just 1:  [1] *)
val list3 = Cons(2,Cons(1,Nil)) (* the list [2,1] *)
val list4 = Cons(2,list2)       (* also the list [2,1] *)
(* the list [1,2,3,4,5] *)
val list5 = Cons(1,Cons(2,Cons(3,Cons(4,Cons(5,Nil)))))
(* the list [6,7,8,9,10] *)
val list6 = Cons(6,Cons(7,Cons(8,Cons(9,Cons(10,Nil)))))

(* test to see if the list is empty *)
fun is_empty(xs:intlist):bool = 
    case xs of
      Nil => true
    | Cons(_,_) => false

(* Return the number of elements in the list *)
fun length(xs:intlist):int = 
    case xs of
      Nil => 0
    | Cons(i:int,rest:intlist) => 1 + length(rest)

(* Notice that the case expressions for lists all have the same
 * form -- a case for the empty list (Nil) and a case for a Cons.
 * Also notice that for most functions, the Cons case involves a
 * recursive function call. *)
(* Return the sum of the elements in the list *)
fun sum(xs:intlist):int = 
    case xs of
      Nil => 0
    | Cons(i:int,rest:intlist) => i + sum(rest)

(* Create a string representation of a list *)
fun toString(xs: intlist):string = 
    case xs of
      Nil => ""
    | Cons(i:int, Nil) => Int.toString(i)
    | Cons(i:int, Cons(j:int, rest:intlist)) => 
       Int.toString(i) ^ "," ^ toString(Cons(j,rest))
    
(* Return the first element (if any) of the list *)
fun head(is: intlist):int = 
    case is of
      Nil => raise Fail("empty list!")
    | Cons(i,tl) => i

(* Return the rest of the list after the first element *)
fun tail(is: intlist):intlist = 
    case is of
      Nil => raise Fail("empty list!")
    | Cons(i,tl) => tl

(* Return the last element of the list (if any) *)
fun last(is: intlist):int = 
    case is of
      Nil => raise Fail("empty list!")
    | Cons(i,Nil) => i
    | Cons(i,tl) => last(tl)

(* Return the ith element of the list *)
fun ith(is: intlist, i:int):int = 
    case (i,is) of
      (_,Nil) => raise Fail("empty list!")
    | (1,Cons(i,tl)) => i
    | (n,Cons(i,tl)) =>
	if (n <= 0) then raise Fail("bad index")
	else ith(tl, i - 1)

(* Append two lists:  append([1,2,3],[4,5,6]) = [1,2,3,4,5,6] *)
fun append(list1:intlist, list2:intlist):intlist = 
    case list1 of
      Nil => list2
    | Cons(i,tl) => Cons(i,append(tl,list2))

(* Reverse a list:  reverse([1,2,3]) = [3,2,1].
 * Notice that we compute this by reversing the tail of the
 * list first (e.g., compute reverse([2,3]) = [3,2]) and then
 * append the singleton list [1] to the end to yield [3,2,1]. *)
fun reverse(list:intlist):intlist = 
    case list of
      Nil => Nil
    | Cons(hd,tl) => append(reverse(tl), Cons(hd,Nil)) 

fun inc(x:int):int = x + 1;
fun square(x:int):int = x * x;

(* given [i1,i2,...,in] return [i1+1,i2+1,...,in+n] *)
fun addone_to_all(list:intlist):intlist = 
    case list of
      Nil => Nil
    | Cons(hd,tl) => Cons(inc(hd), addone_to_all(tl))

(* given [i1,i2,...,in] return [i1*i1,i2*i2,...,in*in] *)
fun square_all(list:intlist):intlist = 
    case list of
      Nil => Nil
    | Cons(hd,tl) => Cons(square(hd), square_all(tl))

(* given a function f and [i1,...,in], return [f(i1),...,f(in)].
 * Notice how we factored out the common parts of addone_to_all
 * and square_all. *)
fun do_function_to_all(f:int->int, list:intlist):intlist = 
    case list of
      Nil => Nil
    | Cons(hd,tl) => Cons(f(hd), do_function_to_all(f,tl))

(* now we can define addone_to_all in terms of do_function_to_all *)
fun addone_to_all(list:intlist):intlist = 
    do_function_to_all(inc, list);

(* same with square_all *)
fun square_all(list:intlist):intlist = 
    do_function_to_all(square, list);

(* given [i1,i2,...,in] return i1+i2+...+in (also defined above) *)
fun sum(list:intlist):int = 
    case list of
      Nil => 0
    | Cons(hd,tl) => hd + sum(tl)

(* given [i1,i2,...,in] return i1*i2*...*in *)
fun product(list:intlist):int = 
    case list of
      Nil => 1
    | Cons(hd,tl) => hd * product(tl)

(* given f, b, and [i1,i2,...,in], return f(i1,f(i2,...,f(in,b))).
 * Again, we factored out the common parts of sum and product. *)
fun collapse(f:(int * int) -> int, b:int, list:intlist):int = 
    case list of
      Nil => b
    | Cons(hd,tl) => f(hd,collapse(f,b,tl))

(* Now we can define sum and product in terms of collapse *)
fun sum(list:intlist):int = 
    let fun add(i1:int,i2:int):int = i1 + i2
    in 
        collapse(add,0,list)
    end

fun product(list:intlist):int = 
    let fun mul(i1:int,i2:int):int = i1 * i2
    in
        collapse(mul,1,list)
    end

(* Here, we use an anonymous function instead of declaring add and mul.
 * After all, what's the point of giving those functions names if all
 * we're going to do is pass them to collapse? *)
fun sum(list:intlist):int = 
    collapse((fn (i1:int,i2:int) => i1+i2),0,list);

fun product(list:intlist):int = 
    collapse((fn (i1:int,i2:int) => i1*i2),1,list);

(* And here, we just pass the operators directly... *)
fun sum(list:intlist):int = collapse(op +, 0, list);

fun product(list:intlist):int = collapse(op *, 1, list);