struct Dual <: Number
f::Float64
dfdt::Float64
end
Base.promote_rule(::Type{Dual},::Type{Float64}) = Dual
Base.promote_rule(::Type{Dual},::Type{Int64}) = Dual
function Dual(f::Float64)
return Dual(f,0.0)
end
function Dual(f::Integer)
return Dual(Float64(f),0.0)
end
function Base.:+(x::Dual, y::Dual)
return Dual(x.f + y.f, x.dfdt + y.dfdt)
end
function Base.:-(x::Dual, y::Dual)
return Dual(x.f - y.f, x.dfdt - y.dfdt)
end
function Base.:-(y::Dual)
return Dual(-y.f, -y.dfdt)
end
function Base.:*(x::Dual, y::Dual)
return Dual(x.f * y.f, x.dfdt * y.f + x.f * y.dfdt)
end
function Base.:/(x::Dual, y::Dual)
return Dual(x.f / y.f, x.dfdt / y.f - x.f * y.dfdt / y.f^2)
end
function Base.exp(x::Dual)
return Dual(exp(x.f), exp(x.f) * x.dfdt)
end
function Base.sin(x::Dual)
return Dual(sin(x.f), cos(x.f) * x.dfdt)
end
function Base.cos(x::Dual)
return Dual(cos(x.f), -sin(x.f) * x.dfdt)
end
Consider the function
$$f(x) = \exp(\sin(x+3)^2).$$The chain-rule derivative of this is
$$f'(x) = \exp(\sin(x+3)^2) \cdot 2 \sin(x+3) \cdot \cos(x+3).$$function f(x)
z = sin(x + 3)
return exp(z * z);
end
function dfdx(x)
return exp(sin(x+3)^2) * 2 * sin(x + 3) * cos(x + 3)
end
f(2)
# numerical derivative
(f(2 + 1e-5) - f(2))/1e-5
dfdx(2)
f(Dual(2.0,1.0))
function g(x)
return sin.([0,0.5,1.0,2.0,3.0,3.5,4.0,4.5] .* x);
end
g(Dual(2.0,1.0))
The "dual number" representation computed the derivative automatically! And we can even do it all-at-once for vector-valued functions!
A really dumb, but simple, reverse-mode autodiff.
mutable struct BkwdDiffable <: Number
u::Float64
dhdx::Float64
order::Int64
back::Function
end
Base.promote_rule(::Type{BkwdDiffable},::Type{Float64}) = BkwdDiffable
Base.promote_rule(::Type{BkwdDiffable},::Type{Int64}) = BkwdDiffable
global bkwd_diff_node_order = 0
function BkwdDiffable(x::Float64)
global bkwd_diff_node_order
order = bkwd_diff_node_order
bkwd_diff_node_order += 1;
return BkwdDiffable(x,0.0,order,() -> BkwdDiffable[])
end
function BkwdDiffable(x::Integer)
return BkwdDiffable(Float64(x))
end
function Base.:+(x::BkwdDiffable, y::BkwdDiffable)
rv = BkwdDiffable(x.u + y.u)
function bkwd()
x.dhdx += rv.dhdx
y.dhdx += rv.dhdx
return BkwdDiffable[x,y]
end
rv.back = bkwd
return rv
end
function Base.:-(x::BkwdDiffable, y::BkwdDiffable)
rv = BkwdDiffable(x.u - y.u)
function bkwd()
x.dhdx += rv.dhdx
y.dhdx -= rv.dhdx
return BkwdDiffable[x,y]
end
rv.back = bkwd
return rv
end
function Base.:-(y::BkwdDiffable)
rv = BkwdDiffable(-y.u)
function bkwd()
y.dhdx -= rv.dhdx
return BkwdDiffable[y]
end
rv.back = bkwd
return rv
end
function Base.:*(x::BkwdDiffable, y::BkwdDiffable)
rv = BkwdDiffable(x.u * y.u)
function bkwd()
x.dhdx += rv.dhdx * y.u
y.dhdx += rv.dhdx * x.u
return BkwdDiffable[x,y]
end
rv.back = bkwd
return rv
end
function Base.:/(x::BkwdDiffable, y::BkwdDiffable)
rv = BkwdDiffable(x.u / y.u)
function bkwd()
x.dhdx += rv.dhdx / y.u
y.dhdx += -rv.dhdx * x.u / y.u^2
return BkwdDiffable[x,y]
end
rv.back = bkwd
return rv
end
function Base.:exp(x::BkwdDiffable)
rv = BkwdDiffable(exp(x.u))
function bkwd()
x.dhdx += rv.dhdx * exp(x.u)
return BkwdDiffable[x]
end
rv.back = bkwd
return rv
end
function Base.:sin(x::BkwdDiffable)
rv = BkwdDiffable(sin(x.u))
function bkwd()
x.dhdx += rv.dhdx * cos(x.u)
return BkwdDiffable[x]
end
rv.back = bkwd
return rv
end
function Base.:cos(x::BkwdDiffable)
rv = BkwdDiffable(cos(x.u))
function bkwd()
x.dhdx += -rv.dhdx * sin(x.u)
return BkwdDiffable[x]
end
rv.back = bkwd
return rv
end
# pass in the objective we want to differentiate with respect to
function reverse_autodiff(h::BkwdDiffable)
h.dhdx = 1.0;
frontier = BkwdDiffable[h]
while length(frontier) != 0
# sort the active compute graph elements to find the newest one
sort!(frontier; by = (x) -> x.u)
newest = pop!(frontier)
# compute the backward step on this element
to_add_to_frontier = newest.back()
union!(frontier, to_add_to_frontier)
end
end
x = BkwdDiffable(2.0)
fx = f(x)
fx.u
reverse_autodiff(fx)
x.dhdx
This also makes it feasible to compute gradients!
function loss(w)
return sin(sum([0,0.5,1.0,2.0,3.0,3.5,4.0,4.5] .* exp.(w)));
end
w = [BkwdDiffable(0.0) for i in 1:8]
lossw = loss(w);
lossw.u
loss(zeros(8))
reverse_autodiff(lossw)
w
# just check the 6th element numerically
e6 = zeros(8); e6[6] = 1.0;
(loss(zeros(8) + 1e-5 * e6) - loss(zeros(8))) / 1e-5
w[6].dhdx