Lecture 21: Monads


In the context of async, have been thinking of the >>= operator as being just like a let statement, except that it handles scheduling. That is, we think of

d >>= fun x ->
return e

as being analogous to

let x = d in

This is an instance of a general design pattern called a monad. You use a monad when you want to add some extra functionality to your let ... in expressions. A monad has the following module type:

module type Monad = sig
  (** a "wrapped up value of type 'a" *)
  type 'a t

  (** [m >>= f] unwraps m, passes the result to f, and return an 'a t
      representing the output of f *)
  val (>>=) : 'a t -> ('a -> 'b t) -> 'b t

  (** wrap up a naked value *)
  val return : 'a -> 'a t

The Option monad

We might like to evaluate something like

(Some 3) + (Some 5)

and get Some 7; we would expect if either of the arguments were None that the whole expression would evaluate to None. Unfortunately, as you know, the types don't work out. We can use a pattern matching let to get almost the same effect:

let option_sum xo yo =
  let (Some x) = xo in
  let (Some y) = yo in
  Some (x + y)

but these let statements are inexhaustive pattern matches; if you call option_sum (Some 3) None you will get an exception instead of a None. The correct implementation involves nested pattern matches, and quickly loses the clarity of the option_sum implementation.

We can use monads to get the best of both worlds. As with the Deferred monad, we think of xo >>= fun x -> as "let x = xo in but with some extra details":

let option_sum xo yo =
  xo >>= fun x ->
  yo >>= fun y ->
  return (x + y)

This works exactly as expected if we define the option monad as follows:

module OptionMonad = struct
  type 'a t = 'a option

  let (>>=) xo f = match xo with
    | None -> None
    | Some x -> f x

  let return x = Some x

List monad

In python, you can write list comprehensions like

>>> [x*x for x in [1,2,3]]

which should be read as: create a list containing x*x for every x in the list [1,2,3]. You can iterate multiple variables:

>>> [x+y for x in [0,1,2] for y in [10,20]]
[10, 20, 11, 21, 12, 22]

This expression considers every pair of x and y, and computes x + y for each.

We can think of this as

let x = an element of [0,1,2] in
let y = an element of [10,20] in
return all possible x + y

With this interpretation, we see we can encode the list comprehension as a monad expression:

let all_sums xs ys =
  xs >>= fun x ->
  ys >>= fun y ->
  return (x + y)

To make this work, we must define the list monad:

module ListMonad = struct
  type 'a t

  let (>>=) l f =
    List.concat (List.map f l)

  let return x = [x]

The concatenation is required because f may return a list of possible values, and we want to take each element of that least for each possible input. Mapping f over l gives a list of lists, and concat simply combines these into a single list.

The Arbitrary monad

QCheck uses the notion of an "arbitrary" value. We can think of an arbitrary value as an algorithm for randomly generating a value, given a random number generator. In fact this is how 'a Arbitrary.t is defined:

type 'a Arbitrary.t = Random.state -> 'a

Suppose we wanted to write a function that returns the total result after two 6-sided dice are rolled. The Arbitrary module gives us the value (1 -- 6) : int Arbitrary.t. To sum two dice rolls we might like to write the following function:

let sum_dice xa ya =
  let x = sample xa in
  let y = sample xb in
  xa + xb

sum_dice (1--6) (1--6) : int Arbitrary.t

As above, we can write this code monadically:

let sum_dice xa ya =
  xa >>= fun x ->
  ya >>= fun y ->
  return (x + y)

This requires us to define the Arbitrary monad operations:

module Arbitrary = struct
  type 'a t = Random.state -> 'a

  let (>>=) a f = fun generator ->
    let x = a generator in
    f x

  let return x = fun generator -> x

Suppose that instead of just generating a random die roll, we wanted to compute the probablity for all of the possible outcomes. We could use exactly the same code, but run it in the distribution monad:

module Distribution = struct
  (** Abstraction function: the list [x1,p1; x2,p2; ...] represents the
      probability space with possible outcomes x1, x2, and x3; the
      probability of x1 is p1, the probability of x2 is p2, and so on. *)

  (** Rep invariant: the xi should be distinct, and the sum of the pi should
      be 1. *)

  type 'a t = ('a * float) list

  let (>>=) ps f =
    (** if ps is [x1, px1; x2, px2; ...]
      *  - for each i, apply f to xi, getting [y1, py1; y2, py2; ...]
      *  - multiply each yi by pi to generate the list [yj,pxi *. pyj]
      *  - combine the lists together
      *  - remove duplicate ys by summing their probabilities

  let return x = [(x, 1.)]

  let (--) from to =
    let count = to - from + 1 in
    let p = 1. /. (float_of_int count) in
    [(from, p); (from+1, p); ...; (to-1, p); (to,p)]

This suggests one could write the dice-rolling program as a functor over the monad:

module type MonadWithRange = sig
  include Monad

  val (--) : int -> int -> int t

module ListWithRange = struct
  include ListMonad
  let (--) from to = [from; from+1; from+2; ...; to]

module DiceRoll (M : MonadWithRange) = struct
  let result =
    (1 -- 6) >>= roll1 in
    (1 -- 6) >>= roll2 in
    return (roll1 + roll2)

(** using Arbitrary to sample *)
DiceRoll(Arbitrary).result (Random.get_state ());;
- : int = 7
DiceRoll(Arbitrary).result (Random.get_state ());;
- : int = 4
DiceRoll(Arbitrary).result (Random.get_state ());;
- : int = 10

(** using List to enumerate possibilities *)
- : int list = [2; 3; 4; 5; 6; 7;
                   3; 4; 5; 6; 7; 8;
                      4; 5; 6; 7; 8; 9;
                         5; 6; 7; 8; 9; 10;
                            6; 7; 8; 9; 10; 11;
                               7; 8; 9; 10; 11; 12]

(** using Distribution to generate probabilities *)
- : int * float list = [(2,  0.02); (3,  0.05); (4,  0.08); (5,  0.11);
                        (6,  0.13); (7,  0.16); (8,  0.13); (9,  0.11);
                        (10, 0.08); (11, 0.05); (12, 0.02)]

Monad laws

If we want to interpret a monadic bind (>>=) as a let ... in expression, there are a few expectations that we might have. These are encapsulated in the "monad laws" (in the same way that we captured the number properties in the AddableSpec, MultipliableSpec, and so on, and we captured the list properties in ListLike.Spec in PS4).

Law 1 says that return just wraps something up and bind just unwraps it: we expect return x >>= fun x -> e to be the same as (fun x -> e) x More concisely: return x >>= f ==== f x This is analogous to our expectation that let x = x in e is the same as e

Note that ==== is not "="; it means that the programs should behave the same.

The second law corresponds to the idea that let x = e in x is the same as e In the monadic version, we expect m >>= fun x -> return x ==== m

The third law corresponds to the idea that we can factor out helper functions without changing the meaning of our code. For example, we'd like to think of the following two programs as the same:

(* version 1 *)
let both dx dy =
  dx >>= fun x ->
  dy >>= fun y ->
  return (x,y)
  both (read "f.txt") (read "g.txt") >>= fun (f,g) ->
  return (f^g)


(* version 2 *)
read "f.txt" >>= fun x ->
read "g.txt" >>= fun y ->
return (x,y) >>= fun (f,g) ->
return (f^g)

That is, we would like to be able to inline the body of both. However, if we look closely, the two programs are not the same. If we use the substitution model to evaluate the first version, we get something like

(* version 1 substituted *)
(read "f.txt" >>= fun x ->
 read "g.txt" >>= fun y ->
 return (x,y))
>>= fun (f,g) ->
return (f^g)

This looks similar to the second version, but the parentheses are in a different place. The third monad states that these two versions should behave the same. To see the general form of this rule, we'll first factor out the common code from the two examples:

let f = fun x     -> read "g.txt" >>= fun y -> return (x,y)
let g = fun (f,g) -> return (f^g)
let m = read "f.txt"

(* version 1 substituted with f and g pulled out *)
(m >>= f) >>= g

(* version 2 with f and g pulled out *)
m   >>= fun x ->
f x >>= g

Law 3 says that these two should behave the same:

(m >>= f) >>= g    ====    m   >>= fun x ->
                           f x >>= g