Lecture 26: Alternatives to Autoregressive ML¶

CS4787/5777 — Principles of Large-Scale Machine Learning Systems¶

$\newcommand{\R}{\mathbb{R}}$

Recall: "classical" machine learning

  • Example-label pairs $(x,y)$
  • Example usually contains more information than the label
    • e.g. we might process a $1024 \times 1024 \times 3$ image only to produce a single label with $1000$ discrete options
  • This is not bad, but it would be nice to have more signal for learning

Modern causal autoregressive learning

  • Example is a sequence $(x_1, x_2, \ldots, x_n)$
  • Task is to predict $x_k$ given $(x_1, \ldots, x_{k-1})$ for all $k \in \{1,\ldots,n\}$
  • For probabilistic sequence modeling with $x_i \in V$ for discrete set $V$ (usually a vocabulary), this results in a density of the form $$\mathbf{P}(x_1, x_2, \ldots, x_n) = \prod_{k = 1}^n p_{w}(x_k \mid x_1, x_2, \ldots, x_{k-1})$$ where $p_w$ denotes the learned model as a function of the weights $w$.
  • Minimize the negative-log-likelihood of the data (cross-entropy loss).
  • Beneficial consequence: example and "labels" are the same
    • they contain the same amount of information
    • gives us a lot of "signal" for learning
    • e.g. if we process a sequence of length $1024$ then there are $1024$ things we are trying to predict

Drawback of causal autoregressive modeling

  • Sampling latency is high due to high memory bandwidth demands (Memory Ops from https://arxiv.org/pdf/2403.14123)
No description has been provided for this image
  • To sample a sequence of length $n$, we need to query the model $n$ times, and each such query requires that we load $w$ $$\text{total mem bw required (not incl. activations)} = n \cdot d_{\text{model}} \cdot (\operatorname{bpw})$$ where $d_{\text{model}}$ is the model size and $\operatorname{bpw}$ is the bytes-per-weight of the (possibly compressed) model.

  • This becomes intractable if $n$ is very large.

  • Imagine we tried to do this to predict the pixels in a $1024 \times 1024$ image! Two big problems:

    • Attention is quadratic in $n$, and here $n$ would be $1024^2$.
    • The memory bandwidth required just to load the weights would be huge: $1024^2 \cdot d_{\text{model}}$.

One Alternative: Diffusion Models¶

Motivating idea. Instead of modeling our distribution autoregressively, let's do the following

$$\mathbf{P}(x_1, x_2, \ldots, x_n) = \prod_{k = 1}^n p_{w}(x_k).$$

Now we can sample all the $x_i$ in parallel, since $x_k$ no longer depends on $x_{<k}$. This means we only need one load of the model weights to do this!

  • Memory bandwidth would just be $d_{\text{model}} \cdot (\operatorname{bpw})$

What's the downside of doing this?

Okay, so obviously we can't do that...¶

What if we did something slightly more clever by augmenting our state space with an extra variable $z_k$ (same shape as $x_k$) which represents some "hint" or "partial" information about what might be at position $k$

$$\mathbf{P}(x_1, x_2, \ldots, x_n) = \sum_{z_1,\ldots,z_n} \prod_{k = 1}^n p_{w}(x_k \mid z_1, \ldots, z_n) \cdot p_{w}(z_k).$$

Now the sampling process has two steps:

  • Draw $z_1, \ldots, z_n$ in parallel by querying the model
  • Draw $w_1, \ldots, w_n$ in parallel conditioned on $(z_1, \ldots, z_n)$
  • Memory bandwidth is now $2 \cdot d_{\text{model}} \cdot (\operatorname{bpw})$

We can do this across multiple steps too. Augmenting the space with $z_{k}^{(T)}$ for $k \in \{1,\ldots,n\}$ and $T \in \{1,\ldots,T-1\}$

$$\mathbf{P}(x_1, x_2, \ldots, x_n) = \sum_{z} \prod_{k = 1}^n p_{w,T}(x_k \mid z_1^{(T-1)}, \ldots, z_n^{(T-1)}) \cdot p_{w,T-1}(z_k^{(T-1)} \mid z_1^{(T-2)}, \ldots, z_n^{(T-2)}) \cdots p_{w,2}(z_k^{(2)} \mid z_1^{(1)}, \ldots, z_n^{(1)}) \cdot p_{w,1}(z_k^{(1)}).$$

Now the sampling process has $T$ steps:

  • Draw $z_1^{(1)}, \ldots, z_n^{(1)}$ in parallel by querying the model
  • Draw $z_1^{(2)}, \ldots, z_n^{(2)}$ in parallel by querying the model, conditioned on $(z_1^{(1)}, \ldots, z_n^{(1)})$
  • Now we can discard $(z_1^{(1)}, \ldots, z_n^{(1)})$, since we don't need them anymore
  • Continue by drawing $z_1^{(3)}, \ldots, z_n^{(3)}$ in parallel by querying the model, conditioned on $(z_1^{(2)}, \ldots, z_n^{(2)})$
  • Repeat, eventually producing $z_1^{(T-1)}, \ldots, z_n^{(T-1)}$
  • Draw $w_1, \ldots, w_n$ in parallel conditioned on $(z_1^{(T-1)}, \ldots, z_n^{(T-1)})$
  • Memory bandwidth is now $T \cdot d_{\text{model}} \cdot (\operatorname{bpw})$

But this is all little awkward. Some questions we'd need to answer:

  • Where do we get $z$ from for training? (if we need a value for $z$)
  • How to we minimize the loss efficiently? (can't just sum over exponentially many values for $z$ to mimimize the log-likelihood)

The ELBO¶

Very powerful tool in Bayesian machine learning. Let $X$ and $Z$ be any random variables, and let $p$ be a joint distributions of $X$ and $Z$, and let $q$ be a (possibly different) distribution of $Z$ conditioned on $X$. Then the ELBO ("evidence lower bound") for an example $x$ is defined as

$$\operatorname{ELBO}(x; p, q) = \mathbf{E}_{z \sim q(\cdot \mid x)}\left[ \log \frac{p(x,z)}{q(z|x)} \right].$$

It's a standard result that

$$\log p(x) \ge \operatorname{ELBO}(x; p, q).$$

Why? Observe that

\begin{align*} \log p(x) &= \log\left( \sum_{z} p(x,z) \right) \\&= \log\left( \sum_{z} q(z|x) \cdot \frac{p(x,z)}{q(z|x)} \right) \\&= \log\left( \mathbf{E}_{z \sim q(\cdot \mid x)}\left[ \frac{p(x,z)}{q(z|x)} \right] \right) \\&\ge \mathbf{E}_{z \sim q(\cdot \mid x)}\left[ \log\left( \frac{p(x,z)}{q(z|x)} \right) \right] \end{align*}

where this last step follows from Jensen's inequality. Also, obviously the NELBO ("negative ELBO")

$$-\log p(x) \le -\operatorname{ELBO}(x; p, q) = \operatorname{NELBO}(x; p, q)$$

is an upper bound on the loss we'd like to minimize. Main idea: minimize this upper bound!

Training with ELBO Recipe¶

When we want to minimize the negative log likelihood of some data $X$ but our model gives us a joint distribution of $X$ and $Z$ rather than just a distribution over $X$, we do the following:

  • Choose some conditional distribution $q(Z | X)$ — independent of the model weights
  • Draw $z$ from that $q$, given $x$
  • Let the loss of this example be $\log \frac{p_w(x,z)}{q(z|x)}$
  • Backprop and run gradient descent normally

Many problems still!

  • We blew up the state by a factor of $T$
  • Now the model is trying to predict these random samples $z$ which might not have much information about $x$
  • Just kicked the design can down the road: now we still need to pick $q$

We can fix a few of these problems by re-writing the loss¶

If we let $x$ denote $x_1, \ldots, x_n$ and $z^{(t)}$ denote $z^{(t)}_1, \ldots, z^{(t)}_n$, and we require that $q$ have a Markov-chain structure,

\begin{align*} \operatorname{ELBO}(x) &= \mathbf{E}_{z \sim q(\cdot \mid x)}\left[ \log \frac{p(x,z)}{q(z|x)} \right] \\&= \mathbf{E}_{z \sim q(\cdot \mid x)}\left[ \log\left( \frac{ p_{w,T}(x \mid z^{(T-1)}) \cdot p_{w,T-1}(z^{(T-1)} \mid z^{(T-2)}) \cdots p_{w,2}(z^{(2)} \mid z^{(1)}) \cdot p_{w,1}(z^{(1)}) }{ q(z^{(T-1)} \mid z^{(T-2)}, x) \cdots q(z^{(2)} \mid z^{(1)}, x) \cdot q(z^{(1)} \mid x) }\right) \right] \\&= \mathbf{E}_{z \sim q(\cdot \mid x)}\left[ \log\left( p_{w,T}(x \mid z^{(T-1)}) \right) + \log\left( \frac{ p_{w,T-1}(z^{(T-1)} \mid z^{(T-2)}) }{ q(z^{(T-1)} \mid z^{(T-2)}, x) }\right) + \cdots + \log\left( \frac{ p_{w,1}(z^{(1)}) }{ q(z^{(1)} \mid x) }\right) \right] \\&= \mathbf{E}_{z \sim q(\cdot \mid x)}\left[ \log\left( p_{w,T}(x \mid z^{(T-1)}) \right) \right] + \mathbf{E}_{z \sim q(\cdot \mid x)}\left[ \log\left( \frac{ p_{w,T-1}(z^{(T-1)} \mid z^{(T-2)}) }{ q(z^{(T-1)} \mid z^{(T-2)}, x) }\right) \right] + \cdots + \mathbf{E}_{z \sim q(\cdot \mid x)}\left[ \log\left( \frac{ p_{w,1}(z^{(1)}) }{ q(z^{(1)} \mid x) }\right) \right] \end{align*}

Key Trick: Evaluate only one of the terms of this sum, using subsampling (our old trick from the first lecture). A random term of this sum looks like:

$$\mathbf{E}_{z \sim q(\cdot \mid x)}\left[ \log\left( \frac{ p_{w,t}(z^{(t)} \mid z^{(t-1)}) }{ q(z^{(t)} \mid z^{(t-1)}, x) }\right) \right]$$

and we can rewrite this as

$$\mathbf{E}_{z_{t-1} \sim q(Z_{t-1} \mid x)}\left[ \sum_{z^{(t)}} q(z^{(t)} \mid z^{(t-1)}, x) \log\left( \frac{ p_{w,t}(z^{(t)} \mid z^{(t-1)}) }{ q(z^{(t)} \mid z^{(t-1)}, x) }\right) \right]$$

This gives us a full recipe for a discrete-time, discrete-space diffusion!¶

Continuous Diffusion¶

Much real diffusion works in continuous space.

  • Make the time coordinate $t$ continuous rather than discrete: making an infinite number of infinitely small changes.
  • The "noising" process $q$ is typically Brownian motion, and we think about the whole process as a noising-denoising process where the model $p$ learns to denoise images that had Gaussian noise added to them by $q$

Handling the Quadratic Cost of Attention¶

We still haven't addressed the fact that attention scales as $\mathcal{O}(n^2)$

Two main ways to solve this:

  • More efficient attention layers
    • Linear attention
    • Sliding window attention
  • Structured state-space models
    • A combination of recurrent and convolutional neural networks (RNNs and CNNs), which can be efficiently computed as either a recurrence (like an RNN) or a convolution (like a CNN). This means we can efficiently generate one-token-at-at-time for autoregressive modeling (like an RNN) but also process the whole sequence in parallel for efficient training (like a CNN).
In [ ]: