A Curious Learning Task

Suppose we want to fit the following dataset.

In [4]:
using Plots
using LinearAlgebra
using Statistics
using Interact

gr()
┌ Info: Precompiling Plots [91a5bcdd-55d7-5caf-9e0b-520d859cae80]
└ @ Base loading.jl:1278
┌ Info: Precompiling Interact [c601a237-2ae4-5e1e-952c-7a85b0c7eef1]
└ @ Base loading.jl:1278

Unable to load WebIO. Please make sure WebIO works for your Jupyter client. For troubleshooting, please see the WebIO/IJulia documentation.

Out[4]:
Plots.GRBackend()
In [5]:
function gen_data(n::Int64)
    xs = rand(n);
    ys = cos.(2*xs .+ 20*xs.^2) .+ 0.1 * randn(n);
    return (xs, ys);
end

n = 256;
(xs, ys) = gen_data(n);
(xs_test, ys_test) = gen_data(n);
In [6]:
scatter(xs, ys; label="training data");
scatter!(xs_test, ys_test; label="test data")
Out[6]:

How will a linear model perform on this task?

Let's look at linear regression.

$$\min_{a,b} \sum_{i=1}^n (ax_i + b - y_i)^2.$$
In [7]:
Xs_homogeneous = vcat(xs', ones(1, n));

(a,b) = inv(Xs_homogeneous * Xs_homogeneous') * Xs_homogeneous * ys;
In [8]:
scatter(xs, ys; label="training data", title="Linear Regression");
plot!(sort(xs), a * sort(xs) .+ b; label="linreg (a=$(round(a;digits=3)), b=$(round(b;digits=3)))")
Out[8]:
In [9]:
# what's the average mean-squared error?

training_loss = mean((a * xs[i] + b - ys[i])^2 for i = 1:n);
test_loss = mean((a * xs_test[i] + b - ys_test[i])^2 for i = 1:n);

println("training loss = $training_loss")
println("test loss = $test_loss")
training loss = 0.5132400881269441
test loss = 0.4925173858425717

Can we do better?

One way to do this is to use a more sophisticated model. One such model is the piecewise linear model.

In [10]:
function piecewise_linear_eval(x::Float64, zs::Array{Tuple{Float64,Float64},1})
    for i = 1:length(zs)
        if ((x >= zs[i][1]) && (x <= zs[i+1][1]))
            (x0, y0) = zs[i];
            (x1, y1) = zs[i+1];
            return ((x - x0)/(x1 - x0)) * (y1 - y0) + y0;
        end
    end
    return 0.0;
end

function piecewise_linear_widget(xs::Array{Float64,1}, ys::Array{Float64,1})
    xinits = [0.36, 0.5, 0.62, 0.74, 0.82, 0.95];
    yinits = sort(rand(6); rev=true);
    x1 = slider(0.0:0.01:1.0, label="x1", value=xinits[1]);
    y1 = slider(-1.2:0.01:1.2, label="y1", value=yinits[1]);
    x2 = slider(0.0:0.01:1.0, label="x2", value=xinits[2]);
    y2 = slider(-1.2:0.01:1.2, label="y2", value=yinits[2]);
    x3 = slider(0.0:0.01:1.0, label="x3", value=xinits[3]);
    y3 = slider(-1.2:0.01:1.2, label="y3", value=yinits[3]);
    x4 = slider(0.0:0.01:1.0, label="x4", value=xinits[4]);
    y4 = slider(-1.2:0.01:1.2, label="y4", value=yinits[4]);
    x5 = slider(0.0:0.01:1.0, label="x5", value=xinits[5]);
    y5 = slider(-1.2:0.01:1.2, label="y5", value=yinits[5]);
    x6 = slider(0.0:0.01:1.0, label="x6", value=xinits[6]);
    y6 = slider(-1.2:0.01:1.2, label="y6", value=yinits[6]);
    keypoints = Interact.@map sort([(0.0,1.0), (&x1,&y1), (&x2,&y2), (&x3,&y3), (&x4,&y4), (&x5,&y5), (&x6,&y6), (1.0,-1.0)]);
    err = Interact.@map mean((piecewise_linear_eval(xs[i], &keypoints) - ys[i])^2 for i = 1:n);
    plt = Interact.@map begin
        scatter(xs, ys; label="training data", title="Piecewise Linear Model (err=$(round(&err; digits=4)))", xlim=(0.0,1.0), ylim=(-1.3,1.3), legend=:bottomleft);
        plot!([z[1] for z in &keypoints], [z[2] for z in &keypoints]; label="piecewise linear", linewidth = 2, markershape = :square, markersize=5)
    end
    wdg = Widget([
                "x1" => x1, "y1" => y1,
                "x2" => x2, "y2" => y2,
                "x3" => x3, "y3" => y3,
                "x4" => x4, "y4" => y4,
                "x5" => x5, "y5" => y5,
                "x6" => x6, "y6" => y6], output = plt)
    @layout! wdg hbox(plt, vbox(hbox(:x1, :y1), hbox(:x2, :y2), hbox(:x3, :y3), hbox(:x4, :y4), hbox(:x5, :y5), hbox(:x6, :y6))) ## custom layout: by default things are stacked vertically
end
Out[10]:
piecewise_linear_widget (generic function with 1 method)
In [11]:
piecewise_linear_widget(xs, ys)
Out[11]:

This error is MUCH lower than what we got from the linear regression model!

But how can we learn this?

Problem: the way we have parameterized this model is not continuous!

To solve this problem, note that we can always represent a piecewise linear function as a sum of shifted and scaled ReLU functions. The ReLU function (REctified Linear Unit) is defined as

$$\operatorname{ReLU}(x) = \begin{cases}x & \text{if } x \ge 0 \\ 0 & \text{if } x \le 0\end{cases} = \max(x, 0).$$

We can visualize this as follows.

In [12]:
function ReLU(x::Float64)
    return max(x, 0);
end

us = collect(-1.0:0.01:1.0);
plot(us, ReLU.(us); linewidth=3, label="", title="ReLU Function")
Out[12]:
In [13]:
# to illustrate, here's a sum of some randomly shifted and scaled relus

us = collect(-1.0:0.01:1.0);
plot(us, sum(rand([-1.0,1.0]) * ReLU.(randn() * us .+ randn()) for i = 1:5); linewidth=3, label="", title="Sum of Random ReLUs")
Out[13]:

We can try to parameterize our model as the sum of ReLU functions as follows.

$$h_{a,b,w}(x) = w_1 \cdot \operatorname{ReLU}(a_1 \cdot x + b_1) + w_2 \cdot \operatorname{ReLU}(a_2 \cdot x + b_2) + \cdots = \sum_{i=1}^d w_i \cdot \operatorname{ReLU}(a_i \cdot x + b_i).$$

This is guaranteed to be continuous in the parameters $a, b, w \in \mathbb{R}^d$. (Why?)

We can train this using SGD. To compute the gradient with respect to a loss function, observe that

$$\frac{\partial}{\partial w_i} \frac{1}{2} \left( h_{a,b,w}(x) - y \right)^2 = \left( h_{a,b,w}(x) - y \right) \cdot \operatorname{ReLU}(a_i \cdot x + b_i)$$$$\frac{\partial}{\partial a_i} \frac{1}{2} \left( h_{a,b,w}(x) - y \right)^2 = \left( h_{a,b,w}(x) - y \right) \cdot w_i \cdot \operatorname{ReLU}'(a_i \cdot x + b_i) \cdot x$$$$\frac{\partial}{\partial b_i} \frac{1}{2} \left( h_{a,b,w}(x) - y \right)^2 = \left( h_{a,b,w}(x) - y \right) \cdot w_i \cdot \operatorname{ReLU}'(a_i \cdot x + b_i)$$
In [14]:
function dReLU(x::Float64)
    return (x > 0.0) ? 1.0 : 0.0;
end

function heval(x::Float64, a::Array{Float64,1}, b::Array{Float64,1}, w::Array{Float64,1})
    d = length(a);
    @assert(length(b) == d);
    @assert(length(w) == d);
    
    return sum(w .* ReLU.(a .* x .+ b));
end

function grad(x::Float64, y::Float64, a::Array{Float64,1}, b::Array{Float64,1}, w::Array{Float64,1})
    d = length(a);
    @assert(length(b) == d);
    @assert(length(w) == d);
    
    axb = a .* x .+ b;
    dh = sum(w .* ReLU.(axb)) - y;
    
    dw = dh .* ReLU.(axb);
    da = dh .* w .* dReLU.(axb) .* x;
    db = dh .* w .* dReLU.(axb);
    
    return [da,db,dw];
end
Out[14]:
grad (generic function with 1 method)
In [15]:
d = 1024;
a = rand([1.0,-1.0],d);
b = 2 * rand(d) .- 1;
w = randn(d);

alpha = 0.0001;
In [16]:
# what does the model look like when we initialize?

us = collect(0.0:0.001:1.0);
scatter(xs, ys; label="training data", title="Piecewise Linear Model");
plot!(us, [heval(u,a,b,w) for u in us]; label="piecewise linear")
Out[16]:
In [17]:
# let's train this with SGD

for k = 1:10000
    i = rand(1:n)
    (a, b, w) = (a, b, w) .- alpha .* grad(xs[i], ys[i], a, b, w);
end

err = mean((heval(xs[i],a,b,w)-ys[i])^2 for i = 1:n);
test_err = mean((heval(xs_test[i],a,b,w)-ys_test[i])^2 for i = 1:n);
us = collect(0.0:0.001:1.0);
scatter(xs, ys; label="training data", title="Piecewise Linear Model (err=$(round(err;digits=4)), test=$(round(test_err;digits=4)))");
plot!(us, [heval(u,a,b,w) for u in us]; label="piecewise linear")
Out[17]: