Lecture 10: Neural Networks and Transformers¶
CS4787/5777 — Principles of Large-Scale Machine Learning Systems¶
$\newcommand{\R}{\mathbb{R}}$ $\newcommand{\norm}[1]{\left\|#1\right\|}$ $\newcommand{\Exv}[1]{\mathbf{E}\left[#1\right]}$ $\newcommand{\Prob}[1]{\mathbf{P}\left(#1\right)}$ $\newcommand{\Var}[1]{\operatorname{Var}\left(#1\right)}$ $\newcommand{\Abs}[1]{\left|#1\right|}$
import torch
import transformers
import matplotlib
from matplotlib import pyplot
import time
matplotlib.rcParams.update({'font.size': 14, 'figure.figsize': (6.0, 6.0)})
torch.set_grad_enabled(False)
<torch.autograd.grad_mode.set_grad_enabled at 0x1044ab8b0>
Review: Linear models and neural networks.¶
From the homework and your past machine learning course, you should all be familiar with the notion of a linear model hypothesis class. For example, for multinomial logistic regression, we had the hypothesis class $$h_W(x) = \operatorname{softmax}(Wx).$$ This is a specific example of a more general linear model of the form $$h_W(x) = \sigma(Wx)$$ for some inputs $x \in \R^d$, matrix $W \in \R^{D \times d}$, and function $\sigma: \R^D \rightarrow \R^D$. Many important methods in machine learning use linear model hypothesis classes, including linear regression, logistic regression, and SVM.
One naive way that we can combine two hypothesis classes is by stacking or layering them. If I have one class of hypotheses $h^{(1)}_{W_1}$ that maps from $\R^{d_0}$ to $\R^{d_1}$ and a second class of hypotheses $h^{(2)}_{W_2}$ that maps from $\R^{d_1}$ to $\R^{d_2}$, then I can form the layered hypothesis class $$h_{W_1,W_2}(x) = h^{(2)}_{W_2}(h^{(1)}_{W_1}(x))$$ that results from first applying $h^{(1)}$ and then applying $h^{(2)}$. Intutively, we're first having $h^{(1)}$ make a prediction and then using the result of that prediction as an input to $h^{(2)}$ to make our final prediction. If both our consituent hypothesis classes are linear models, we can write this out more explicitly as $$h_{W_1,W_2}(x) = \sigma_2(W_2 \cdot \sigma_1(W_1 x)).$$ Of course, we don't need to limit ourselves to layering just two linear classifiers. We could layer as many as we want. For example, if we had $\mathcal{L}$ total layers, then our hypothesis would look like $$h_{W_1,W_2,\ldots,W_l}(x) = \sigma_l(W_l \cdot \sigma_{l-1}(W_{l-1} \cdots \sigma_2(W_2 \cdot \sigma_1(W_1 x)) \cdots)).$$ We can write this out more generally and explicitly in terms of a recurrence relation. \begin{align*} o_0 &= x && \text{Typical runtime cost:}\\ \forall l \in \{1,\ldots,\mathcal{L}\}, \hspace{1em} a_l &= W_l \cdot o_{l-1} + b_l && \fbox{\rule[2em]{10em}{0pt}} \\ \forall l \in \{1,\ldots,\mathcal{L}\}, \hspace{1em} o_l &= \sigma_l(a_l) && \fbox{\rule[2em]{10em}{0pt}} \\ h_{W_1,b_1,W_2,b_2,\ldots,W_l,b_l}(x) &= o_{\mathcal{L}}. && \fbox{\rule[2em]{10em}{0pt}} \end{align*} where $a_l, o_l \in \R^{d_l}$, and here we've also added an explicit bias parameter $b_l \in \R^{d_l}$ to each layer. This type of model is called a multilayer perceptron (MLP), artificial neural network (ANN), or deep neural network (DNN). (Specifically, it's a type of deep neural network called a feedforward neural network.) Here, the functions $\sigma_l$ are called the activation functions and are almost always chosen to operate independently along each dimension; that is (with abuse of notation) $$\left( \sigma_l(x) \right)_i = \sigma_l(x_i).$$ Note that this is not true for the softmax, but it's true about pretty much every other major activation function.
Variants of neural networks:¶
- Residual neural networks include feedback connections in which the outputs of the model are fed back into itself.
- Convolutional neural networks restrict some of the linear transformations $W_l$ to be members of some subset of linear transformations, typically convolutions with some filter.
- Recurrent neural networks repeat the same layers to process a sequence.
- Transformers use attention blocks to process sequences and spatially/temporally structured data in a unified way.
Transformers¶
Designed to process sequential data, but can generalize to any sort of structured data.
Represents an example as a matrix in $\R^{n \times d}$ where $n$ is the sequence length (a.k.a. $n$ "tokens") and $d$ is the representation dimension. A mini-batch of $B$ examples is then a matrix in $\R^{B \times n \times d}$.
Most characteristic layer: attention layer (more formally, "Scaled Dot-Product Attention"). Given input activation matrices $Q \in \mathbb{R}^{n \times d_k}$ (the "query" matrix), $K \in \mathbb{R}^{n \times d_k}$ (the "key" matrix), and $V \in \mathbb{R}^{n \times d_v}$ (the "value" matrix), the attention layer outputs $$\operatorname{Attention}(Q,K,V) = \operatorname{softmax}\left( \frac{Q K^T}{\sqrt{d_k}} \right) V,$$ where this softmax applies along the rows of the matrix (i.e. each row of $\operatorname{softmax}( \cdot )$ sums to $1$). You can think of this as a "soft" or "weighted" lookup. This formulation lets every token (every sequence element) look up into every other one: if we want to restrict this, we can use an attention mask $M \in \mathbb{R}^{n \times n}$, usually with elements in $\{-\infty, 0\}$, and set $$\operatorname{MaskedAttention}(Q,K,V) = \operatorname{softmax}\left( \frac{Q K^T}{\sqrt{d_k}} + M \right) V.$$ This "zeros out" the entries of $\operatorname{softmax}( \cdot )$ for which $M_{ij} = -\infty$.
Multiple attention layers are combined together to form a multi-head attention layer. Such a layer with $h$ "heads" takes as input tensors $Q \in \mathbb{R}^{n \times h \times d_k}$, $K \in \mathbb{R}^{n \times h \times d_k}$, and $V \in \mathbb{R}^{n \times h \times d_v}$, and outputs a tensor of size $(n \times h \times d_v)$ such that $$\operatorname{MultiHeadAttention}(Q,K,V)_{:,i,:} = \operatorname{MaskedAttention}(Q_{:,i,:},K_{:,i,:},V_{:,i,:});$$ that is, it's just $h$ attention layers running in parallel along the head dimension.
A typical multi-head attention block with representation dimension $d$ and number of heads $h$ (where $h$ evenly divides $d$) has $d_k = d_v = d/h$ and is parameterized by four matrices: $W_K \in \mathbb{R}^{d \times d}$, $W_Q \in \mathbb{R}^{d \times d}$, $W_V \in \mathbb{R}^{d \times d}$ and $W_O \in \mathbb{R}^{d \times d}$. Given input $X \in \mathbb{R}^{n \times d}$, it outputs $$\operatorname{MultiHeadAttention}(X W_Q^T, X W_K^T, X W_V^T) W_O^T$$ where here we reshape $\operatorname{MultiHeadAttention}$ to operate on matrices like $Q \in \mathbb{R}^{n \times hd_k}$ rather than on tensors.
Typical transformers alternate attention layers with multi-layer perceptrons (MLPs) operating independently along token positions. E.g. for Llama, $$\operatorname{MLP}(X; W_{\text{up}}, W_{\text{gate}}, W_{\text{down}}) = ((X W_{\text{up}}^T) \odot \operatorname{SiLU}(X W_{\text{gate}}^T)) W_{\text{down}}^T.$$
Position¶
How do transformers handle positions? Three main approaches:
- Absolute positional encoding/embedding
- Relative positional encoding (e.g. Rotary Position Embedding RoPE)
- Causal attention masks
An absolute positional embedding concatenates a token embedding with an embedding for its position (that is, the number of possible "inputs" to the transformer is equal to the vocabulary size times the number of positions).
A relative positional embedding modifies the attention mechanism to take the position into account. Main idea: let $U$ be some fixed orthogonal matrix (rotation matrix, $U^{-1} = U^T$). For a query row vector $q$ at position $m$ and key row vector $k$ at position $n$, modify them by doing $$\hat q = q U^m \;\;\text{and}\;\; \hat k = k U^n.$$ Then, when we take their dot product in the attention operator, we'll get $$\hat q (\hat k)^T = q U^m (k U^n)^T = q U^m (U^n)^T k^T = q U^m U^{-n} k^T = q U^{m-n} k^T.$$ That is, this entry in the rotation matrix is transformed in a way that depends on the relative positional difference between the key and query vectors!
A causal attention mask also modifies the attention mechanism by letting query tokens only depend on key tokens at its own position or in the past. This is critical for efficient autoregressive modeling! (Why?)
Layer normalization¶
Deep neural networks can have problems during training because of higher-order effects that exist because of the composition of many layers. Small changes made to earlier layers can have a major impact on what happens later in the network. This makes training difficult, because we often want to set the step size to be very small to damp out these higher effects, but this also makes training slow. One way to address this is layer normalization, which reparameterizes a deep neural network to reduce these effects.
Suppose that we are concerned with a single (vector) activation in the network. Let $x \in \R^d$ denote the vector of activations of the original network. RMS normalization (one variant of layer normalization) replaces this $u$ with $$\mathsf{RMSNorm}(x; w) = \frac{x}{\sqrt{\epsilon + \frac{1}{d} \sum_{i=1}^d x_i^2}} \odot w = \frac{x}{\sqrt{\epsilon + \frac{1}{d} \norm{x}^2}} \odot w$$ where $\epsilon$ is some small fixed positive number (e.g. $10^{-6}$) used to avoid dividing by zero, and $w \in \R^d$ is a weight vector of the RMSNorm layer. We're basically just scaling by the square root of the norm of the signal.
Another, much weirder type of normalization is batch normalization which normalizes across the batch dimension instead.
Autoregressive models: the language-modeling head.¶
Maps to a prediction over tokens in the language model for the next token in the sequence.
The Key Feature of Autoregressive Transformers: the KV Cache¶
Since models are causal and MLP layers are per-token, later token predictions depend on earlier ones only through the attention mechanism.
Attention mechanism only depends on $K$ and $V$ tensors.
To make future predictions during inference, we only need to save the $K$ and $V$ values from the past! We can get rid of all the other intermediates.
Putting it all together. The Llama-3 Model: Demo.¶
model_name = 'meta-llama/Meta-Llama-3.1-8B-Instruct'
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name,local_files_only=True)
model = transformers.AutoModelForCausalLM.from_pretrained(model_name,torch_dtype='auto',local_files_only=True)
Loading checkpoint shards: 0%| | 0/4 [00:00<?, ?it/s]
model
LlamaForCausalLM(
(model): LlamaModel(
(embed_tokens): Embedding(128256, 4096)
(layers): ModuleList(
(0-31): 32 x LlamaDecoderLayer(
(self_attn): LlamaAttention(
(q_proj): Linear(in_features=4096, out_features=4096, bias=False)
(k_proj): Linear(in_features=4096, out_features=1024, bias=False)
(v_proj): Linear(in_features=4096, out_features=1024, bias=False)
(o_proj): Linear(in_features=4096, out_features=4096, bias=False)
)
(mlp): LlamaMLP(
(gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
(up_proj): Linear(in_features=4096, out_features=14336, bias=False)
(down_proj): Linear(in_features=14336, out_features=4096, bias=False)
(act_fn): SiLU()
)
(input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
(post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
)
)
(norm): LlamaRMSNorm((4096,), eps=1e-05)
(rotary_emb): LlamaRotaryEmbedding()
)
(lm_head): Linear(in_features=4096, out_features=128256, bias=False)
)
Grouped-Query Attention¶
Reduces the cost of the KV cache by using each $K$ and $V$ vector multiple times across multiple queries.
Makes the $Q$ tensor larger than the $K$ and $V$ tensors.
prompt = 'Please tell me all the ways Sun Wukong became immortal in Journey to the West.'
input = tokenizer(f'<|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>', return_tensors='pt')
input
{'input_ids': tensor([[128000, 128006, 882, 128007, 271, 5618, 3371, 757, 682,
279, 5627, 8219, 468, 3178, 647, 6244, 60214, 304,
43680, 311, 279, 4410, 13, 128009, 128006, 78191, 128007]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1]])}
tokenizer.batch_decode(input['input_ids'].t())
['<|begin_of_text|>', '<|start_header_id|>', 'user', '<|end_header_id|>', '\n\n', 'Please', ' tell', ' me', ' all', ' the', ' ways', ' Sun', ' W', 'uk', 'ong', ' became', ' immortal', ' in', ' Journey', ' to', ' the', ' West', '.', '<|eot_id|>', '<|start_header_id|>', 'assistant', '<|end_header_id|>']
output = model.generate(**input, max_length=256)
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
output
tensor([[128000, 128006, 882, 128007, 271, 5618, 3371, 757, 682,
279, 5627, 8219, 468, 3178, 647, 6244, 60214, 304,
43680, 311, 279, 4410, 13, 128009, 128006, 78191, 128007,
271, 32, 2294, 3488, 922, 832, 315, 8620, 17649,
596, 1455, 28530, 5885, 2268, 11439, 311, 279, 11670,
11775, 43680, 311, 279, 4410, 11, 8219, 468, 3178,
647, 11, 1101, 3967, 439, 279, 58937, 6342, 11,
6244, 60214, 1555, 264, 10824, 315, 813, 1866, 13736,
11, 24632, 4455, 11, 323, 279, 56650, 315, 5370,
409, 1385, 13, 5810, 527, 279, 1401, 5627, 568,
17427, 4998, 76052, 1473, 16, 13, 3146, 59204, 505,
264, 9998, 96618, 8219, 468, 3178, 647, 574, 9405,
505, 264, 9998, 304, 264, 16700, 26457, 11, 902,
574, 1071, 311, 387, 279, 1121, 315, 264, 24632,
25885, 13, 1115, 19018, 7342, 13160, 1461, 439, 264,
1694, 449, 24674, 13736, 323, 18000, 627, 17, 13,
3146, 43066, 287, 279, 21594, 35257, 44187, 96618, 1666,
264, 3995, 39803, 11, 8219, 468, 3178, 647, 11352,
264, 24632, 14098, 430, 11938, 1461, 4998, 76052, 13,
578, 14098, 11, 3967, 439, 279, 21594, 35257, 44187,
320, 19171, 2663, 279, 330, 38120, 12, 3968, 3372,
44187, 1, 477, 330, 38120, 12, 3968, 3372, 44187,
315, 279, 18288, 61269, 4063, 574, 1071, 311, 6782,
279, 28591, 315, 279, 4330, 5540, 25, 7732, 11,
4027, 11, 9578, 11, 9501, 11, 323, 3090, 627,
18, 13, 3146, 38030, 449, 279, 23860, 1132, 3804,
71, 32973, 96618, 8219, 468, 3178, 647, 574, 3010,
11352, 555, 279, 23860, 1132, 3804, 71, 32973, 11,
264, 47841, 7491, 889]])
tokenizer.decode(output[0])
'<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nPlease tell me all the ways Sun Wukong became immortal in Journey to the West.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nA great question about one of Chinese literature\'s most beloved characters!\n\nAccording to the classic novel Journey to the West, Sun Wukong, also known as the Monkey King, became immortal through a combination of his own powers, magical events, and the blessings of various deities. Here are the key ways he achieved immortality:\n\n1. **Born from a stone**: Sun Wukong was born from a stone in a mountain cave, which was said to be the result of a magical phenomenon. This unusual birth marked him as a being with extraordinary powers and abilities.\n2. **Consuming the Five Elements Fruit**: As a young monkey, Sun Wukong discovered a magical fruit that granted him immortality. The fruit, known as the Five Elements Fruit (also called the "Five-Flavor Fruit" or "Five-Flavor Fruit of the Golden Lotus"), was said to contain the essence of the five elements: wood, fire, earth, metal, and water.\n3. **Training with the Patriarch Subhuti**: Sun Wukong was later discovered by the Patriarch Subhuti, a Buddhist master who'
Appendix/Review: Overfitting and Underfitting and Neural Networks.¶
- Underfitting informally means that the training error is high.
- Overfitting informally means that the difference between the test error and the training error is high.
- Capacity of a model informally refers to the ability of the model to fit a wide range of possible functions.
- Models with high capacity tend to overfit.
- Models with low capacity tend to underfit.
- The representational capacity of a parameterized class of models informally refers to the extent to which for a wide range of possible functions, some model in the class approximates that function well.
- Deep neural networks have very high representational capacity.
- In fact, they're universal approximators.
- The effective capacity of a parameterized class of models given a specific learning algorithm with a specific amount of data refers to the extent to which for a wide range of possible functions, the model in the class produced by the learning algorithm can approximate that function well.
- For convex optimization probblems (what we've studied so far) all the algorithms we've studied converge to the global optimum, so effective capacity will be equal to representational capacity.
- On the other hand, changing the optimization algorithm for a deep learning task can alter the effective capacity. Even changing the hyperparameters of an algorithm can have this effect.
Take away point: In addition to thinking about how changes to a model or algorithm affect optimization parameters like $n$, $d$, and $\kappa$, we also need to reason about how these methods interact with the capacity of the model, especially when we're training a model with a non-convex loss.