# Forward-mode Autodiff: A quick demo¶

In [1]:
struct Dual <: Number
f::Float64
dfdt::Float64
end

Base.promote_rule(::Type{Dual},::Type{Float64}) = Dual
Base.promote_rule(::Type{Dual},::Type{Int64}) = Dual

function Dual(f::Float64)
return Dual(f,0.0)
end

function Dual(f::Integer)
return Dual(Float64(f),0.0)
end

function Base.:+(x::Dual, y::Dual)
return Dual(x.f + y.f, x.dfdt + y.dfdt)
end

function Base.:-(x::Dual, y::Dual)
return Dual(x.f - y.f, x.dfdt - y.dfdt)
end

function Base.:-(y::Dual)
return Dual(-y.f, -y.dfdt)
end

function Base.:*(x::Dual, y::Dual)
return Dual(x.f * y.f, x.dfdt * y.f + x.f * y.dfdt)
end

function Base.:/(x::Dual, y::Dual)
return Dual(x.f / y.f, x.dfdt / y.f - x.f * y.dfdt / y.f^2)
end

function Base.exp(x::Dual)
return Dual(exp(x.f), exp(x.f) * x.dfdt)
end

function Base.sin(x::Dual)
return Dual(sin(x.f), cos(x.f) * x.dfdt)
end

function Base.cos(x::Dual)
return Dual(cos(x.f), -sin(x.f) * x.dfdt)
end


Consider the function

$$f(x) = \exp(\sin(x+3)^2).$$

The chain-rule derivative of this is

$$f'(x) = \exp(\sin(x+3)^2) \cdot 2 \sin(x+3) \cdot \cos(x+3).$$
In [2]:
function f(x)
z = sin(x + 3)
return exp(z * z);
end

function dfdx(x)
return exp(sin(x+3)^2) * 2 * sin(x + 3) * cos(x + 3)
end

Out[2]:
dfdx (generic function with 1 method)
In [3]:
f(2)

Out[3]:
2.5081257587058756
In [4]:
# numerical derivative
(f(2 + 1e-5) - f(2))/1e-5

Out[4]:
-1.3644906946996824
In [5]:
dfdx(2)

Out[5]:
-1.3644733615014137
In [6]:
f(Dual(2.0,1.0))

Out[6]:
Dual(2.5081257587058756, -1.364473361501414)
In [7]:
function g(x)
return sin.([0,0.5,1.0,2.0,3.0,3.5,4.0,4.5] .* x);
end

Out[7]:
g (generic function with 1 method)
In [8]:
g(Dual(2.0,1.0))

Out[8]:
8-element Array{Dual,1}:
Dual(0.0, 0.0)
Dual(0.8414709848078965, 0.2701511529340699)
Dual(0.9092974268256817, -0.4161468365471424)
Dual(-0.7568024953079282, -1.3072872417272239)
Dual(-0.27941549819892586, 2.880510859951098)
Dual(0.6569865987187891, 2.638657890201566)
Dual(0.9893582466233818, -0.5820001352344542)
Dual(0.4121184852417566, -4.100086178481046)

The "dual number" representation computed the derivative automatically! And we can even do it all-at-once for vector-valued functions!

# Reverse-mode Autodiff¶

A really dumb, but simple, reverse-mode autodiff.

In [9]:
mutable struct BkwdDiffable <: Number
u::Float64
dhdx::Float64
order::Int64
back::Function
end

Base.promote_rule(::Type{BkwdDiffable},::Type{Float64}) = BkwdDiffable
Base.promote_rule(::Type{BkwdDiffable},::Type{Int64}) = BkwdDiffable

global bkwd_diff_node_order = 0

function BkwdDiffable(x::Float64)
global bkwd_diff_node_order
order = bkwd_diff_node_order
bkwd_diff_node_order += 1;
return BkwdDiffable(x,0.0,order,() -> BkwdDiffable[])
end

function BkwdDiffable(x::Integer)
return BkwdDiffable(Float64(x))
end

function Base.:+(x::BkwdDiffable, y::BkwdDiffable)
rv = BkwdDiffable(x.u + y.u)
function bkwd()
x.dhdx += rv.dhdx
y.dhdx += rv.dhdx
return BkwdDiffable[x,y]
end
rv.back = bkwd
return rv
end

function Base.:-(x::BkwdDiffable, y::BkwdDiffable)
rv = BkwdDiffable(x.u - y.u)
function bkwd()
x.dhdx += rv.dhdx
y.dhdx -= rv.dhdx
return BkwdDiffable[x,y]
end
rv.back = bkwd
return rv
end

function Base.:-(y::BkwdDiffable)
rv = BkwdDiffable(-y.u)
function bkwd()
y.dhdx -= rv.dhdx
return BkwdDiffable[y]
end
rv.back = bkwd
return rv
end

function Base.:*(x::BkwdDiffable, y::BkwdDiffable)
rv = BkwdDiffable(x.u * y.u)
function bkwd()
x.dhdx += rv.dhdx * y.u
y.dhdx += rv.dhdx * x.u
return BkwdDiffable[x,y]
end
rv.back = bkwd
return rv
end

function Base.:/(x::BkwdDiffable, y::BkwdDiffable)
rv = BkwdDiffable(x.u / y.u)
function bkwd()
x.dhdx += rv.dhdx / y.u
y.dhdx += -rv.dhdx * x.u / y.u^2
return BkwdDiffable[x,y]
end
rv.back = bkwd
return rv
end

function Base.:exp(x::BkwdDiffable)
rv = BkwdDiffable(exp(x.u))
function bkwd()
x.dhdx += rv.dhdx * exp(x.u)
return BkwdDiffable[x]
end
rv.back = bkwd
return rv
end

function Base.:sin(x::BkwdDiffable)
rv = BkwdDiffable(sin(x.u))
function bkwd()
x.dhdx += rv.dhdx * cos(x.u)
return BkwdDiffable[x]
end
rv.back = bkwd
return rv
end

function Base.:cos(x::BkwdDiffable)
rv = BkwdDiffable(cos(x.u))
function bkwd()
x.dhdx += -rv.dhdx * sin(x.u)
return BkwdDiffable[x]
end
rv.back = bkwd
return rv
end

# pass in the objective we want to differentiate with respect to
function reverse_autodiff(h::BkwdDiffable)
h.dhdx = 1.0;
frontier = BkwdDiffable[h]
while length(frontier) != 0
# sort the active compute graph elements to find the newest one
sort!(frontier; by = (x) -> x.u)
# compute the backward step on this element
end
end

Out[9]:
reverse_autodiff (generic function with 1 method)
In [10]:
x = BkwdDiffable(2.0)

Out[10]:
BkwdDiffable(2.0, 0.0, 0, var"#1#2"())
In [11]:
fx = f(x)

Out[11]:
BkwdDiffable(2.5081257587058756, 0.0, 5, var"#bkwd#8"{BkwdDiffable,BkwdDiffable}(BkwdDiffable(0.9195357645382262, 0.0, 4, var"#bkwd#6"{BkwdDiffable,BkwdDiffable,BkwdDiffable}(BkwdDiffable(-0.9589242746631385, 0.0, 3, var"#bkwd#9"{BkwdDiffable,BkwdDiffable}(BkwdDiffable(5.0, 0.0, 2, var"#bkwd#3"{BkwdDiffable,BkwdDiffable,BkwdDiffable}(BkwdDiffable(2.0, 0.0, 0, var"#1#2"()), BkwdDiffable(3.0, 0.0, 1, var"#1#2"()), BkwdDiffable(#= circular reference @-2 =#))), BkwdDiffable(#= circular reference @-2 =#))), BkwdDiffable(-0.9589242746631385, 0.0, 3, var"#bkwd#9"{BkwdDiffable,BkwdDiffable}(BkwdDiffable(5.0, 0.0, 2, var"#bkwd#3"{BkwdDiffable,BkwdDiffable,BkwdDiffable}(BkwdDiffable(2.0, 0.0, 0, var"#1#2"()), BkwdDiffable(3.0, 0.0, 1, var"#1#2"()), BkwdDiffable(#= circular reference @-2 =#))), BkwdDiffable(#= circular reference @-2 =#))), BkwdDiffable(#= circular reference @-2 =#))), BkwdDiffable(#= circular reference @-2 =#)))
In [12]:
fx.u

Out[12]:
2.5081257587058756
In [13]:
reverse_autodiff(fx)

In [14]:
x.dhdx

Out[14]:
-1.3644733615014137

This also makes it feasible to compute gradients!

In [15]:
function loss(w)
return sin(sum([0,0.5,1.0,2.0,3.0,3.5,4.0,4.5] .* exp.(w)));
end

Out[15]:
loss (generic function with 1 method)
In [16]:
w = [BkwdDiffable(0.0) for i in 1:8]

Out[16]:
8-element Array{BkwdDiffable,1}:
BkwdDiffable(0.0, 0.0, 6, var"#1#2"())
BkwdDiffable(0.0, 0.0, 7, var"#1#2"())
BkwdDiffable(0.0, 0.0, 8, var"#1#2"())
BkwdDiffable(0.0, 0.0, 9, var"#1#2"())
BkwdDiffable(0.0, 0.0, 10, var"#1#2"())
BkwdDiffable(0.0, 0.0, 11, var"#1#2"())
BkwdDiffable(0.0, 0.0, 12, var"#1#2"())
BkwdDiffable(0.0, 0.0, 13, var"#1#2"())
In [17]:
lossw = loss(w);

In [18]:
lossw.u

Out[18]:
-0.34248061846961253
In [19]:
loss(zeros(8))

Out[19]:
-0.34248061846961253
In [20]:
reverse_autodiff(lossw)

In [21]:
w

Out[21]:
8-element Array{BkwdDiffable,1}:
BkwdDiffable(0.0, 0.0, 6, var"#1#2"())
BkwdDiffable(0.0, 0.469762446874128, 7, var"#1#2"())
BkwdDiffable(0.0, 0.939524893748256, 8, var"#1#2"())
BkwdDiffable(0.0, 1.879049787496512, 9, var"#1#2"())
BkwdDiffable(0.0, 2.818574681244768, 10, var"#1#2"())
BkwdDiffable(0.0, 3.288337128118896, 11, var"#1#2"())
BkwdDiffable(0.0, 3.758099574993024, 12, var"#1#2"())
BkwdDiffable(0.0, 4.227862021867152, 13, var"#1#2"())
In [22]:
# just check the 6th element numerically
e6 = zeros(8); e6[6] = 1.0;
(loss(zeros(8) + 1e-5 * e6) - loss(zeros(8))) / 1e-5

Out[22]:
3.288374546278616
In [23]:
w[6].dhdx

Out[23]:
3.288337128118896
In [ ]: