Forward-mode Autodiff: A quick demo

In [1]:
struct Dual <: Number
    f::Float64
    dfdt::Float64
end

Base.promote_rule(::Type{Dual},::Type{Float64}) = Dual
Base.promote_rule(::Type{Dual},::Type{Int64}) = Dual

function Dual(f::Float64)
    return Dual(f,0.0)
end

function Dual(f::Integer)
    return Dual(Float64(f),0.0)
end

function Base.:+(x::Dual, y::Dual)
    return Dual(x.f + y.f, x.dfdt + y.dfdt)
end

function Base.:-(x::Dual, y::Dual)
    return Dual(x.f - y.f, x.dfdt - y.dfdt)
end

function Base.:-(y::Dual)
    return Dual(-y.f, -y.dfdt)
end

function Base.:*(x::Dual, y::Dual)
    return Dual(x.f * y.f, x.dfdt * y.f + x.f * y.dfdt)
end

function Base.:/(x::Dual, y::Dual)
    return Dual(x.f / y.f, x.dfdt / y.f - x.f * y.dfdt / y.f^2)
end

function Base.exp(x::Dual)
    return Dual(exp(x.f), exp(x.f) * x.dfdt)
end

function Base.sin(x::Dual)
    return Dual(sin(x.f), cos(x.f) * x.dfdt)
end

function Base.cos(x::Dual)
    return Dual(cos(x.f), -sin(x.f) * x.dfdt)
end

Consider the function

$$f(x) = \exp(\sin(x+3)^2).$$

The chain-rule derivative of this is

$$f'(x) = \exp(\sin(x+3)^2) \cdot 2 \sin(x+3) \cdot \cos(x+3).$$
In [2]:
function f(x)
    z = sin(x + 3)
    return exp(z * z);
end

function dfdx(x)
    return exp(sin(x+3)^2) * 2 * sin(x + 3) * cos(x + 3)
end
Out[2]:
dfdx (generic function with 1 method)
In [3]:
f(2)
Out[3]:
2.5081257587058756
In [4]:
# numerical derivative
(f(2 + 1e-5) - f(2))/1e-5
Out[4]:
-1.3644906946996824
In [5]:
dfdx(2)
Out[5]:
-1.3644733615014137
In [6]:
f(Dual(2.0,1.0))
Out[6]:
Dual(2.5081257587058756, -1.364473361501414)
In [7]:
function g(x)
    return sin.([0,0.5,1.0,2.0,3.0,3.5,4.0,4.5] .* x);
end
Out[7]:
g (generic function with 1 method)
In [8]:
g(Dual(2.0,1.0))
Out[8]:
8-element Array{Dual,1}:
                                 Dual(0.0, 0.0)
   Dual(0.8414709848078965, 0.2701511529340699)
  Dual(0.9092974268256817, -0.4161468365471424)
 Dual(-0.7568024953079282, -1.3072872417272239)
  Dual(-0.27941549819892586, 2.880510859951098)
    Dual(0.6569865987187891, 2.638657890201566)
  Dual(0.9893582466233818, -0.5820001352344542)
   Dual(0.4121184852417566, -4.100086178481046)

The "dual number" representation computed the derivative automatically! And we can even do it all-at-once for vector-valued functions!

Reverse-mode Autodiff

A really dumb, but simple, reverse-mode autodiff.

In [9]:
mutable struct BkwdDiffable <: Number
    u::Float64
    dhdx::Float64
    order::Int64
    back::Function
end

Base.promote_rule(::Type{BkwdDiffable},::Type{Float64}) = BkwdDiffable
Base.promote_rule(::Type{BkwdDiffable},::Type{Int64}) = BkwdDiffable

global bkwd_diff_node_order = 0

function BkwdDiffable(x::Float64)
    global bkwd_diff_node_order
    order = bkwd_diff_node_order
    bkwd_diff_node_order += 1;
    return BkwdDiffable(x,0.0,order,() -> BkwdDiffable[])
end

function BkwdDiffable(x::Integer)
    return BkwdDiffable(Float64(x))
end

function Base.:+(x::BkwdDiffable, y::BkwdDiffable)
    rv = BkwdDiffable(x.u + y.u)
    function bkwd() 
        x.dhdx += rv.dhdx
        y.dhdx += rv.dhdx
        return BkwdDiffable[x,y]
    end
    rv.back = bkwd
    return rv
end

function Base.:-(x::BkwdDiffable, y::BkwdDiffable)
    rv = BkwdDiffable(x.u - y.u)
    function bkwd() 
        x.dhdx += rv.dhdx
        y.dhdx -= rv.dhdx
        return BkwdDiffable[x,y]
    end
    rv.back = bkwd
    return rv
end

function Base.:-(y::BkwdDiffable)
    rv = BkwdDiffable(-y.u)
    function bkwd() 
        y.dhdx -= rv.dhdx
        return BkwdDiffable[y]
    end
    rv.back = bkwd
    return rv
end

function Base.:*(x::BkwdDiffable, y::BkwdDiffable)
    rv = BkwdDiffable(x.u * y.u)
    function bkwd() 
        x.dhdx += rv.dhdx * y.u
        y.dhdx += rv.dhdx * x.u
        return BkwdDiffable[x,y]
    end
    rv.back = bkwd
    return rv
end

function Base.:/(x::BkwdDiffable, y::BkwdDiffable)
    rv = BkwdDiffable(x.u / y.u)
    function bkwd() 
        x.dhdx += rv.dhdx / y.u
        y.dhdx += -rv.dhdx * x.u / y.u^2
        return BkwdDiffable[x,y]
    end
    rv.back = bkwd
    return rv
end

function Base.:exp(x::BkwdDiffable)
    rv = BkwdDiffable(exp(x.u))
    function bkwd() 
        x.dhdx += rv.dhdx * exp(x.u)
        return BkwdDiffable[x]
    end
    rv.back = bkwd
    return rv
end

function Base.:sin(x::BkwdDiffable)
    rv = BkwdDiffable(sin(x.u))
    function bkwd() 
        x.dhdx += rv.dhdx * cos(x.u)
        return BkwdDiffable[x]
    end
    rv.back = bkwd
    return rv
end

function Base.:cos(x::BkwdDiffable)
    rv = BkwdDiffable(cos(x.u))
    function bkwd() 
        x.dhdx += -rv.dhdx * sin(x.u)
        return BkwdDiffable[x]
    end
    rv.back = bkwd
    return rv
end

# pass in the objective we want to differentiate with respect to
function reverse_autodiff(h::BkwdDiffable)
    h.dhdx = 1.0;
    frontier = BkwdDiffable[h]
    while length(frontier) != 0
        # sort the active compute graph elements to find the newest one
        sort!(frontier; by = (x) -> x.u)
        newest = pop!(frontier)
        # compute the backward step on this element
        to_add_to_frontier = newest.back()
        union!(frontier, to_add_to_frontier)
    end
end
Out[9]:
reverse_autodiff (generic function with 1 method)
In [10]:
x = BkwdDiffable(2.0)
Out[10]:
BkwdDiffable(2.0, 0.0, 0, var"#1#2"())
In [11]:
fx = f(x)
Out[11]:
BkwdDiffable(2.5081257587058756, 0.0, 5, var"#bkwd#8"{BkwdDiffable,BkwdDiffable}(BkwdDiffable(0.9195357645382262, 0.0, 4, var"#bkwd#6"{BkwdDiffable,BkwdDiffable,BkwdDiffable}(BkwdDiffable(-0.9589242746631385, 0.0, 3, var"#bkwd#9"{BkwdDiffable,BkwdDiffable}(BkwdDiffable(5.0, 0.0, 2, var"#bkwd#3"{BkwdDiffable,BkwdDiffable,BkwdDiffable}(BkwdDiffable(2.0, 0.0, 0, var"#1#2"()), BkwdDiffable(3.0, 0.0, 1, var"#1#2"()), BkwdDiffable(#= circular reference @-2 =#))), BkwdDiffable(#= circular reference @-2 =#))), BkwdDiffable(-0.9589242746631385, 0.0, 3, var"#bkwd#9"{BkwdDiffable,BkwdDiffable}(BkwdDiffable(5.0, 0.0, 2, var"#bkwd#3"{BkwdDiffable,BkwdDiffable,BkwdDiffable}(BkwdDiffable(2.0, 0.0, 0, var"#1#2"()), BkwdDiffable(3.0, 0.0, 1, var"#1#2"()), BkwdDiffable(#= circular reference @-2 =#))), BkwdDiffable(#= circular reference @-2 =#))), BkwdDiffable(#= circular reference @-2 =#))), BkwdDiffable(#= circular reference @-2 =#)))
In [12]:
fx.u
Out[12]:
2.5081257587058756
In [13]:
reverse_autodiff(fx)
In [14]:
x.dhdx
Out[14]:
-1.3644733615014137

This also makes it feasible to compute gradients!

In [15]:
function loss(w)
    return sin(sum([0,0.5,1.0,2.0,3.0,3.5,4.0,4.5] .* exp.(w)));
end
Out[15]:
loss (generic function with 1 method)
In [16]:
w = [BkwdDiffable(0.0) for i in 1:8]
Out[16]:
8-element Array{BkwdDiffable,1}:
  BkwdDiffable(0.0, 0.0, 6, var"#1#2"())
  BkwdDiffable(0.0, 0.0, 7, var"#1#2"())
  BkwdDiffable(0.0, 0.0, 8, var"#1#2"())
  BkwdDiffable(0.0, 0.0, 9, var"#1#2"())
 BkwdDiffable(0.0, 0.0, 10, var"#1#2"())
 BkwdDiffable(0.0, 0.0, 11, var"#1#2"())
 BkwdDiffable(0.0, 0.0, 12, var"#1#2"())
 BkwdDiffable(0.0, 0.0, 13, var"#1#2"())
In [17]:
lossw = loss(w);
In [18]:
lossw.u
Out[18]:
-0.34248061846961253
In [19]:
loss(zeros(8))
Out[19]:
-0.34248061846961253
In [20]:
reverse_autodiff(lossw)
In [21]:
w
Out[21]:
8-element Array{BkwdDiffable,1}:
                BkwdDiffable(0.0, 0.0, 6, var"#1#2"())
  BkwdDiffable(0.0, 0.469762446874128, 7, var"#1#2"())
  BkwdDiffable(0.0, 0.939524893748256, 8, var"#1#2"())
  BkwdDiffable(0.0, 1.879049787496512, 9, var"#1#2"())
 BkwdDiffable(0.0, 2.818574681244768, 10, var"#1#2"())
 BkwdDiffable(0.0, 3.288337128118896, 11, var"#1#2"())
 BkwdDiffable(0.0, 3.758099574993024, 12, var"#1#2"())
 BkwdDiffable(0.0, 4.227862021867152, 13, var"#1#2"())
In [22]:
# just check the 6th element numerically
e6 = zeros(8); e6[6] = 1.0;
(loss(zeros(8) + 1e-5 * e6) - loss(zeros(8))) / 1e-5
Out[22]:
3.288374546278616
In [23]:
w[6].dhdx
Out[23]:
3.288337128118896
In [ ]: