struct Dual{T<:Number} <: Number
:: T # Value
val :: T # Variation
δ end
value(x :: Dual) = x.val
variation(x :: Dual) = x.δ
11 Automatic Differentiation
- Automatic differentiation: forward, backward, adjoint
11.1 Dual numbers
Computing with variations is not only useful for pen-and-paper computations; it is also the basis for automatic differentiation with so-called dual numbers. We describe some of the fundamentals of working with dual numbers below. A more full-featured implementation of the dual numbers is given in the ForwardDiff.jl
package.
11.1.1 Scalar computations
A dual number is a pair \((x, \delta x)\) consisting of a value \(x\) and a variation \(\delta x\). As above, above, we can think of this as a function \(\tilde{x}(\epsilon) = x + \epsilon \, \delta x + o(\epsilon)\). For ordinary scalar numbers, we represent this with a Julia structure.
We want to allow constants, which are dual numbers with a zero variation. We can construct these directly or convert them from other numbers.
Dual{T}(x :: T) where {T} = Dual{T}(x, zero(T))
Base.convert(::Type{Dual{T}}, x :: S) where {T,S<:Number} =
Dual{T}(x, zero(T))
We also want to be able to convert between types of Julia dual numbers, and we want promotion rules for doing arithmetic that involves dual numbers together with other types of numbers.
Base.convert(::Type{Dual{T}}, x :: Dual{S}) where {T,S} =
Dual{T}(x.val, x.δ)
Base.promote_rule(::Type{Dual{T}}, ::Type{Dual{S}}) where {T,S} =
promote_type(T,S)}
Dual{Base.promote_rule(::Type{Dual{T}}, ::Type{S}) where {T,S<:Number} =
promote_type(T,S)} Dual{
We would like to define a variety of standard operations (unary and binary) for dual numbers. Rather than writing the same boilerplate code for every function, we write macros @dual_unary
and @dual_binary
that take the name of an operation and the formula for the derivative in terms of \(x, \delta x, y, \delta y\), where \(x\) and \(y\) are the first and second arguments, respectively.
macro dual_unary(op, formula)
:(function $op(x :: Dual)
= value(x), variation(x)
x, δx Dual($op(x), $formula)
end)
end
macro dual_binary(op, formula)
:(function $op(x :: Dual, y :: Dual)
= value(x), variation(x)
x, δx = value(y), variation(y)
y, δy Dual($op(x, y), $formula)
end)
end
We overload the +
, -
, and *
operators to work with dual numbers, using the usual rules of differentiation. We also overload both the left and right division operations and the exponentiation operation.
@dual_binary(Base.:+, δx + δy)
@dual_binary(Base.:-, δx - δy)
@dual_unary(Base.:-, -δx)
@dual_binary(Base.:*, δx*y + x*δy)
@dual_binary(Base.:/, (δx*y - x*δy)/y^2)
@dual_binary(Base.:\, (δy*x - y*δx)/x^2)
@dual_binary(Base.:^, x^y*(y*δx/x + log(x)*δy))
We provide a second version of the power function for when the exponent is a constant integer (as is often the case); this allows us to deal with negative values of \(x\) gracefully.
Base.:^(x :: Dual, n :: Integer) = Dual(x.val^n, n*x.val^(n-1)*x.δ)
For comparisons, we will only consider the value and ignore the variation.
Base.:(==)(x :: Dual, y :: Dual) = x.val == y.val
Base.isless(x :: Dual, y :: Dual) = x.val < y.val
For convenience, we also write a handful of the standard functions that one learns to differentiate in a first calculus course.
@dual_unary(Base.abs, δx*sign(x))
@dual_unary(Base.sqrt, δx/sqrt(x))
@dual_unary(Base.exp, exp(x)*δx)
@dual_unary(Base.log, δx/x)
@dual_unary(Base.sin, cos(x)*δx)
@dual_unary(Base.cos, -sin(x)*δx)
@dual_unary(Base.asin, δx/sqrt(1-x^2))
@dual_unary(Base.acos,-δx/sqrt(1-x^2))
With these definitions in place, we can automatically differentiate through a variety of functions without writing any special code. For example, the following code differentiates the Haaland approximation to the Darcy-Weisbach friction factor in pipe flow and compares to a finite difference approximation:
let
fhaaland(ϵ, D, Re) = 1.0/(-1.8*log( (ϵ/D/3.7)^1.11 + 6.9/Re ))^2
= fhaaland(0.01, 1.0, Dual(3000, 1)).δ
δy = finite_diff(Re->fhaaland(0.01, 1.0, Re), 3000, h=1e-3)
δy_fd isapprox(δy, δy_fd, rtol=1e-6)
end
true
11.1.2 Matrix computations
We can also use dual numbers inside of some of the linear algebra routines provided by Julia. For example, consider \(x = A^{-1} b\) where \(b\) is treated as constant; the following code automatically computes both \(A^{-1} b\) and \(-A^{-1} (\delta A) A^{-1} b\).
function test_dual_solve()
= [1.0 2.0; 3.0 4.0]
Aval = [1.0 0.0; 0.0 0.0]
δA = [3.0; 4.0]
b = Dual.(Aval, δA)
A = A\b
x
# Compare ordinary and variational parts to formulas
value.(x) ≈ Aval\b &&
variation.(x) ≈ -Aval\(δA*(Aval\b))
end
test_dual_solve()
true
While this type of automatic differentiation through a matrix operation works, it is relatively inefficient. We can also define operations that act at a matrix level for relatively expensive operations like matrix multiplication and linear solves. We give as an example code in Julia for more efficient linear solves with matrices of dual numbers, using LU factorization (Chapter 16) as the basis for applying \(A^{-1}\) to a matrix or vector:
function Base.:\(AA :: AbstractMatrix{Dual{T}},
:: AbstractVecOrMat{Dual{S}}) where {T,S}
BB = value.(AA), variation.(AA)
A, δA = value.(BB), variation.(BB)
B, δB = lu(A)
F = F\B
X Dual.(X, F\(δB-δA*X))
end
function Base.:\(AA :: AbstractMatrix{Dual{T}},
:: AbstractVecOrMat{S}) where {T,S}
B = value.(AA), variation.(AA)
A, δA = lu(A)
F = F\B
X Dual.(X, F\(-δA*X))
end
function Base.:\(A :: AbstractMatrix{T},
:: AbstractVecOrMat{Dual{S}}) where {T,S}
BB = value.(BB), variation.(BB)
B, δB = lu(A)
F Dual.(F\B, F\δB)
end
test_dual_solve()
true
11.1.3 Special cases
There are other cases as well where automatic differentiation using dual numbers needs a little help. For example, consider the thin-plate spline function, which has a removable singularity at zero:
ϕtps(r) = r == 0.0 ? 0.0 : r^2*log(abs(r))
If we treat \(r\) as a dual number, the output for \(r = 0\) will be an ordinary floating point number, while the output for every other value of \(r\) will be a dual number. However, we can deal with this by writing a specialized version of the function for dual numbers.
ϕtps(r :: Dual) = r == 0.0 ? Dual(0.0, 0.0) : r^2*log(r)
With this version, the function works correctly at both zero and nonzero dual number arguments:
ϕtps(Dual(0.0, 1.0)), ϕtps(Dual(1.0, 1.0))
(Dual{Float64}(0.0, 0.0), Dual{Float64}(0.0, 1.0))
In addition to difficulties with removable singularities, automatic differentiation systems may lose accuracy do to floating point effects even for functions that are well-behaved. We return to this in Chapter 9.
11.2 Forward and backward
In the previous section, we described automatic differentiation by tracking variations (dual numbers) through a computation, often known as forward mode differentiation. An alternate approach known as backward mode (or adjoint mode) differentiation involves tracking a different set of variables (dual variables)
11.2.1 An example
We consider again the sample function given in the previous section: \[ f(x_1, x_2) = \left( -1.8 \log (x_1/3.7)^{1.11} + 6.9/x_2 \right)^{-2}. \] We would usually write this concisely as
fhaaland0(x1, x2) = (-1.8*log( (x1/3.7)^1.11 + 6.9/x2 ))^-2
When we compute or manipulate such somewhat messy expressions (e.g. for differentiation), it is useful to split them into simpler subexpressions (as in the single static assignment (SSA) style introduced in Chapter 2). For example, we can rewrite \(f(x_1, y_1)\) in terms of seven intermediate expressions, and compute variations of each intermediate to get a variation of the final result. In code, this is
function fhaaland1(x1, x2)
= x1/3.7
y1 = y1^1.11
y2 = 6.9/x2
y3 = y2+y3
y4 = log(y4)
y5 = -1.8*y5
y6 = y6^-2
y7 function Df(δx1, δx2)
= δx1/3.7
δy1 = 1.11*y1^0.11 * δy1
δy2 = -6.9/x2^2 * δx2
δy3 = δy2+δy3
δy4 = δy4/y4
δy5 = -1.8*δy5
δy6 = -2*y6^(-3)*δy6
δy7 end
y7, Dfend
The function Df
inside fhaaland1
computes derivatives similarly to using dual numbers as in the previous section.
Another way to think of this computation is that we have solved for the intermediate \(y\) variables as a function of \(x\) from the relationship \[ G(x, y) = h(x,y) - y = 0, \] where the rows of \(h\) are the seven equations above, e.g. \(h_1(x,y) = x_1/3.7\). Differentiating this relation gives us \[ D_1 h(x,y) \, \delta x + \left( D_2 h(x,y) - I \right) \, \delta y = 0. \] The formula for the variations in the \(y\) variables can be thought of as coming from using forward substitution to solve the linear system \[ (D_2 h(x,y) -I) \delta y = -D_1 h(x,y) \delta x. \] That is, we compute the components of \(\delta y\) in order by observing that \(h_i(x,y)\) depends only on \(y_1, \ldots, y_{i-1}\) and writing \[ \delta y_i = D_1 h_i(x,y) \, \delta x + \sum_{j=1}^{i=1} \left( D_2 h_i(x,y) \right)_j \, \delta y_j. \]
A key observation is that we do not then use all of \(\delta y\); we only care about \[ f'(x) \delta x = w^* \delta y, \] where \(w^*\) is a functional that selects the desired output (here \(w^* = e_7^*\)). Putting the steps of the previous computation together, we have \[ f'(x) \, \delta x = w^* \left[ \left( D_2 h(x,y) - I)^{-1} \left( -D_1 h(x_y) \, \delta x \right) \right) \right]. \] But associativity, we could also write this as \[ f'(x) \, \delta x = \left[ \left( -w^* (D_2 h(x,y) - I)^{-1} \right) D_1 h(x,y) \right] \, \delta x. \] Giving names to the parenthesized pieces of this latter equation, we have \[\begin{aligned} \bar{y}^* &= -w^* (D_2 h(x,y) - I)^{-1} \\ f'(x) &= \bar{y}^* D_1 h(x,y). \end{aligned}\] Where we solved the system for the variations \(\delta y\) by forward substitution, we solve the system dual variables \(\bar{y}^*\) by backward substitution, applying the formula \[ \bar{y}_j^* = w_j^* + \sum_{i=j+1}^n \bar{y}_i^* (D_2 h_i(x,y))_j \] for \(j = m, m-1, \ldots, 1\). Because this computation can be written in terms of a solve with the adjoint matrix \((D_2 h(x,y)-I)^*\) using backward substitution, this method of computing derivatives is sometimes called adjoint mode or backward mode differentiation.
Translating these ideas into code for our example function, we have
function fhaaland2(x1, x2)
= x1/3.7
y1 = y1^1.11
y2 = 6.9/x2
y3 = y2+y3
y4 = log(y4)
y5 = -1.8*y5
y6 = y6^-2
y7
7 = 1
ȳ6 = ȳ7 * (-2*y6^-3)
ȳ5 = ȳ6 * (-1.8)
ȳ4 = ȳ5/y4
ȳ3 = ȳ4
ȳ2 = ȳ4
ȳ1 = ȳ2 * (1.11*y1^0.11)
ȳ
= [ȳ1/3.7 ȳ3*(-6.9/x2^2)]
Df
y7, Dfend
We now have three different methods for computing the derivative of \(f\), which we illustrate below (as well as checking that our computations agree with each other).
let
# Compute f and the derivative using dual numbers
= fhaaland0(Dual(0.01, 1.0), 3000)
f0a = fhaaland0(0.01, Dual(3000, 1))
f0b = [variation(f0a) variation(f0b)]
Df0
# Compute using the SSA-based forward mode code
= fhaaland1(0.01, 3000)
f1, Df1fun = [Df1fun(1,0) Df1fun(0,1)]
Df1
# Compute using the SSA-based backward mode code
= fhaaland2(0.01, 3000)
f2, Df2
# Compare to see that all give the same results
== f2 && f1 ≈ value(f0a) && f1 ≈ value(f0b) &&
f1 1] ≈ Df2[1] && Df0[1] ≈ Df2[1] &&
Df1[2] ≈ Df2[2] && Df0[2] ≈ Df2[2]
Df1[end
true
We note that we only need to compute the \(\bar{y}\) variables once to get the derivative vector, where for forward mode we need to compute the \(\delta y\) variables twice.
The cost of computing the derivative in one direction \((\delta x_1, \delta y_2)\) is only about as great as the cost of evaluating \(f\). But if we wanted the derivative with respect to both \(x_1\) and \(x_2\), we would need to run at least the code in Df
twice (and with the dual number implementation from the previous section, we would also run the evaluation of \(f\) twice).
11.3 A derivative example
Beyond writing language utilities, Julia macros are useful for writing embedded domain-specific languages (DSLs) for accomplishing particular tasks. In this setting, we are really writing a language interpreter embedded in Julia, using the Julia parse tree as an input and producing Julia code as output.
One example has to do with automatic differentiation of simple code expressions, where we are giving another interpretation to simple Julia expressions. This is a more ambitious example, and can be safely skipped over. However, it is also a useful example of a variety of features in Julia. Some types of automatic differentiation can be done in arguably simpler and more powerful ways without writing macros, and we will return to this in later chapters.
11.3.1 Normalization
We define a “simple” expression as one that only involves only:
- Literal nodes (including symbols),
Expr
nodes ofcall
type, including arithmetic and function calls,- Assignment statements,
- Tuple constructors,
begin
/end
blocks.
We can test for simple expressions by recursing through an expression tree
# Simple expression nodes
is_simple(e :: Expr) =
e.head in [:call, :(=), :tuple, :block] &&
all(is_simple.(e.args))
# Other nodes are literal, always simple
is_simple(e) = true
In addition to insisting that expressions are simple, we also want to simplify some peculiarities in Julia’s parsing of addition and multiplication. In many languages, an expression like 1 + 2 + 3
is parsed as two operations: (1 + 2) + 3
. In Julia, this results in a single node. It will simplify our life to convert these types of nodes to binary nodes, assuming left associativity of the addition operation. We do this by recursively rewriting the expression tree to replace a particular type of operation node (calling op
) with a left-associated version of that same node:
leftassoc(op :: Symbol, e) = e
function leftassoc(op :: Symbol, e :: Expr)
= leftassoc.(op, e.args)
args if e.head == :call && args[1] == op
foldl((x,y) -> Expr(:call, op, x, y), args[2:end])
else
Expr(e.head, args...)
end
end
We are particularly interested in left associativity of addition and multiplication, so we write a single function for those operations
leftassoc(e) = leftassoc(:+, leftassoc(:*, e))
We can see the effect by printing the parse of a particular example:
let
e = :(a*b*c + 2 + 3)
println("Before: $e")
println("After: $(leftassoc(e))")
end
Before: a * b * c + 2 + 3
After: ((a * b) * c + 2) + 3
Finally, the Julia parser adds LineNumberNodes
to blocks in order to aid with debugging. We will dispose of those here so that we can focus solely on processing expressions.
filter_line(e) = e
filter_line(e :: Expr) =
Expr(e.head,
filter(x->!(x isa LineNumberNode),
filter_line.(e.args))...)
Putting everything together, we will normalize simple expressions by eliminating line numbers and re-associating addition and multiplication.
normalize_expr(e) = e
function normalize_expr(e :: Expr)
@assert is_simple(e) "Expression must be simple"
leftassoc(filter_line(e))
end
11.3.2 SSA
Simple expressions of the type that we have described can be converted to static single assignment (SSA) form, where each intermediate subexpression is assigned a unique local name (which we produce with a call to gensym
). We represent a program in SSA form as a vector of (symbol, expression)
pairs to be interpreted as “evaluate the expression and assign it to the symbol.” We emit SSA terms by appending them to the end of a list.
function ssa_generator()
= []
result
# Add a single term to the SSA result
function emit!(s, e)
push!(result, (s,e))
send
# Add several terms to the SSA result;
# s is the designated "final result"
function emit!(s, l :: Vector)
append!(result, l)
send
# Emit a term with a new local name
emit!(e) = emit!(gensym(), e)
result, emit!end
The to_ssa
function returns a final result in SSA form Each symbol is only assigned once. Each arithmetic, function call, and tuple constructor results in a new symbol. For assignments, we record a binding of the left hand side symbol to the result computed for the right hand side.
function to_ssa(e :: Expr)
# Check and normalize the expression
e = normalize_expr(e)
# Set up results list and symbol table
= ssa_generator()
elist, emit! = Dict{Symbol,Any}()
symbol_table
# Functions for recursive processing
process(e) = e
process(e :: Symbol) = get(symbol_table, e, e)
function process(e :: Expr)
= process.(e.args)
args if e.head == :block
end]
args[elseif e.head == :(=)
e.args[1]] = args[2]
symbol_table[2]
args[else
emit!(Expr(e.head, args...))
end
end
process(e), elist
end
In case we want to convert a single leaf to SSA form, we define a second method:
function to_ssa(e)
= gensym()
s e)]
s, [(s,end
We also want a utility to compute from SSA back to ordinary Julia code
from_ssa(sym, elist) =
Expr(:block,
Expr(:(=), s, e) for (s,e) in elist]...,
[ sym)
from_ssa (generic function with 1 method)
As is often the case, an example is useful to illustrate the code.
let
= to_ssa(
s, elist quote
= 10
x = 10+x+1
x = x*x
y end)
from_ssa(s, elist)
end
quote
var"##230" = 10 + 10
var"##231" = var"##230" + 1
var"##232" = var"##231" * var"##231"
var"##232"
end
11.3.3 Derivative function
To set up our differentiation function, we will need some differentiation rules. We store this in a dictionary where the keys are different types of operations and functions, and the arguments are the arguments to those functions and their derivatives. Each rule takes expressions for the value of the function, the arguments, and the derivatives of the arguments, and then returns the derivative of the function. In some cases, we may want different methods associated with the same function, e.g. to distinguish between negation and subtraction or to provide a special case of the power function where the exponent is an integer constant.
deriv_minus(f,x,dx) = :(-$dx)
deriv_minus(f,x,y,dx,dy) = :($dx-$dy)
deriv_pow(f,x,n :: Int, dx, dn) = :($n*$x^$(n-1)*$dx)
deriv_pow(f,x,y,dx,dy) = :($f*($dy/$x*$dx + $dy*log($x)))
= Dict(
deriv_rules :+ => (f,x,y,dx,dy) -> :($dx+$dy),
:* => (f,x,y,dx,dy) -> :($dx*$y + $x*$dy),
:/ => (f,x,y,dx,dy) -> :(($dx - $f*$dy)/$y),
:- => deriv_minus,
:^ => deriv_pow,
:log => (f,x,dx) -> :($dx/$x),
:exp => (f,x,dx) -> :($f*$dx)
)
Now we are in a position to write code to simultaneously evaluate the expression and the derivative (with respect to some symbol s
). To do this, it is helpful to give a (locally defined) name to each subexpression and its derivative, and to produce code that is a sequence of assignments to those names. We accumulate those assignments into a list (elist
). An internal process
function with different methods for different types of objects is used to generate the assignments for the subexpression values and the associated derivatives.
function derivative(s :: Symbol, e :: Expr)
# Convert input to SSA form, set up generator
= to_ssa(e)
sym, elist = ssa_generator()
eresult, emit!
# Derivatives of leaves, init with ds/ds = 1.
= Dict{Symbol,Any}()
deriv_table = 1
deriv_table[s] deriv(e :: Symbol) = get(deriv_table, e, 0)
deriv(e) = 0
# Rules to generate differentiation code
function derivcode(s, e :: Expr)
if e.head == :call
= deriv_rules[e.args[1]]
rule = rule(s, e.args[2:end]...,
dexpr deriv.(e.args[2:end])...)
emit!(to_ssa(dexpr)...)
elseif e.head == :tuple
emit!(Expr(:tuple, deriv.(e.args)...))
else
error("Unexpected expression of type $(e.head)")
end
end
derivcode(s, e) = emit!(gensym(), deriv(e))
# Produce code for results and derivatives
for (s,e) in elist
emit!(s,e)
= derivcode(s, e)
deriv_table[s] end
# Add a tuple for return at the end (function + deriv)
emit!(Expr(:tuple, sym, deriv_table[sym])), eresult
end
As an example, consider computing the derivative of \(mx + b\) with respect to \(x\)
from_ssa(derivative(:x, :(m*x+b))...)
quote
var"##233" = m * x
var"##235" = 0 * x
var"##236" = m * 1
var"##237" = var"##235" + var"##236"
var"##234" = var"##233" + b
var"##238" = var"##237" + 0
var"##239" = (var"##234", var"##238")
var"##239"
end
Giving more comprehensible variable names, this is equivalent to
= m * x
y1 = 0 * x
dy1a = m * 1
dy1b = dy1a + dy1b
dy1 = y1 + b
y2 = dy1 + 0
dy2 = (y2, dy2)
result result
This is correct, but it is also a very complicated-looking way to compute \(m\)! We therefore consider putting in some standard simplifications.
11.3.4 Simplification
There are many possible algebraic relationships that we can use to simplify derivatives. Some of these are relationships, like multiplication by zero, are things that we believe are safe but the compiler cannot safely do on our behalf1. Others are things that the compiler could probably do on our behalf, but we might want to do on our own to understand how these things work.
We will again use the matching rules on multiple dispatch to write our simplification rules as different methods under a common name. For example, the simplification rules for \(x + y\) (where the prior expression is saved as \(e\)) are:
# Rules to simplify e = x+y
simplify_add(e, x :: Number, y :: Number) = x+y
simplify_add(e, x :: Number, y :: Symbol) = x == 0 ? y : e
simplify_add(e, x :: Symbol, y :: Number) = y == 0 ? x : e
simplify_add(e, x, y) = e
The rules for minus (unary and binary) are similar in flavor. We will give the unary and binary versions both the same name, and distinguish between the two cases based on the number of arguments that we see.
# Rules to simplify e = -x
simplify_sub(e, x :: Number) = -x
simplify_sub(e, x) = e
# Rules to simplify e = x-y
simplify_sub(e, x :: Number, y :: Symbol) = x == 0 ? :(-$y) : e
simplify_sub(e, x :: Symbol, y :: Number) = y == 0 ? y : e
simplify_sub(e, x, y) = e
With multiplication, we will use the ordinary observation that anything times zero should be zero, ignoring the peculiarities of floating point.
# Rules to simplify e = x*y
simplify_mul(e, x :: Number, y :: Number) = x*y
simplify_mul(e, x :: Number, y :: Symbol) =
if x == 0 0
elseif x == 1 y
else e
end
simplify_mul(e, x :: Symbol, y :: Number) =
if y == 0 0
elseif y == 1 x
else e
end
simplify_mul(e, x, y) = e
Finally, we define rules for simplifying quotients and powers.
# Rules to simplify e = x/y
simplify_div(e, x :: Number, y :: Number) = x/y
simplify_div(e, x :: Symbol, y :: Number) = y == 1 ? x : e
simplify_div(e, x :: Number, y :: Symbol) = x == 0 ? 0 : e
simplify_div(e, x, y) = e
# Simplify powers
simplify_pow(e, x :: Symbol, n :: Number) =
if n == 1 x
elseif n == 0 1
else e
end
simplify_pow(e, x, n) = e
To simplify a line in an SSA assignment, we look up the appropriate rule function in a table (as we did with differentiation rules) and apply it.
= Dict(
simplify_rules :+ => simplify_add,
:- => simplify_sub,
:* => simplify_mul,
:/ => simplify_div,
:^ => simplify_pow)
simplify_null(e, args...) = e
simplify(e) = e
simplify(e :: Expr) =
if e.head == :call
= get(simplify_rules, e.args[1], simplify_null)
rule rule(e, e.args[2:end]...)
else
e
end
simplify (generic function with 2 methods)
Putting everything together, our rule for simplifying code in SSA form is:
- Look up whether a previous simplification replaced any arguments with constants or with other symbols. If a replacement was also replaced earlier, reply the rule recursively.
- Apply the appropriate simplification rule.
- If the expression now looks like
(symbol, x)
wherex
is a leaf (a constant or another symbol), make a record in the lookup table to replace future references tosymbol
with references tox
.
The following code carries out this task.
function simplify_ssa(sym, elist)
# Set up and add to replacement table
= Dict{Symbol,Any}()
replacements function replacement(s,e)
= e
replacements[s] end
# Apply a replacement policy
replace(e) = e
replace(s :: Symbol) =
if s in keys(replacements)
replace(replacements[s])
else
send
# Simplify each term
= ssa_generator()
eresult, emit! for (s,e) in elist
e = e isa Expr ? Expr(e.head, replace.(e.args)...) : e
= simplify(e)
enew if enew isa Number || enew isa Symbol
replacement(s, enew)
end
emit!(s, enew)
end
replace(sym), eresult
end
Finally, we do dead code elimination, keeping only “live” computations on which the final result depends. Liveness is computed recursively: the final result is live, and anything that a live variable depends on is live. Because in SSA each expression depends only on previous results, we can compute liveness by traversing the list from back to front and updating what we know to be live as we go.
function prune_ssa(sym, elist)
# The end symbol is live
= Set([sym])
live
# Propogate liveness
mark_live(s :: Symbol) = push!(live, s)
mark_live(e :: Expr) = mark_live.(e.args)
mark_live(e) = nothing
# If the RHS is live, so is anything it depends on
for (s,e) in reverse(elist)
if s in live
mark_live(e)
end
end
# Return only the live expressions
filter!(p -> p[1] in live, elist)
sym, end
Repeating our experiment with differentiating \(mx + b\) but including simplification and pruning, we get a much terser generated code.
from_ssa(
prune_ssa(
simplify_ssa(
derivative(:x, :(m*x+b))...)...)...)
quote
var"##240" = m * x
var"##241" = var"##240" + b
var"##246" = (var"##241", m)
var"##246"
end
11.3.5 The macro
Finally, we package the derivatives into a macro. The macro gets the derivative information from the derivative
function and packages it into a block that at the end returns the final expression value and derivative as a tuple. We note that the namespace for macros is different from the namespace for functions, so it is fine to use the same name (derivative
) for both our main function and for the macro that calls it. Note that we escape the full output in this case – we want to be able to use the external names, and we only write to unique local symbols.
macro derivative(s :: Symbol, e :: Expr)
= derivative(s, e)
sym, elist = simplify_ssa(sym, elist)
sym, elist = prune_ssa(sym, elist)
sym, elist esc(from_ssa(sym, elist))
end
We will test out the macro on a moderately messy expression that is slightly tedious to compute by hand, and compare our result to the (slightly tedious) hand computation.
let
# An expression and a hand-computed derivative
f(x) = -log(x^2 + 2*exp(x) + (x+1)/x)
df(x) =
-(2*x + 2*exp(x) -1/x^2)/
^2 + 2*exp(x) + (x+1)/x)
(x
# Autodiff of the same expression
g(x) = @derivative(x, -log(x^2 + 2*exp(x) + (x+1)/x))
# Check that the results agree up to roundoff effects
= 2.3
xtest = g(xtest)
gx, dgx ≈ f(xtest) && dgx ≈ df(xtest)
gx end
true
In IEEE floating point,
0 * Inf
is aNaN
, not a zero. The compiler does not know that we don’t expect infinities in the code, and so cannot safely eliminate products with 0.↩︎