Recitation 3: Higher-Order Functions & Datatype Constructors

Note to instructors: I don't expect you to necessarily get through all of this material. Go only as fast as the students can absorb it, and use more examples as necessary. Remind students that they can read the notes online and come to consulting hours for additional help. Make sure the students know that there's extra material that you didn't cover.


Higher-order functions

Functions are values just like any other value in SML. What does that mean exactly? This means that we can pass functions around as arguments to other functions, that we can store functions in data structures, that we can return functions as a result from other functions. The full implication of this will not hit you until later, but believe us, it will.

Let us look at why it is useful to have higher-order functions. The first reason is that it allows you to write more general code, hence more reusable code. As a running example, consider functions double and square on integers:

fun double (x:int):int = 2 * x
fun square (x:int):int = x * x

Let us now come up with a function to quadruple a number. We could do it directly, but for utterly twisted motives decide to use the function double above:

fun quad (x:int):int = double (double (x))

Straightforward enough. What about a function to raise an integer to the fourth power?

fun fourth (x:int):int = square (square (x))

There is an obvious similarity between these two functions: what they do is apply a given function twice to a value. By passing in the function to apply_twice as an argument, we can abstract this functionality and thus reuse code:

fun apply_twice (f:int -> int, x:int):int = f (f (x))

Using this, we can write:

fun quad (x:int):int = apply_twice(double,x)
fun fourth (x:int):int = apply_twice(square,x)

The advantage is that the similarity between these two functions has been made manifest. Doing this is very helpful. If someone comes up with an improved (or corrected) version of apply_twice, then every function that uses it profits from the improvement.

The function apply_twice is a so-called higher-order function: it is a function from functions to other values. Notice the type of apply_twice is ((int -> int) * int) -> int

In order not to pollute the top level namespace, it can be useful to locally define the function to pass in as an argument. For example:

fun fourth (x:int):int = 
  let 
    fun square (y:int):int = y * y
  in
    apply_twice (square,x)
  end

However, it seems silly to define and name a function simply to pass it in as an argument to another function. After all, all we really care about is that apply_twice gets a function that doubles its argument. We can do that using some new syntax:

fun fourth (x:int):int = apply_twice (fn (y:int):int => y*y,x)

We introduce a new expression to denote "a function that expects such and such argument and returning such an expression":

e ::= ...  |  fn (id : type) => e

The fn expression creates an anonymous function: a function without a name. The type makes things actually clearer. Unlike top-level functions, the return type of an anonymous function is not declared (and is inferred automatically). What is the type of fn (y:int) => y = 3?

Answer: int -> bool

The declaration val square : int -> int = fn (y:int) => y*y has the same effect as fun square (y:int):int = y * y. In fact, the declaration using fun is just syntactic sugar for the more tedious val declaration.

Anonymous functions are useful for creating functions to pass as arguments to other functions, but are also useful for writing functions that return other functions! Let us revisit the apply_twice function. We now write a function twice which takes a function as an argument and returns a new function that applies the original function twice:

fun twice (f: int->int) = 
  fn(x: int):int => f (f (x))

This function takes a function f (of type int->int) as an argument, and returns a value fn (x:int) => f (f (x)), which is a function which when applied to an argument applies f twice to that argument. Thus, we can write

val fourth = twice (fn (x:int) => x * x)
val quad = twice (fn (x:int) => 2 * x)

and trying to evaluate fourth (3) does indeed result in 81.

Here are more examples of useful higher-order functions that we will leave you to ponder (and try out at home):

fun compose (f:int -> int, g:int -> int) = 
  fn (x:int) => f (g (x))
fun ntimes (f:int -> int,n:int) = 
  if (n=0)
    then (fn (x:int) => x)
  else compose (f, ntimes (f,n-1))

Implementing Binary Trees with Tuples

In class yesterday, we saw an example of using a datatype to define integer lists in terms of an empty list (Nil) and cons cells (a head containing an integer and a tail consisting of another list). We were able to iterate over the list, doing various manipulations on the data, and we were able to represent this concisely using higher-order functions. Today we're going to start by doing the same thing with binary trees to make sure everyone is very comfortable with pattern matching in case expressions.

The obvious way to start is with the following datatype, which we saw at the end of the last lecture:

datatype inttree = Leaf | Branch of (int * inttree * inttree)

This defines a type inttree to be either a leaf node (containing no data) or a branch node (containing an int and left and right subtrees). We could have defined a leaf note to contain an integer and no subtrees (some people do this), but then we'd need another constructor to represent the empty tree. Consider the representation of a generic tree.

The first logical function to write is is_empty:

fun is_empty (xs:inttree) : bool =
    case xs of
        Leaf => true
      | _ => false

Then, just as we computed the length of a list, we can count the non-leaf nodes in a tree:

fun size (xs:inttree) : int =
    case xs of
        Leaf => 0
      | Branch(_, left, right) => 1 + size(left) + size(right)

The pattern matching done in this function is very powerful. (If you don't see the power yet, you certainly will when the datatypes become as complicated as our definition of expressions in ML.) We can make very trivial changes to this function to compute several other interesting values:


Implementing Binary Trees with Records

For both lists and trees, we've been using tuples to represent the nodes. But with trees, there may be some confusion with respect to the order of the fields: does the datum come before or after the left subtree? We can solve this problem using a record type. We can define it as

datatype inttree = Leaf | Branch of { datum:int, left:inttree, right:inttree }

(Note: Binary trees are simple enough that this probably would not be adequate motivation to use records in a real program, since we now have to remember the field names and spell them out every time we use them. Skipping to the next section on polynomials is fine if running low on time.)

Using this new representation, we can write size as

fun size (xs:inttree) : int =
    case xs of
        Leaf => 0
      | Branch{datum=i, left=lt, right=rt} => 1 + size(lt) + size(rt)

We've written several functions to analyze trees, but we don't yet have a way to generate large trees easily, so if you want to try these functions in your compiler, you'd have a lot of typing to do to spell out a tree of depth 10. Let's get the compiler to do it for us.

fun complete_tree (i:int, depth:int) : inttree =
    case depth of
        0 => Leaf
      | _ => Branch{datum=i,
                    left=complete_tree(2*i, depth-1),
                    right=complete_tree(2*i+1, depth-1)}

This function will take an integer i and a depth and recursively create a complete tree of the given depth whose nodes are given distinct indices based on i. If we start with i=1, then we get a complete tree whose preorder node listing is 1, 2, 3, etc. Consider the example given by

val test_tree = complete_tree(1,3)

Now that we have an example tree to work on, we need a cleaner way to visualize the tree than looking at the compiler's representation of records. Let's write a function to print the contents of a tree in order:

fun print_inorder (xs:inttree) : unit =
    case xs of
        Leaf => ()
      | Branch{datum=i, left, right} => (print_inorder(left);
                                         print(" " ^ Int.toString(i) ^ " ");
                                         print_inorder(right))

Notice that here we did not provide names for binding the left and right subtrees. Actually, the use of record labels only is just syntactic sugar for binding the same name to its value, so we could have written "datum=i, left=left, right=right". Anyway, our function behaves as follows on our test tree:

- print_inorder(test_tree);
 4  2  5  1  6  3  7 val it = () : unit

We could have applied many other functions to each element of the tree. A standard data structure operation is apply, which executes a given function on every element. The function is evaluated for side-effects only; the return value is ignored. How could we write apply_inorder for our trees?

fun apply_inorder (f:int->unit, xs:inttree) : unit =
    case xs of
        Leaf => ()
      | Branch{datum, left, right} => (apply_inorder(f,left);
                                       f(datum);
                                       apply_inorder(f,right))

Using this, we can write a very short version of print_inorder:

fun print_inorder (xs:inttree) : unit =
    apply_inorder(fn (i:int) => print(" " ^ Int.toString(i) ^ " "), xs)

Another common operation is map, which generates a copy of the data structure in which a given function has been applied to every element. We can write apply_inorder as

fun map_tree (f:int->int, xs:inttree) : inttree =
    case xs of
        Leaf => Leaf
      | Branch{datum=i, left, right} => Branch{datum=f(i),
                                               left=map_tree(f,left),
                                               right=map_tree(f,right)}

How could we use this to square a tree?

val tripled_tree = map_tree(fn (i:int) => i*3, test_tree)

A Little Algebra

Here's a new example. Suppose we want to represent polynomials of a single variable. Thus we are interested in sums of terms, where each term has a real coefficient and an integral power; for example, the polynomial 3x3+5x2-x+10.3. We can write this in ML as

type term = real * int
datatype termsum = Zero | Sum of (term * termsum)

Notice the use of the declaration starting with the keyword type. As we've seen earlier, this simply defines a name for a type; we've defined a term to be an ordered pair of a real and an int. Using this representation, the polynomial x2+2x+3 can be written as

val test_poly = Sum((1.0,2), Sum((2.0,1), Sum((3.0,0), Zero)))

Since the compiler normally only prints out the first few levels of a data structure, but long polynomials involve very deep nesting, we'll ask it to print more than usual for us:

Compiler.Control.Print.printDepth := 100;

This is useful for debugging, but of course we really want a way to print out polynomials in a more readable way. Consider this function by breaking it down into pieces:

fun print_poly(ts:termsum) : unit =
    let fun print_term((coef,pow):term) : unit =
            (print(Real.toString(coef));
             if pow <> 0 then print("*x^" ^ Int.toString(pow)) else ())
    in case ts of
           Zero => print("0\n")
         | Sum(t, Zero) => (print_term(t); print("\n"))
         | Sum(t, xs) => (print_term(t); print(" + "); print_poly(xs))
    end

For simplicity, we'll require that all polynomials are ordered by strictly decreasing powers. This will make manipulation of polynomials a lot easier. We're allowed to do this because we're defining the data structure, but we have to be careful that all of our functions now preserve this invariant.

How can we add two polynomials using pattern matching?

fun add_poly(ts1:termsum, ts2:termsum) : termsum =
    case (ts1,ts2) of
        (Zero,_) => ts2
      | (_,Zero) => ts1
      | (Sum((c1,p1),xs1),Sum((c2,p2),xs2)) =>
        case Int.compare(p1,p2) of
            GREATER => Sum((c1,p1), add_poly(xs1,ts2))
          | LESS => Sum((c2,p2), add_poly(ts1,xs2))
          | EQUAL => Sum((c1+c2,p1), add_poly(xs1,xs2))

This definition does two things we haven't seen before. First, it actually constructs a tuple for the purpose of pattern-matching, which allows us to match on multiple values at the same time. Second, we use the library function Int.compare, which compares two integers and returns a value of type order. The SML Basis Library defines order to be

datatype order = LESS | EQUAL | GREATER

This obviously is superior to the C convention of returning integers to represent order, since it allows the compiler to check the types of arguments. (C uses integers for too many things, so it's easy for type errors to go unnoticed.)

Finally, we can make use of add_poly to write a function to multiply polynomials:

fun mult_poly(ts1:termsum, ts2:termsum) : termsum =
    let fun mult_term((c1,p1):term, ts:termsum) =
            case ts of
                Zero => Zero
              | Sum((c2,p2),xs) => Sum((c1*c2,p1+p2),mult_term((c1,p1),xs))
    in case ts1 of
           Zero => Zero
         | Sum(t,xs1) => add_poly(mult_term(t,ts2), mult_poly(xs1,ts2))
    end

For example, we can try:

- print_poly(mult_poly(test_poly, test_poly));
1.0*x^4 + 4.0*x^3 + 10.0*x^2 + 12.0*x^1 + 9.0
val it = () : unit