Lecture 27: Alternatives to the Transformer¶

CS4787/5777 — Principles of Large-Scale Machine Learning Systems¶

$\newcommand{\R}{\mathbb{R}}$

Continuing from Last Time: Continuous Diffusion¶

First let's make the space continuous. Recall from last time, we had loss like (ignoring the start and end of the process)

$$\ell = \mathbf{E}_{t \sim \operatorname{Unif}\{1,T\}} \mathbf{E}_{z_{t-1} \sim q(Z_{t-1} \mid x)}\left[ \sum_{z^{(t)}} q(z^{(t)} \mid z^{(t-1)}, x) \log\left( \frac{ p_{w,t}(z^{(t)} \mid z^{(t-1)}) }{ q(z^{(t)} \mid z^{(t-1)}, x) }\right) \right].$$

If we split this out in terms of individual tokens $k \in \{1,\ldots,n\}$, we get

$$\ell = \mathbf{E}_{t \sim \operatorname{Unif}\{1,T\}} \mathbf{E}_{z_{t-1} \sim q(Z_{t-1} \mid x)}\left[ \sum_{k=1}^n \sum_{z_k^{(t)}} q_k(z_k^{(t)} \mid z^{(t-1)}, x) \log\left( \frac{ p_{w,t,k}(z_k^{(t)} \mid z^{(t-1)}) }{ q_k(z_k^{(t)} \mid z^{(t-1)}, x) }\right) \right].$$

If we want to make this continuous, we replace the sum with an integral:

$$\ell = \mathbf{E}_{t \sim \operatorname{Unif}\{1,T\}} \mathbf{E}_{z_{t-1} \sim q(Z_{t-1} \mid x)}\left[ \sum_{k=1}^n \int_{z_k^{(t)}} q_k(z_k^{(t)} \mid z^{(t-1)}, x) \log\left( \frac{ p_{w,t,k}(z_k^{(t)} \mid z^{(t-1)}) }{ q_k(z_k^{(t)} \mid z^{(t-1)}, x) }\right) \; {d z_k^{(t)}} \right].$$

Now, suppose that $q_k(z_k^{(t)} \mid z^{(t-1)}, x)$ and $p_{w,t,k}(z_k^{(t)} \mid z^{(t-1)})$ are multivariate Gaussian distributions with the same fixed covariance $\sigma^2 I$. If $\nu$ denotes the mean of $q_k$ and $\mu$ denotes the mean of $p_{w,t,k}$ (and both of these are implicitly conditioned on and dependent on $z^{(t-1)}$ and (for $q$) $x$), then this becomes

$$\ell = \mathbf{E}_{t \sim \operatorname{Unif}\{1,T\}} \mathbf{E}_{z_{t-1} \sim q(Z_{t-1} \mid x)}\left[ \sum_{k=1}^n \frac{1}{2 \sigma^2} \| \mu - \nu \|^2 \right].$$

Or, more explicitly

$$\ell = \mathbf{E}_{t \sim \operatorname{Unif}\{1,T\}} \mathbf{E}_{z_{t-1} \sim q(Z_{t-1} \mid x)}\left[ \sum_{k=1}^n \frac{1}{2 \sigma^2} \left\| \mathbf{E}_{z_k^{(t)} \sim q_k(z_k^{(t)} \mid z^{(t-1)}, x)}[z_k^{(t)}] - \mathbf{E}_{z_k^{(t)} \sim p_{w,t,k}(z_k^{(t)} \mid z^{(t-1)})}[z_k^{(t)}] \right\|^2 \right].$$

We can make this continuous in time by letting the time-parameter become continuous too.

  • Instead of $z_k^{(t)}$ being a multivariate Gaussian in the $t$ dimension (where $t$ takes on a finite number of values), now it is a Gaussian process

We can reparameterize and scale this in the limit so the loss looks something like

$$\ell(x) = \mathbf{E}_{t \sim \operatorname{Unif}[0,1]} \mathbf{E}_{z^{(t)} \sim q_t( \cdot \mid x)}\left[ r(t) \cdot \sum_{k=1}^n \left\| \hat x_{w,t,k}( z^{(t)} ) - x_k \right\|^2 \right]$$

where $\hat x$ denotes the output of the learned model—a prediction for what the value of $x_k$ will be given the input at "time" $t$—and $r: [0,1] \rightarrow \mathbb{R}_+$ is some variance schedule function that captures the effect of different variances $\sigma^2$ in the limit.

This now becomes very easy to train!

In the limit, the inference process becomes a stochastic differential equation.

One thing to note: I've presented these results like this for simplicity, but most of the diffusion model literature has larger $t$ coordinates corresponding to more noise, with the data $x$ at the smallest time coordinate. So don't get confused if you see that convention.

Why am I telling you about this?¶

Continuous-time diffusion exposes an important tradeoff.

In the forward pass (inference) of continuous-time diffusion, now there are infinitely many "steps" of the process.

  • Seems to defeat the systems purpose of diffusion! Before we needed $n$ (sequence length) forward passes, but now we seem to need $\infty$

But we can approximate the forward stochastic differential equation of the random process with a finite number of steps.

Lots of different ways to approximate a differential equation with a finite number of steps.

  • E.g. Euler's method

Choosing a different "sampler" and different number of steps lets us trade off computational cost for accuracy/fidelity to the learning process.

  • Generally the computational cost of a diffusion process is linearly proportional to the number of steps of the approximate method used for inference.

Handling the Quadratic Cost of Attention: Alternatives to the Transformer¶

We still haven't addressed the fact that attention scales as $\mathcal{O}(n^2)$

Two main ways to solve this:

  • More efficient attention layers
    • Linear attention
    • Sliding window attention
  • State-space models
    • A combination of recurrent and convolutional neural networks (RNNs and CNNs), which can be efficiently computed as either a recurrence (like an RNN) or a convolution (like a CNN). This means we can efficiently generate one-token-at-at-time for autoregressive modeling (like an RNN) but also process the whole sequence in parallel for efficient training (like a CNN).

Recall: What was the benefit of causal-attention transformers?¶

  • Can handle sequences of any length
  • Compute graph depth does not grow with length (unlike a RNN)
  • Can add a new token to the sequence without reprocessing the whole sequence

Linear Attention¶

Recall: attention operation is

$$\operatorname{Attention}(Q,K,V) = \operatorname{softmax}\left( \frac{Q K^T}{\sqrt{d_k}} \right) V$$

Linearized attention replaces this with

$$\operatorname{LinearizedAttention}(Q,K,V)_{ij} = \frac{ \sum_{l=1}^n \langle \phi(Q_i), \phi(K_i) \rangle V_{li} }{\sum_{l=1}^n \langle \phi(Q_i), \phi(K_i) \rangle }$$

where $\phi: \mathbb{R}^{d_k} \rightarrow \mathbb{R}^C$ is a feature map of some non-negative similarity kernel, i.e for all $x$ and $y$, $\langle \phi(x), \phi(y) \rangle = 0$. If we let $K_{\phi} \in \mathbb{R}^{n \times C}$ denote the matrix obtained by applying $\phi$ independently to tokens of $K$, and similarly for $Q_{\phi}$, then

$$\operatorname{LinearizedAttention}(Q,K,V)_{ij} = \frac{ e_i^T (Q_{\phi} K_{\phi}^T V) e_j }{ e_i^T ( Q_{\phi} K_{\phi}^T \mathbf{1} ) e_j }.$$

How many FLOPs does this require to compute? How does that compare to SDPA attention?

State-Space Models¶

Consider a discrete-time process with input $u_t \in \mathbb{R}^m$, state $x_t \in \mathbb{R}^{n}$ and output $y_t \in \mathbb{R}^p$ given by initial state $x_0$ and the update

\begin{align*} x_t &= A x_{t-1} + B u_t \\ y_t &= C x_t. \end{align*}

parameterized by matrices $A \in \mathbb{R}^{n \times n}$, $B \in \mathbb{R}^{n \times m}$ and $C \in \mathbb{R}^{p \times n}$. We can interpret this as a (linear) function mapping tensors $U \in \mathbb{R}^{T \times m}$ to $Y \in \mathbb{R}^{T \times p}$ for any desired length $T$.

What is the cost in FLOPs of computing this for a sequence of length $T$?

$$\operatorname{FLOPs} = \mathcal{O}((n^2 + mn + np) T)$$

What is the cost in FLOPs of adding one new token to a sequence? How much memory do we need to store to enable this beyond input/output tensors

$$\operatorname{FLOPs} = \mathcal{O}((n^2 + mn + np) T)$$$$\operatorname{MEM} = \mathcal{O}(n)$$

This memory use is better than a transformer, which would need memory proportional to the number of tokens $T$.

But this still has the RNN problem of having a compute graph depth that grows with the sequence length!

Can we do this faster? Let's look at the scalar case...¶

\begin{align*} x_t &= a x_{t-1} + b u_t \\ y_t &= c x_t. \end{align*}

Here, $y_t = c \sum_{k=1}^t a^{t-k} b u_k$; if we let $y$ and $u$ be vectors, this looks like

$$y = bc \begin{bmatrix} 1 & 0 & 0 & 0 & \ldots \\ a & 1 & 0 & 0 & \ldots \\ a^2 & a & 1 & 0 & \ldots \\ a^3 & a^2 & a & 1 & \ldots \\ \vdots & \vdots & \vdots & \vdots & \ddots \\ \end{bmatrix} u.$$

That is, this is just a multiplication by a Toeplitz matrix, where each left-to-right descending diagonal is constant. We can also interpret this as a convolution (what a CNN does).

We can multiply by a $T \times T$ Toeplitz matrix in log-linear time $\mathcal{O}(T \log T)$ via a Fourier transform.

State Space Models used as layers in a neural network¶

We can use SSM where we'd use a transformer attention block.

  • But need to be careful because this simple SSM layer is linear

Compute cost during forward pass of inference is linear in the sequence length, not quadratic.

Compute cost during training is log-linear in the sequence length (because we treat it as a convolution and compute it via FFT)

Memory needed to store state is only constant (the analog of the KV cache does not grow with the sequence length)

But there's no free lunch! Tradeoffs with accuracy; varies by task.

In [ ]: