11  Automatic Differentiation

11.1 Dual numbers

Computing with variations is not only useful for pen-and-paper computations; it is also the basis for automatic differentiation with so-called dual numbers. We describe some of the fundamentals of working with dual numbers below. A more full-featured implementation of the dual numbers is given in the ForwardDiff.jl package.

11.1.1 Scalar computations

A dual number is a pair \((x, \delta x)\) consisting of a value \(x\) and a variation \(\delta x\). As above, above, we can think of this as a function \(\tilde{x}(\epsilon) = x + \epsilon \, \delta x + o(\epsilon)\). For ordinary scalar numbers, we represent this with a Julia structure.

struct Dual{T<:Number} <: Number
    val :: T  # Value
    δ :: T    # Variation
end

value(x :: Dual) = x.val
variation(x :: Dual) = x.δ

We want to allow constants, which are dual numbers with a zero variation. We can construct these directly or convert them from other numbers.

Dual{T}(x :: T) where {T} = Dual{T}(x, zero(T))
Base.convert(::Type{Dual{T}}, x :: S) where {T,S<:Number} =
    Dual{T}(x, zero(T))

We also want to be able to convert between types of Julia dual numbers, and we want promotion rules for doing arithmetic that involves dual numbers together with other types of numbers.

Base.convert(::Type{Dual{T}}, x :: Dual{S}) where {T,S} =
    Dual{T}(x.val, x.δ)
Base.promote_rule(::Type{Dual{T}}, ::Type{Dual{S}}) where {T,S} =
    Dual{promote_type(T,S)}
Base.promote_rule(::Type{Dual{T}}, ::Type{S}) where {T,S<:Number} =
    Dual{promote_type(T,S)}

We would like to define a variety of standard operations (unary and binary) for dual numbers. Rather than writing the same boilerplate code for every function, we write macros @dual_unary and @dual_binary that take the name of an operation and the formula for the derivative in terms of \(x, \delta x, y, \delta y\), where \(x\) and \(y\) are the first and second arguments, respectively.

macro dual_unary(op, formula)
    :(function $op(x :: Dual)
          x, δx = value(x), variation(x)
          Dual($op(x), $formula)
      end)
end

macro dual_binary(op, formula)
    :(function $op(x :: Dual, y :: Dual)
          x, δx = value(x), variation(x)
          y, δy = value(y), variation(y)
          Dual($op(x, y), $formula)
      end)
end

We overload the +, -, and * operators to work with dual numbers, using the usual rules of differentiation. We also overload both the left and right division operations and the exponentiation operation.

@dual_binary(Base.:+, δx + δy)
@dual_binary(Base.:-, δx - δy)
@dual_unary(Base.:-, -δx)
@dual_binary(Base.:*, δx*y + x*δy)
@dual_binary(Base.:/, (δx*y - x*δy)/y^2)
@dual_binary(Base.:\, (δy*x - y*δx)/x^2)
@dual_binary(Base.:^, x^y*(y*δx/x + log(x)*δy))

We provide a second version of the power function for when the exponent is a constant integer (as is often the case); this allows us to deal with negative values of \(x\) gracefully.

Base.:^(x :: Dual, n :: Integer) = Dual(x.val^n, n*x.val^(n-1)*x.δ)

For comparisons, we will only consider the value and ignore the variation.

Base.:(==)(x :: Dual, y :: Dual)  = x.val == y.val
Base.isless(x :: Dual, y :: Dual) = x.val < y.val

For convenience, we also write a handful of the standard functions that one learns to differentiate in a first calculus course.

@dual_unary(Base.abs,  δx*sign(x))
@dual_unary(Base.sqrt, δx/sqrt(x)) 
@dual_unary(Base.exp,  exp(x)*δx)
@dual_unary(Base.log,  δx/x) 
@dual_unary(Base.sin,  cos(x)*δx)
@dual_unary(Base.cos, -sin(x)*δx) 
@dual_unary(Base.asin, δx/sqrt(1-x^2)) 
@dual_unary(Base.acos,-δx/sqrt(1-x^2)) 

With these definitions in place, we can automatically differentiate through a variety of functions without writing any special code. For example, the following code differentiates the Haaland approximation to the Darcy-Weisbach friction factor in pipe flow and compares to a finite difference approximation:

let
    fhaaland(ϵ, D, Re) = 1.0/(-1.8*log( (ϵ/D/3.7)^1.11 + 6.9/Re ))^2
    δy = fhaaland(0.01, 1.0, Dual(3000, 1)).δ
    δy_fd = finite_diff(Re->fhaaland(0.01, 1.0, Re), 3000, h=1e-3)
    isapprox(δy, δy_fd, rtol=1e-6)
 end
true

11.1.2 Matrix computations

We can also use dual numbers inside of some of the linear algebra routines provided by Julia. For example, consider \(x = A^{-1} b\) where \(b\) is treated as constant; the following code automatically computes both \(A^{-1} b\) and \(-A^{-1} (\delta A) A^{-1} b\).

function test_dual_solve()
    Aval = [1.0 2.0; 3.0 4.0]
    δA   = [1.0 0.0; 0.0 0.0]
    b = [3.0; 4.0]
    A = Dual.(Aval, δA)
    x = A\b

    # Compare ordinary and variational parts to formulas
    value.(x)  Aval\b &&
    variation.(x)  -Aval\(δA*(Aval\b))
end

test_dual_solve()
true

While this type of automatic differentiation through a matrix operation works, it is relatively inefficient. We can also define operations that act at a matrix level for relatively expensive operations like matrix multiplication and linear solves. We give as an example code in Julia for more efficient linear solves with matrices of dual numbers, using LU factorization (Chapter 16) as the basis for applying \(A^{-1}\) to a matrix or vector:

function Base.:\(AA :: AbstractMatrix{Dual{T}},
                 BB :: AbstractVecOrMat{Dual{S}}) where {T,S}
    A, δA = value.(AA), variation.(AA)
    B, δB = value.(BB), variation.(BB)
    F = lu(A)
    X = F\B
    Dual.(X, F\(δB-δA*X))
end

function Base.:\(AA :: AbstractMatrix{Dual{T}},
                 B :: AbstractVecOrMat{S}) where {T,S}
    A, δA = value.(AA), variation.(AA)
    F = lu(A)
    X = F\B
    Dual.(X, F\(-δA*X))
end

function Base.:\(A :: AbstractMatrix{T},
                 BB :: AbstractVecOrMat{Dual{S}}) where {T,S}
    B, δB = value.(BB), variation.(BB)
    F = lu(A)
    Dual.(F\B, F\δB)
end

test_dual_solve()
true

11.1.3 Special cases

There are other cases as well where automatic differentiation using dual numbers needs a little help. For example, consider the thin-plate spline function, which has a removable singularity at zero:

ϕtps(r) = r == 0.0 ? 0.0 : r^2*log(abs(r))

If we treat \(r\) as a dual number, the output for \(r = 0\) will be an ordinary floating point number, while the output for every other value of \(r\) will be a dual number. However, we can deal with this by writing a specialized version of the function for dual numbers.

ϕtps(r :: Dual) = r == 0.0 ? Dual(0.0, 0.0) : r^2*log(r)

With this version, the function works correctly at both zero and nonzero dual number arguments:

ϕtps(Dual(0.0, 1.0)), ϕtps(Dual(1.0, 1.0))
(Dual{Float64}(0.0, 0.0), Dual{Float64}(0.0, 1.0))

In addition to difficulties with removable singularities, automatic differentiation systems may lose accuracy do to floating point effects even for functions that are well-behaved. We return to this in Chapter 9.

11.2 Forward and backward

In the previous section, we described automatic differentiation by tracking variations (dual numbers) through a computation, often known as forward mode differentiation. An alternate approach known as backward mode (or adjoint mode) differentiation involves tracking a different set of variables (dual variables)

11.2.1 An example

We consider again the sample function given in the previous section: \[ f(x_1, x_2) = \left( -1.8 \log (x_1/3.7)^{1.11} + 6.9/x_2 \right)^{-2}. \] We would usually write this concisely as

fhaaland0(x1, x2) = (-1.8*log( (x1/3.7)^1.11 + 6.9/x2 ))^-2

When we compute or manipulate such somewhat messy expressions (e.g. for differentiation), it is useful to split them into simpler subexpressions (as in the single static assignment (SSA) style introduced in Chapter 2). For example, we can rewrite \(f(x_1, y_1)\) in terms of seven intermediate expressions, and compute variations of each intermediate to get a variation of the final result. In code, this is

function fhaaland1(x1, x2)
    y1 = x1/3.7
    y2 = y1^1.11
    y3 = 6.9/x2
    y4 = y2+y3
    y5 = log(y4)
    y6 = -1.8*y5
    y7 = y6^-2
    function Df(δx1, δx2)
        δy1 = δx1/3.7
        δy2 = 1.11*y1^0.11 * δy1
        δy3 = -6.9/x2^2 * δx2
        δy4 = δy2+δy3
        δy5 = δy4/y4
        δy6 = -1.8*δy5
        δy7 = -2*y6^(-3)*δy6
    end
    y7, Df
end

The function Df inside fhaaland1 computes derivatives similarly to using dual numbers as in the previous section.

Another way to think of this computation is that we have solved for the intermediate \(y\) variables as a function of \(x\) from the relationship \[ G(x, y) = h(x,y) - y = 0, \] where the rows of \(h\) are the seven equations above, e.g. \(h_1(x,y) = x_1/3.7\). Differentiating this relation gives us \[ D_1 h(x,y) \, \delta x + \left( D_2 h(x,y) - I \right) \, \delta y = 0. \] The formula for the variations in the \(y\) variables can be thought of as coming from using forward substitution to solve the linear system \[ (D_2 h(x,y) -I) \delta y = -D_1 h(x,y) \delta x. \] That is, we compute the components of \(\delta y\) in order by observing that \(h_i(x,y)\) depends only on \(y_1, \ldots, y_{i-1}\) and writing \[ \delta y_i = D_1 h_i(x,y) \, \delta x + \sum_{j=1}^{i=1} \left( D_2 h_i(x,y) \right)_j \, \delta y_j. \]

A key observation is that we do not then use all of \(\delta y\); we only care about \[ f'(x) \delta x = w^* \delta y, \] where \(w^*\) is a functional that selects the desired output (here \(w^* = e_7^*\)). Putting the steps of the previous computation together, we have \[ f'(x) \, \delta x = w^* \left[ \left( D_2 h(x,y) - I)^{-1} \left( -D_1 h(x_y) \, \delta x \right) \right) \right]. \] But associativity, we could also write this as \[ f'(x) \, \delta x = \left[ \left( -w^* (D_2 h(x,y) - I)^{-1} \right) D_1 h(x,y) \right] \, \delta x. \] Giving names to the parenthesized pieces of this latter equation, we have \[\begin{aligned} \bar{y}^* &= -w^* (D_2 h(x,y) - I)^{-1} \\ f'(x) &= \bar{y}^* D_1 h(x,y). \end{aligned}\] Where we solved the system for the variations \(\delta y\) by forward substitution, we solve the system dual variables \(\bar{y}^*\) by backward substitution, applying the formula \[ \bar{y}_j^* = w_j^* + \sum_{i=j+1}^n \bar{y}_i^* (D_2 h_i(x,y))_j \] for \(j = m, m-1, \ldots, 1\). Because this computation can be written in terms of a solve with the adjoint matrix \((D_2 h(x,y)-I)^*\) using backward substitution, this method of computing derivatives is sometimes called adjoint mode or backward mode differentiation.

Translating these ideas into code for our example function, we have

function fhaaland2(x1, x2)
    y1 = x1/3.7
    y2 = y1^1.11
    y3 = 6.9/x2
    y4 = y2+y3
    y5 = log(y4)
    y6 = -1.8*y5
    y7 = y6^-2
    
7 = 1
6 =7 * (-2*y6^-3)
5 =6 * (-1.8)
4 =5/y4
3 =4
2 =4
1 =2 * (1.11*y1^0.11)

    Df = [ȳ1/3.7  ȳ3*(-6.9/x2^2)]
    y7, Df
end

We now have three different methods for computing the derivative of \(f\), which we illustrate below (as well as checking that our computations agree with each other).

let
    # Compute f and the derivative using dual numbers
    f0a = fhaaland0(Dual(0.01, 1.0), 3000)
    f0b = fhaaland0(0.01, Dual(3000, 1))
    Df0 = [variation(f0a) variation(f0b)]

    # Compute using the SSA-based forward mode code
    f1, Df1fun = fhaaland1(0.01, 3000)
    Df1 = [Df1fun(1,0) Df1fun(0,1)]

    # Compute using the SSA-based backward mode code
    f2, Df2 = fhaaland2(0.01, 3000)

    # Compare to see that all give the same results
    f1 == f2 && f1  value(f0a) && f1  value(f0b) &&
    Df1[1]  Df2[1] && Df0[1]  Df2[1] &&
    Df1[2]  Df2[2] && Df0[2]  Df2[2]
end
true

We note that we only need to compute the \(\bar{y}\) variables once to get the derivative vector, where for forward mode we need to compute the \(\delta y\) variables twice.

The cost of computing the derivative in one direction \((\delta x_1, \delta y_2)\) is only about as great as the cost of evaluating \(f\). But if we wanted the derivative with respect to both \(x_1\) and \(x_2\), we would need to run at least the code in Df twice (and with the dual number implementation from the previous section, we would also run the evaluation of \(f\) twice).

11.3 A derivative example

Beyond writing language utilities, Julia macros are useful for writing embedded domain-specific languages (DSLs) for accomplishing particular tasks. In this setting, we are really writing a language interpreter embedded in Julia, using the Julia parse tree as an input and producing Julia code as output.

One example has to do with automatic differentiation of simple code expressions, where we are giving another interpretation to simple Julia expressions. This is a more ambitious example, and can be safely skipped over. However, it is also a useful example of a variety of features in Julia. Some types of automatic differentiation can be done in arguably simpler and more powerful ways without writing macros, and we will return to this in later chapters.

11.3.1 Normalization

We define a “simple” expression as one that only involves only:

  • Literal nodes (including symbols),
  • Expr nodes of call type, including arithmetic and function calls,
  • Assignment statements,
  • Tuple constructors,
  • begin/end blocks.

We can test for simple expressions by recursing through an expression tree

# Simple expression nodes
is_simple(e :: Expr) =
    e.head in [:call, :(=), :tuple, :block] &&
    all(is_simple.(e.args))

# Other nodes are literal, always simple
is_simple(e) = true

In addition to insisting that expressions are simple, we also want to simplify some peculiarities in Julia’s parsing of addition and multiplication. In many languages, an expression like 1 + 2 + 3 is parsed as two operations: (1 + 2) + 3. In Julia, this results in a single node. It will simplify our life to convert these types of nodes to binary nodes, assuming left associativity of the addition operation. We do this by recursively rewriting the expression tree to replace a particular type of operation node (calling op) with a left-associated version of that same node:

leftassoc(op :: Symbol, e) = e
function leftassoc(op :: Symbol, e :: Expr)
    args = leftassoc.(op, e.args)
    if e.head == :call && args[1] == op
        foldl((x,y) -> Expr(:call, op, x, y), args[2:end])
    else
        Expr(e.head, args...)
    end
end

We are particularly interested in left associativity of addition and multiplication, so we write a single function for those operations

leftassoc(e) = leftassoc(:+, leftassoc(:*, e))

We can see the effect by printing the parse of a particular example:

let
    e = :(a*b*c + 2 + 3)
    println("Before: $e")
    println("After:  $(leftassoc(e))")
end
Before: a * b * c + 2 + 3
After:  ((a * b) * c + 2) + 3

Finally, the Julia parser adds LineNumberNodes to blocks in order to aid with debugging. We will dispose of those here so that we can focus solely on processing expressions.

filter_line(e) = e
filter_line(e :: Expr) =
    Expr(e.head,
         filter(x->!(x isa LineNumberNode),
                filter_line.(e.args))...)

Putting everything together, we will normalize simple expressions by eliminating line numbers and re-associating addition and multiplication.

normalize_expr(e) = e
function normalize_expr(e :: Expr)
    @assert is_simple(e) "Expression must be simple"
    leftassoc(filter_line(e))
end

11.3.2 SSA

Simple expressions of the type that we have described can be converted to static single assignment (SSA) form, where each intermediate subexpression is assigned a unique local name (which we produce with a call to gensym). We represent a program in SSA form as a vector of (symbol, expression) pairs to be interpreted as “evaluate the expression and assign it to the symbol.” We emit SSA terms by appending them to the end of a list.

function ssa_generator()
    result = []

    # Add a single term to the SSA result
    function emit!(s, e)
        push!(result, (s,e))
        s
    end

    # Add several terms to the SSA result;
    #  s is the designated "final result"
    function emit!(s, l :: Vector)
        append!(result, l)
        s
    end

    # Emit a term with a new local name
    emit!(e) = emit!(gensym(), e)

    result, emit!
end

The to_ssa function returns a final result in SSA form Each symbol is only assigned once. Each arithmetic, function call, and tuple constructor results in a new symbol. For assignments, we record a binding of the left hand side symbol to the result computed for the right hand side.

function to_ssa(e :: Expr)
    
    # Check and normalize the expression
    e = normalize_expr(e)

    # Set up results list and symbol table
    elist, emit! = ssa_generator()
    symbol_table = Dict{Symbol,Any}()

    # Functions for recursive processing
    process(e) = e
    process(e :: Symbol) = get(symbol_table, e, e)
    function process(e :: Expr)
        args = process.(e.args)
        if e.head == :block
            args[end]
        elseif e.head == :(=)
            symbol_table[e.args[1]] = args[2]
            args[2]
        else
            emit!(Expr(e.head, args...))
        end
    end

    process(e), elist
end

In case we want to convert a single leaf to SSA form, we define a second method:

function to_ssa(e)
    s = gensym()
    s, [(s,e)]
end

We also want a utility to compute from SSA back to ordinary Julia code

from_ssa(sym, elist) =
    Expr(:block,
         [Expr(:(=), s, e) for (s,e) in elist]...,
         sym)
from_ssa (generic function with 1 method)

As is often the case, an example is useful to illustrate the code.

let
    s, elist = to_ssa(
        quote
            x = 10
            x = 10+x+1
            y = x*x
        end)
    from_ssa(s, elist)
end
quote
    var"##230" = 10 + 10
    var"##231" = var"##230" + 1
    var"##232" = var"##231" * var"##231"
    var"##232"
end

11.3.3 Derivative function

To set up our differentiation function, we will need some differentiation rules. We store this in a dictionary where the keys are different types of operations and functions, and the arguments are the arguments to those functions and their derivatives. Each rule takes expressions for the value of the function, the arguments, and the derivatives of the arguments, and then returns the derivative of the function. In some cases, we may want different methods associated with the same function, e.g. to distinguish between negation and subtraction or to provide a special case of the power function where the exponent is an integer constant.

deriv_minus(f,x,dx) = :(-$dx)
deriv_minus(f,x,y,dx,dy) = :($dx-$dy)
deriv_pow(f,x,n :: Int, dx, dn) = :($n*$x^$(n-1)*$dx)
deriv_pow(f,x,y,dx,dy) = :($f*($dy/$x*$dx + $dy*log($x)))

deriv_rules = Dict(
    :+   => (f,x,y,dx,dy) -> :($dx+$dy),
    :*   => (f,x,y,dx,dy) -> :($dx*$y + $x*$dy),
    :/   => (f,x,y,dx,dy) -> :(($dx - $f*$dy)/$y),
    :-   => deriv_minus,
    :^   => deriv_pow,
    :log => (f,x,dx)    -> :($dx/$x),
    :exp => (f,x,dx)    -> :($f*$dx)
)

Now we are in a position to write code to simultaneously evaluate the expression and the derivative (with respect to some symbol s). To do this, it is helpful to give a (locally defined) name to each subexpression and its derivative, and to produce code that is a sequence of assignments to those names. We accumulate those assignments into a list (elist). An internal process function with different methods for different types of objects is used to generate the assignments for the subexpression values and the associated derivatives.

function derivative(s :: Symbol, e :: Expr)

    # Convert input to SSA form, set up generator
    sym, elist = to_ssa(e)
    eresult, emit! = ssa_generator()
    
    # Derivatives of leaves, init with ds/ds = 1.
    deriv_table = Dict{Symbol,Any}()
    deriv_table[s] = 1
    deriv(e :: Symbol) = get(deriv_table, e, 0)
    deriv(e) = 0

    # Rules to generate differentiation code
    function derivcode(s, e :: Expr)
        if e.head == :call
            rule = deriv_rules[e.args[1]]
            dexpr = rule(s, e.args[2:end]...,
                         deriv.(e.args[2:end])...)
            emit!(to_ssa(dexpr)...)
        elseif e.head == :tuple
            emit!(Expr(:tuple, deriv.(e.args)...))
        else
            error("Unexpected expression of type $(e.head)")
        end
    end
    
    derivcode(s, e) = emit!(gensym(), deriv(e))
    
    # Produce code for results and derivatives
    for (s,e) in elist
        emit!(s,e)
        deriv_table[s] = derivcode(s, e)
    end

    # Add a tuple for return at the end (function + deriv)
    emit!(Expr(:tuple, sym, deriv_table[sym])), eresult
end

As an example, consider computing the derivative of \(mx + b\) with respect to \(x\)

from_ssa(derivative(:x, :(m*x+b))...)
quote
    var"##233" = m * x
    var"##235" = 0 * x
    var"##236" = m * 1
    var"##237" = var"##235" + var"##236"
    var"##234" = var"##233" + b
    var"##238" = var"##237" + 0
    var"##239" = (var"##234", var"##238")
    var"##239"
end

Giving more comprehensible variable names, this is equivalent to

y1 = m * x
dy1a = 0 * x
dy1b = m * 1
dy1 = dy1a + dy1b
y2 = y1 + b
dy2 = dy1 + 0
result = (y2, dy2)
result

This is correct, but it is also a very complicated-looking way to compute \(m\)! We therefore consider putting in some standard simplifications.

11.3.4 Simplification

There are many possible algebraic relationships that we can use to simplify derivatives. Some of these are relationships, like multiplication by zero, are things that we believe are safe but the compiler cannot safely do on our behalf1. Others are things that the compiler could probably do on our behalf, but we might want to do on our own to understand how these things work.

We will again use the matching rules on multiple dispatch to write our simplification rules as different methods under a common name. For example, the simplification rules for \(x + y\) (where the prior expression is saved as \(e\)) are:

# Rules to simplify e = x+y
simplify_add(e, x :: Number, y :: Number) = x+y
simplify_add(e, x :: Number, y :: Symbol) = x == 0 ? y : e
simplify_add(e, x :: Symbol, y :: Number) = y == 0 ? x : e
simplify_add(e, x, y) = e

The rules for minus (unary and binary) are similar in flavor. We will give the unary and binary versions both the same name, and distinguish between the two cases based on the number of arguments that we see.

# Rules to simplify e = -x
simplify_sub(e, x :: Number) = -x
simplify_sub(e, x) = e

# Rules to simplify e = x-y
simplify_sub(e, x :: Number, y :: Symbol) = x == 0 ? :(-$y) : e
simplify_sub(e, x :: Symbol, y :: Number) = y == 0 ? y : e
simplify_sub(e, x, y) = e

With multiplication, we will use the ordinary observation that anything times zero should be zero, ignoring the peculiarities of floating point.

# Rules to simplify e = x*y
simplify_mul(e, x :: Number, y :: Number) = x*y
simplify_mul(e, x :: Number, y :: Symbol) =
    if     x == 0  0
    elseif x == 1  y
    else           e
    end
simplify_mul(e, x :: Symbol, y :: Number) =
    if     y == 0  0
    elseif y == 1  x
    else           e
    end
simplify_mul(e, x, y) = e

Finally, we define rules for simplifying quotients and powers.

# Rules to simplify e = x/y
simplify_div(e, x :: Number, y :: Number) = x/y
simplify_div(e, x :: Symbol, y :: Number) = y == 1 ? x : e
simplify_div(e, x :: Number, y :: Symbol) = x == 0 ? 0 : e
simplify_div(e, x, y) = e

# Simplify powers
simplify_pow(e, x :: Symbol, n :: Number) =
    if     n == 1  x
    elseif n == 0  1
    else           e
    end
simplify_pow(e, x, n) = e

To simplify a line in an SSA assignment, we look up the appropriate rule function in a table (as we did with differentiation rules) and apply it.

simplify_rules = Dict(
    :+ => simplify_add,
    :- => simplify_sub,
    :* => simplify_mul,
    :/ => simplify_div,
    :^ => simplify_pow)

simplify_null(e, args...) = e
simplify(e) = e
simplify(e :: Expr) =
    if e.head == :call
        rule = get(simplify_rules, e.args[1], simplify_null)
        rule(e, e.args[2:end]...)
    else
        e
    end
simplify (generic function with 2 methods)

Putting everything together, our rule for simplifying code in SSA form is:

  • Look up whether a previous simplification replaced any arguments with constants or with other symbols. If a replacement was also replaced earlier, reply the rule recursively.
  • Apply the appropriate simplification rule.
  • If the expression now looks like (symbol, x) where x is a leaf (a constant or another symbol), make a record in the lookup table to replace future references to symbol with references to x.

The following code carries out this task.

function simplify_ssa(sym, elist)

    # Set up and add to replacement table
    replacements = Dict{Symbol,Any}()
    function replacement(s,e)
        replacements[s] = e
    end

    # Apply a replacement policy
    replace(e) = e
    replace(s :: Symbol) =
        if s in keys(replacements)
            replace(replacements[s])
        else
            s
        end

    # Simplify each term
    eresult, emit! = ssa_generator()
    for (s,e) in elist
        e = e isa Expr ? Expr(e.head, replace.(e.args)...) : e
        enew = simplify(e)
        if enew isa Number || enew isa Symbol
            replacement(s, enew)
        end
        emit!(s, enew)
    end
    replace(sym), eresult
end

Finally, we do dead code elimination, keeping only “live” computations on which the final result depends. Liveness is computed recursively: the final result is live, and anything that a live variable depends on is live. Because in SSA each expression depends only on previous results, we can compute liveness by traversing the list from back to front and updating what we know to be live as we go.

function prune_ssa(sym, elist)

    # The end symbol is live
    live = Set([sym])

    # Propogate liveness
    mark_live(s :: Symbol) = push!(live, s)
    mark_live(e :: Expr) = mark_live.(e.args)
    mark_live(e) = nothing
    
    # If the RHS is live, so is anything it depends on
    for (s,e) in reverse(elist)
        if s in live
            mark_live(e)
        end
    end

    # Return only the live expressions
    sym, filter!(p -> p[1] in live, elist)
end

Repeating our experiment with differentiating \(mx + b\) but including simplification and pruning, we get a much terser generated code.

from_ssa(
    prune_ssa(
        simplify_ssa(
            derivative(:x, :(m*x+b))...)...)...)
quote
    var"##240" = m * x
    var"##241" = var"##240" + b
    var"##246" = (var"##241", m)
    var"##246"
end

11.3.5 The macro

Finally, we package the derivatives into a macro. The macro gets the derivative information from the derivative function and packages it into a block that at the end returns the final expression value and derivative as a tuple. We note that the namespace for macros is different from the namespace for functions, so it is fine to use the same name (derivative) for both our main function and for the macro that calls it. Note that we escape the full output in this case – we want to be able to use the external names, and we only write to unique local symbols.

macro derivative(s :: Symbol, e :: Expr)
    sym, elist = derivative(s, e)
    sym, elist = simplify_ssa(sym, elist)
    sym, elist = prune_ssa(sym, elist)
    esc(from_ssa(sym, elist))
end

We will test out the macro on a moderately messy expression that is slightly tedious to compute by hand, and compare our result to the (slightly tedious) hand computation.

let
    # An expression and a hand-computed derivative
    f(x) = -log(x^2 + 2*exp(x) + (x+1)/x)
    df(x) =
        -(2*x + 2*exp(x) -1/x^2)/
        (x^2 + 2*exp(x) + (x+1)/x)

    # Autodiff of the same expression
    g(x) = @derivative(x, -log(x^2 + 2*exp(x) + (x+1)/x))

    # Check that the results agree up to roundoff effects
    xtest = 2.3
    gx, dgx = g(xtest)
    gx  f(xtest) && dgx  df(xtest)
end
true

  1. In IEEE floating point, 0 * Inf is a NaN, not a zero. The compiler does not know that we don’t expect infinities in the code, and so cannot safely eliminate products with 0.↩︎