Lecture 11: Neural Networks and Transformers, Continued.¶
Matrix-matrix multiply as the core op of deep learning.¶
CS4787/5777 — Principles of Large-Scale Machine Learning Systems¶
$\newcommand{\R}{\mathbb{R}}$ $\newcommand{\norm}[1]{\left\|#1\right\|}$ $\newcommand{\Exv}[1]{\mathbf{E}\left[#1\right]}$ $\newcommand{\Prob}[1]{\mathbf{P}\left(#1\right)}$ $\newcommand{\Var}[1]{\operatorname{Var}\left(#1\right)}$ $\newcommand{\Abs}[1]{\left|#1\right|}$
import torch
import transformers
import matplotlib
from matplotlib import pyplot
import time
import math
# matplotlib.rcParams.update({'font.size': 14, 'figure.figsize': (6.0, 6.0)})
torch.set_grad_enabled(False)
<torch.autograd.grad_mode.set_grad_enabled at 0x104dcde70>
Overfitting and Underfitting and Neural Networks.¶
- Underfitting informally means that the training error is high.
- Overfitting informally means that the difference between the test error and the training error is high.
- Capacity of a model informally refers to the ability of the model to fit a wide range of possible functions.
- Models with high capacity tend to overfit.
- Models with low capacity tend to underfit.
- The representational capacity of a parameterized class of models informally refers to the extent to which for a wide range of possible functions, some model in the class approximates that function well.
- Deep neural networks have very high representational capacity.
- In fact, they're universal approximators.
- The effective capacity of a parameterized class of models given a specific learning algorithm with a specific amount of data refers to the extent to which for a wide range of possible functions, the model in the class produced by the learning algorithm can approximate that function well.
- For convex optimization problems (e.g. linear regression, logistic regression) all the algorithms we've studied converge to the global optimum, so effective capacity will be equal to representational capacity.
- On the other hand, changing the optimization algorithm for a deep learning task can alter the effective capacity. Even changing the hyperparameters of an algorithm can have this effect.
Take away point: In addition to thinking about how changes to a model or algorithm affect optimization parameters like $n$, $d$, and $\kappa$, we also need to reason about how these methods interact with the capacity of the model, especially when we're training a model with a non-convex loss.
N = 4096
A = torch.randn(N,N)
X = torch.randn(N,N)
Ymv = torch.zeros(N,N)
Ymm = torch.zeros(N,N)
# Y = X @ A.t()
start = time.time()
for i in range(N):
# Y[i,:] = A @ X[i,:]
torch.mv(A,X[i,:],out=Ymv[i,:])
stop = time.time()
elapsed_matvec = stop-start
print(f'took {elapsed_matvec} seconds to do {N} matrix-vector multiplies')
took 4.805057048797607 seconds to do 4096 matrix-vector multiplies
start = time.time()
torch.mm(X, A.t(), out=Ymm)
stop = time.time()
elapsed_matmul = stop-start
print(f'took {elapsed_matmul} seconds to do one matrix-matrix multiply')
took 0.06293106079101562 seconds to do one matrix-matrix multiply
# the results are basically the same! not exactly the same though...
(Ymv - Ymm).square().sum() / Ymm.square().sum()
tensor(1.3464e-12)
print(f'speedup: {elapsed_matvec/elapsed_matmul}')
speedup: 76.3542992665333
Matrix-matrix multiply is much faster than matrix-vector multiply!¶
Let's check this effect for some operations with the same number of FLOPs.
Recall that for a matrix multiply of a $\R^{m \times n}$ by $\R^{n \times p}$ matrix, the number of scalar multiplications needed is $mnp$. Let's keep the inner (reduction) dimension fixed and vary the batch size.
N = 8192
A = torch.randn(N,N)
X = torch.randn(N,N)
Y = torch.zeros(N,N)
Bs = [2**i for i in range(13)]
elapsed_times = []
for B in Bs:
X = X.view(-1,B,N)
Y = Y.view(-1,B,N)
torch.cpu.synchronize() # not necessary for cpu
start = time.time()
for i in range(X.shape[0]):
torch.mm(X[i,:,:], A.t(), out=Y[i,:,:])
torch.cpu.synchronize()
stop = time.time()
elapsed_times.append(stop-start)
pyplot.loglog(Bs, elapsed_times, marker="o")
pyplot.xlabel('batch dimension (B)')
pyplot.ylabel('elapsed time')
pyplot.title('Time for (N/B) NxN by NxB matmuls');
pyplot.loglog(Bs, [N**3/t for t in elapsed_times], marker="o")
pyplot.xlabel('batch dimension (B)')
pyplot.ylabel('multiplies per second (FLOPS)')
pyplot.title('FLOPS of NxN by NxB matmuls');
Is the same thing true across devices?¶
N = 8192
A = torch.randn(N,N,device='mps')
X = torch.randn(N,N,device='mps')
Y = torch.zeros(N,N,device='mps')
Bs = [2**i for i in range(13)]
elapsed_times_mps = []
for B in Bs:
X = X.view(-1,B,N)
Y = Y.view(-1,B,N)
torch.mps.synchronize()
start = time.time()
for i in range(X.shape[0]):
torch.mm(X[i,:,:], A.t(), out=Y[i,:,:])
torch.mps.synchronize()
stop = time.time()
elapsed_times_mps.append(stop-start)
pyplot.loglog(Bs, [N**3/t for t in elapsed_times], marker="o", label='cpu')
pyplot.loglog(Bs, [N**3/t for t in elapsed_times_mps], marker="o", label='mps')
pyplot.xlabel('batch dimension (B)')
pyplot.ylabel('multiplies per second (FLOPS)')
pyplot.legend()
pyplot.title('FLOPS of NxN by NxB matmuls');
X = torch.randn(32,32,32)
Y = torch.randn(32,15)
X @ Y;
X = X.permute(0,2,1)
X @ Y;
Continuing from last time: The Llama-3 Model Demo.¶
model_name = 'meta-llama/Meta-Llama-3.1-8B-Instruct'
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name,local_files_only=True)
model = transformers.AutoModelForCausalLM.from_pretrained(model_name,torch_dtype='auto',local_files_only=True)
Loading checkpoint shards: 0%| | 0/4 [00:00<?, ?it/s]
model
LlamaForCausalLM(
(model): LlamaModel(
(embed_tokens): Embedding(128256, 4096)
(layers): ModuleList(
(0-31): 32 x LlamaDecoderLayer(
(self_attn): LlamaAttention(
(q_proj): Linear(in_features=4096, out_features=4096, bias=False)
(k_proj): Linear(in_features=4096, out_features=1024, bias=False)
(v_proj): Linear(in_features=4096, out_features=1024, bias=False)
(o_proj): Linear(in_features=4096, out_features=4096, bias=False)
)
(mlp): LlamaMLP(
(gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
(up_proj): Linear(in_features=4096, out_features=14336, bias=False)
(down_proj): Linear(in_features=14336, out_features=4096, bias=False)
(act_fn): SiLU()
)
(input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
(post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
)
)
(norm): LlamaRMSNorm((4096,), eps=1e-05)
(rotary_emb): LlamaRotaryEmbedding()
)
(lm_head): Linear(in_features=4096, out_features=128256, bias=False)
)
Grouped-Query Attention¶
Reduces the cost of the KV cache by using each $K$ and $V$ vector multiple times across multiple queries.
Makes the $Q$ tensor larger than the $K$ and $V$ tensors.
prompt = 'Please tell me all the ways Sun Wukong became immortal in Journey to the West.'
formatted_prompt = f'<|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n'
formatted_prompt += 'A great'
input = tokenizer(formatted_prompt, return_tensors='pt')
formatted_prompt
'<|start_header_id|>user<|end_header_id|>\n\nPlease tell me all the ways Sun Wukong became immortal in Journey to the West.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nA great'
output = model(**input)
output.past_key_values
DynamicCache(layers=[<transformers.cache_utils.DynamicLayer object at 0x400e62ec0>, <transformers.cache_utils.DynamicLayer object at 0x405a6b9d0>, <transformers.cache_utils.DynamicLayer object at 0x405a6a5f0>, <transformers.cache_utils.DynamicLayer object at 0x405a692d0>, <transformers.cache_utils.DynamicLayer object at 0x405a6bc40>, <transformers.cache_utils.DynamicLayer object at 0x405a69540>, <transformers.cache_utils.DynamicLayer object at 0x405a6ba60>, <transformers.cache_utils.DynamicLayer object at 0x405a6beb0>, <transformers.cache_utils.DynamicLayer object at 0x405a693f0>, <transformers.cache_utils.DynamicLayer object at 0x405a6baf0>, <transformers.cache_utils.DynamicLayer object at 0x405a69a50>, <transformers.cache_utils.DynamicLayer object at 0x405a6b970>, <transformers.cache_utils.DynamicLayer object at 0x405a6b580>, <transformers.cache_utils.DynamicLayer object at 0x405a6bfa0>, <transformers.cache_utils.DynamicLayer object at 0x405a69cc0>, <transformers.cache_utils.DynamicLayer object at 0x405a69000>, <transformers.cache_utils.DynamicLayer object at 0x405a69030>, <transformers.cache_utils.DynamicLayer object at 0x405a82830>, <transformers.cache_utils.DynamicLayer object at 0x405a811e0>, <transformers.cache_utils.DynamicLayer object at 0x405a823b0>, <transformers.cache_utils.DynamicLayer object at 0x405a022c0>, <transformers.cache_utils.DynamicLayer object at 0x405a00af0>, <transformers.cache_utils.DynamicLayer object at 0x405a00b50>, <transformers.cache_utils.DynamicLayer object at 0x405a00940>, <transformers.cache_utils.DynamicLayer object at 0x405a02110>, <transformers.cache_utils.DynamicLayer object at 0x405a00a90>, <transformers.cache_utils.DynamicLayer object at 0x405a00fa0>, <transformers.cache_utils.DynamicLayer object at 0x405a00c40>, <transformers.cache_utils.DynamicLayer object at 0x405a00ca0>, <transformers.cache_utils.DynamicLayer object at 0x405a01360>, <transformers.cache_utils.DynamicLayer object at 0x405a01720>, <transformers.cache_utils.DynamicLayer object at 0x405a016c0>])
output.logits.shape
torch.Size([1, 30, 128256])
output.logits[0,-1].argmax()
tensor(3488)
tokenizer.decode([2294])
' great'
output = model.generate(**input, max_length=256)
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
output
tensor([[128000, 128006, 882, 128007, 271, 5618, 3371, 757, 682,
279, 5627, 8219, 468, 3178, 647, 6244, 60214, 304,
43680, 311, 279, 4410, 13, 128009, 128006, 78191, 128007,
271, 32, 2294, 3488, 922, 832, 315, 8620, 17649,
596, 1455, 28530, 5885, 2268, 11439, 311, 279, 11670,
11775, 43680, 311, 279, 4410, 11, 8219, 468, 3178,
647, 11, 1101, 3967, 439, 279, 58937, 6342, 11,
6244, 60214, 1555, 264, 10824, 315, 813, 1866, 13736,
11, 24632, 4455, 11, 323, 279, 56650, 315, 5370,
409, 1385, 13, 5810, 527, 279, 1401, 5627, 568,
17427, 4998, 76052, 1473, 16, 13, 3146, 59204, 505,
264, 9998, 96618, 8219, 468, 3178, 647, 574, 9405,
505, 264, 9998, 304, 264, 16700, 26457, 11, 902,
574, 1071, 311, 387, 279, 1121, 315, 264, 24632,
25885, 13, 1115, 19018, 7342, 13160, 1461, 439, 264,
1694, 449, 24674, 13736, 323, 18000, 627, 17, 13,
3146, 43066, 287, 279, 21594, 35257, 44187, 96618, 1666,
264, 3995, 39803, 11, 8219, 468, 3178, 647, 11352,
264, 24632, 14098, 430, 11938, 1461, 4998, 76052, 13,
578, 14098, 11, 3967, 439, 279, 21594, 35257, 44187,
320, 19171, 2663, 279, 330, 38120, 12, 3968, 3372,
44187, 1, 477, 330, 38120, 12, 3968, 3372, 44187,
315, 279, 18288, 61269, 4063, 574, 1071, 311, 6782,
279, 28591, 315, 279, 4330, 5540, 25, 7732, 11,
4027, 11, 9578, 11, 9501, 11, 323, 3090, 627,
18, 13, 3146, 38030, 449, 279, 23860, 1132, 3804,
71, 32973, 96618, 8219, 468, 3178, 647, 574, 3010,
11352, 555, 279, 23860, 1132, 3804, 71, 32973, 11,
264, 47841, 7491, 889]])
tokenizer.decode(output[0])
'<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nPlease tell me all the ways Sun Wukong became immortal in Journey to the West.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nA great question about one of Chinese literature\'s most beloved characters!\n\nAccording to the classic novel Journey to the West, Sun Wukong, also known as the Monkey King, became immortal through a combination of his own powers, magical events, and the blessings of various deities. Here are the key ways he achieved immortality:\n\n1. **Born from a stone**: Sun Wukong was born from a stone in a mountain cave, which was said to be the result of a magical phenomenon. This unusual birth marked him as a being with extraordinary powers and abilities.\n2. **Consuming the Five Elements Fruit**: As a young monkey, Sun Wukong discovered a magical fruit that granted him immortality. The fruit, known as the Five Elements Fruit (also called the "Five-Flavor Fruit" or "Five-Flavor Fruit of the Golden Lotus"), was said to contain the essence of the five elements: wood, fire, earth, metal, and water.\n3. **Training with the Patriarch Subhuti**: Sun Wukong was later discovered by the Patriarch Subhuti, a Buddhist master who'