ericjwang.com
I haven’t had as much time to keep working on this project as I’d have liked. If I had a bit more time, I’d have added sections for flash attention, speculative decoding, and more aggressive quantization methods. May revisit eventually…
I recently purchased an RTX 4090 GPU. Frankly, I’m not sure why I did this, because for most of my personal and professional computing needs I’m perfectly satisfied with my M1 Macbook Pro, and I don’t really enjoy playing video games. If all I wanted was to mess around with deep models, I’d have been better served running a VM on the Lambda cloud, where for the same amount of money one can run a 4090-equivalent datacenter GPU for almost a year — enough time, I expect, for something even better to come out.
Still, one must justify the unjustifiable. I’ve always been interested in ML inference at the systems level, and owning a GPU is a good way to act on that interest. So for the past week, on the advice of my friend Horace He, I’ve been spending my leisure time gradually accelerating the inference of a basic GPT model.
I chose to begin with Andrej Karpathy’s nanoGPT, which is a concise but complete implementation of the GPT model. Within nanoGPT, the model is decomposed into the following PyTorch modules:
GPT
, containing
Embedding
s \(W_{te} \in \mathbb{R}^{d_m \times (n_v = 50257)}\)Embedding
s \(W_{pe} \in \mathbb{R}^{d_m \times (n_b = 2048)}\)Dropout
layer, applied to the input embedding \(W_{te}U + W_{pe}\).Block
s.Linear
layer with parameter \(W_{te}^T\).Block
, containing:
LayerNorm
.CausalSelfAttention
.LayerNorm
.MLP
.CausalSelfAttention
, containing:
Linear
layer with parameter \(W \in \mathbb{R}^{3d_n [+1]\times d_n}\). - this layer is split into three, and should be considered \((\begin{smallmatrix}W_Q & W_K & W_V & [b]\end{smallmatrix})\)^{1}.Dropout
layer, which is applied to the attention matrix \(QK^T\).Linear
layer with parameter \(W \in \mathbb{R}^{d_n\times d_n [+1]}\).Dropout
layer, applied to the output.MLP
, containing:
Linear
layer with parameter \(W \in \mathbb{R}^{4d_n \times d_n [+1]}\).Linear
layer with parameter \(W \in \mathbb{R}^{d_n \times 4d_n [+1]}\).Dropout
layer applied to the output.LayerNorm
, a LayerNorm implementation that allows bias to be switched [on] and off.For more details on this architecture pray consult the previous post.
My plan is to rewrite these modules gradually, specifically to accelerate inference rather than training. To guide this project and measure its success, I need a fixed benchmark. I propose the following:
gpt2-xl
,
which has 1.5 billion parameters.A reasonable ansatz is that the distribution of the time this task takes
is approximately normal, so we can defer the question about sample sizes to later.
In a file called harness.py
we write something like:
def benchmark(
gpt: GPT,
batch_size=1,
sample_size=1
) -> float:
return [
sample(
gpt,
batch_size,
prompt_tokens=256,
sampled_tokens=256
)
for _ in tqdm(range(sample_size))
]
def sample(
gpt: GPT,
batch_size=1,
prompt_tokens=256,
sampled_tokens=256,
get_rand_input=get_rand_input,
) -> float:
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
u = get_rand_input(batch_size, prompt_tokens)
cuda_u = torch.from_numpy(u).cuda()
# run inference
torch.cuda.synchronize()
start.record()
idx = gpt.generate(cuda_u, sampled_tokens)
end.record()
torch.cuda.synchronize()
# visualize outputs
idx = idx.to("cpu").numpy()
for i in range(batch_size):
print(
decode(u[i]),
"🩹",
decode(idx[i][len(u[i]) :])
)
return start.elapsed_time(end) / 1e3
def get_rand_input(batch_size, seq_len):
md_tokens = get_moby_dick_tokens() # memoized
md_len = len(md_tokens)
return np.array(
[
md_tokens[start_idx : start_idx + seq_len]
for start_idx in (
random.randint(0, md_len - seq_len)
for _ in range(batch_size)
)
]
)
When I run the full version of this code on gpt2-xl
with a prompt length of just 128,
the GPU issues forth a long croak of coil whine,
and spits out:
loading weights from pretrained gpt: gpt2-xl
number of parameters: 1557.61M
Moby Dick has 305318 tokens
== SAMPLED TEXT FOLLOWS ==
on her way north-eastward towards the island of Java;
a gentle air impelling her keel, so that in the surro
unding serenity her three tall tapering masts mildly w
aved to that languid breeze, as three mild palms on a
plain. And still, at wide intervals in the silvery nig
ht, the lonely, alluring jet would be seen.
But one transparent blue morning, when a stillness alm
ost preternatural spread over the sea, however unatten
ded with any stagnant calm; when the long burnished su
n-glade on the waters seemed a golden finger laid acro
ss them, enjoining some secrecy 🩹 and secrecy it did
not admit of; when those spectacles of azure were star
ing through the panes, without the least bloodlust, ar
ound her passing from star to oar, which she could mee
t as those she ran, visible to him; her lewness of her
shiny of sea-silken've is never mentioned by any, sav
e in some tenth case; And, after thus maintaining a r
emote silence at anchor, on daybreak she appears to r
eason on the outside!
At an instant, as she adjusts her masts, the deserted
hulk of a powerful ship pierces her leeward. Huge colu
mns of
== END OF SAMPLED TEXT ==
time: 11.12s
Too late do I now realize that I chose the wrong text with which to test model correctness. It is, unfortunately, well within the realm of possibility for Melville to have written of the Pequod’s lew shiny of sea-silken’ve — but even if correctness is indistinguishable from gibberish, we can be assured that any subsequent implementations of the model are correct as long as their final activations match nanoGPT’s.
A glance at nvtop
gives us the story on the inference above:
First, memory usage climbs to 6 GiB. Once the model is loaded, the GPU’s utilization shoots up to 99%. Then, once inference is finished, 11.12 seconds later, it drops back down to zero.
It’s reasonable to assume that we’re compute-bottlenecked right now,
because nvidia-smi
is telling us that that GPU utilization is at 99%
and we have plenty of memory to spare.
After all,
1557.61M parameters of four bytes each makes 5.8 GiB,
which explains the memory utilization figure.
Doubling the batch size more or less doubles the time that inference takes,
suggesting that a single example already requires too many FLOPs to vectorize appropriately.
To figure out where we can squeeze out more compute efficiency, then, let’s review the GPT architecture. In the diagram below, which I’ve lovingly rendered with a modded-out implementation of mermaid, inputs are marked in green, parameters are marked in blue, matmuls are marked in red, and the output of a module is its unique topmost value.
%%{init: { "flowchart": {"useMaxWidth": true, "rankSpacing": 20} } }%%
flowchart BT
subgraph GPT
direction BT
classDef weight fill:#88f,stroke:#00a
classDef input fill:#8f8,stroke:#0a0
classDef matmul fill:#f88,stroke:#a00
Wte["W<sub>te</sub>"] --> WteU["W<sub>te</sub>U<sub><t</sub>"]
class WteU matmul
Wte --> Logits
class Wte weight
U["U<sub><t</sub>"] --> WteU
class U input
WteU --> Embedding["W<sub>te</sub>U<sub><t</sub> + W<sub>pe</sub> "]
Wpe["W<sub>pe</sub>"] ---> Embedding
class Wpe weight;
Embedding --> Block
subgraph Block["Block (x <i>n<sub>l</sub></i>)"]
direction BT
xb --> LayerNorm1[LayerNorm] --> CausalSelfAttention
xb[X] --> resid1["X = X + CSA(X)"]
class xb input
CausalSelfAttention --> resid1
subgraph CausalSelfAttention
direction BT
Wqkv["W<sub>qkv</sub>"] --> QKV["W<sub>qkv</sub>X"]
class Wqkv weight
xcsa[X] --> QKV
class QKV matmul
class xcsa input
QKV --> Q["Q"]
QKV --> K["K"]
QKV --> V["V"]
Q --> QK["QK<sup>T</sup>"]
K --> QK
class QK matmul
QK --> mQK["Mask(QK<sup>T</sup>)/√d<sub>m</sub>"] --> sQK["Softmax(Mask(QK<sup>T</sup>)/√d<sub>m</sub>)"]
sQK --> attn
V -----> attn["Softmax(Mask(QK<sup>T</sup>)/√d<sub>m</sub>)V"]
class attn matmul
end
resid1 --> LayerNorm2[LayerNorm] --> MLP
resid1 --> resid2["X = X + MLP(X)"]
MLP --> resid2
subgraph MLP
direction BT
xm[X] --> layer1[W<sub>1</sub>x]
class xm input
W1[W<sub>1</sub>] --> layer1
layer1 --> hlayer1["gelu(W<sub>1</sub>x)"]
class layer1 matmul
W2[W<sub>2</sub>] ----> layer2
hlayer1 --> layer2["W<sub>2</sub>gelu(W<sub>1</sub>x)"]
class layer2 matmul
class W1 weight
class W2 weight
end
end
Block --> Logits[W<sub>te</sub><sup>T</sup>Z]
class Logits matmul
Logits --> Multinomial["U<sub>t</sub> = Multinomial(exp(W<sub>te</sub><sup>T</sup>Z))"]
end
We care about matmuls here because all of the other nodes in the computation graph take \(O(mn)\) FLOPs, while naïve matrix multiplication \(\mathbb{R}^{m\times n} \times \mathbb{R}^{n \times p} \rightarrow \mathbb{R}^{m \times p}\) takes \(2mnp = O(mnp)\) FLOPs^{3}, which heuristically suggests that they should account for the largest share.
Here are our five matmuls, with \(t\) being the sequence length and all other variables as they are in the previous post:
The FLOP requirement of a single forward pass due to matmuls is therefore approximately
\[\begin{align}F(t) &= 6n_\ell td_m^2 + 4n_\ell t^2 d_m + 16n_\ell td_m^2 + 2n_vd_m \\&=22n_\ell td_m^2 + 4n_\ell t^2d_m + 2n_vd_m \\&=(4n_\ell d_m) t^2 + (22n_\ell d_m^2)t + 2n_vd_m. \end{align}\]The XL model has \(n_\ell = 48, d_m = 1600, n_v=50257\), so our polynomial simplifies to
\[F(t) \approx (3.1\times 10^5) t^2 + (2.7 \times 10^9)t + 1.6 \times 10^8.\]Here are some values of this function:
\(t\) | QKV | Attn | MLP | LM head | \(F(t)\) |
---|---|---|---|---|---|
\(256\) | \(1.89 \times 10^{11}\) | \(2.01 \times 10^{10}\) | \(5.03 \times 10^{11}\) | \(1.61 \times 10^{8}\) | \(7.12 \times 10^{11}\) |
\(512\) | \(3.77 \times 10^{11}\) | \(8.05 \times 10^{10}\) | \(1.01 \times 10^{12}\) | \(1.61 \times 10^{8}\) | \(1.46 \times 10^{12}\) |
\(1024\) | \(7.55 \times 10^{11}\) | \(3.22 \times 10^{11}\) | \(2.01 \times 10^{12}\) | \(1.61 \times 10^{8}\) | \(3.09 \times 10^{12}\) |
\(2048\) | \(1.51 \times 10^{12}\) | \(1.29 \times 10^{12}\) | \(4.03 \times 10^{12}\) | \(1.61 \times 10^{8}\) | \(6.83 \times 10^{12}\) |
We see that forward passes tend to take several teraflops each.
Because the nanoGPT autoregressive sampling code evaluates the forward pass for \(256 \leq t < 512\), the entire model should take at least \(\sum_{t = 256}^{511} F(t) \approx 2.77 \times 10^{14}\) FLOPs, or 277 TFLOPs. (The RTX 4090 spec says it has a peak FP32 TFLOPS of 82.6 on the boost clock, but I won’t pretend I’m using the GPU anywhere near optimally yet.)
This suggests our first optimization — an easy one at the algorithmic level, which requires little knowledge of how the hardware works but slashes the FLOP complexity from \(\Theta(n^3)\) to \(\Theta(n^2)\).
It is often said of causal self-attention that “later tokens in a sequence
do not affect the embeddings of earlier ones.”
In more concrete terms: if we let the output of the CausalSelfAttention
block
be the matrix \(A = \{a_{ij}\} \in \mathbb{R}^{t \times d_m}\),
we can compute each entry \(a_{ij}\) using just the values in
\(X_{\leq i}\in \mathbb{R}^{i \times d_m}\).
It is enlightening to examine why this is the case. Below is the architecture of causal self-attention:
%%{init: { "flowchart": {"useMaxWidth": true, "rankSpacing": 30, "nodeSpacing": 100} } }%%
flowchart BT
classDef weight fill:#88f,stroke:#00a
classDef input fill:#8f8,stroke:#0a0
classDef matmul fill:#f88,stroke:#a00
Wqkv["W<sub>qkv</sub>"] --> QKV["W<sub>qkv</sub>X"]
class Wqkv weight
xcsa[X] --> QKV
class QKV matmul
class xcsa input
QKV --> Q["Q"]
QKV --> K["K"]
QKV --> V["V"]
Q --> QK["QK<sup>T</sup>"]
K --> QK
class QK matmul
QK --> mQK["Mask(QK<sup>T</sup>)/√d<sub>m</sub>"] --> sQK["Softmax(Mask(QK<sup>T</sup>)/√d<sub>m</sub>)"]
sQK --> attn
V -----> attn["Softmax(Mask(QK<sup>T</sup>)/√d<sub>m</sub>)V"]
class attn matmul
Imagine passing in the two matrices \(X_{<t} = \left(\begin{smallmatrix}x_1^T \\ \vdots \\ x_{t - 1}^T \end{smallmatrix}\right) \in \mathbb{R}^{(t - 1) \times d_m}\) and \(X_{\leq t} = \left(\begin{smallmatrix}X_{<t} \\ x_{t}^T \end{smallmatrix}\right)\in \mathbb{R}^{t \times d_m}\). How does each node on the graph differ?
That is, each incremental sequence entry \(x_t\) adds the entry \(\mathrm{Softmax}(d^{-1/2}_mq_t^T K_{\leq t}^T)V_{\leq t}\) to the output attention matrix. If we can memoize \(K_{<t}\) and \(V_{<t}\) from a previous forward pass of the model, self-attention will only require the matrix multiplications
Thus, we have gone from \(4t^2d_m+ 6td_m^2 + O(d_m + t)\) to \(4td_m + 6d_m^2 + O(d_m)\) FLOPs in the incremental block. This technique has been referred to elsewhere as KV caching.
flowchart BT
classDef weight fill:#88f,stroke:#00a
classDef input fill:#8f8,stroke:#0a0
classDef matmul fill:#f88,stroke:#a00
subgraph CausalSelfAttention
direction BT
Wqkv["W<sub>qkv</sub>"] --> QKV["W<sub>qkv</sub>X"]
class Wqkv weight
xcsa[X] --> QKV
class QKV matmul
class xcsa input
QKV --> Q["Q<sub><t</sub>"]
QKV --> K["K<sub><t</sub>"]
QKV --> V["V<sub><t</sub>"]
Q --> QK["Q<sub><t</sub>K<sub><t</sub><sup>T</sup>"]
K --> QK
class QK matmul
QK --> mQK["Mask(Q<sub><t</sub>K<sub><t</sub><sup>T</sup>)"] --> sQK["Softmax(Mask(Q<sub><t</sub>K<sub><t</sub><sup>T</sup>)/√d<sub>m</sub>)"]
sQK --> attn
V -----> attn["Softmax(Mask(Q<sub><t</sub>K<sub><t</sub><sup>T</sup>)/√d<sub>m</sub>)V<sub><t</sub>"]
class attn matmul
end
subgraph IncrementalCausalSelfAttention
direction BT
_Wqkv["W<sub>qkv</sub>"] --> _QKV["W<sub>qkv</sub>x<sub>t</sub><sup>T</sup>"]
class _Wqkv weight
_xcsa[x] --> _QKV
class _QKV matmul
class _xcsa input
_QKV --> _Q["q<sub>t</sub>"]
_QKV --> _K["k<sub>t</sub>"]
_QKV --> _V["v<sub>t</sub>"]
_Q ---> _QK["q<sub>t</sub><sup>T</sup>K<sub>≤t</sub><sup>T</sup>"]
_K --> __K
K --> __K["K<sub>≤t</sub>"] -->_QK
class _QK matmul
_QK --> _mQK["Mask(q<sub>t</sub><sup>T</sup>K<sub>≤t</sub><sup>T</sup>)/√d<sub>m</sub>"] --> _sQK["Softmax(Mask(q<sub>t</sub><sup>T</sup>K<sub>≤t</sub><sup>T</sup>)/√d<sub>m</sub>)"]
_sQK --> _attn
_V --> __V
V --> __V["V<sub>≤t</sub>"] -----> _attn["Softmax(Mask(q<sub>t</sub><sup>T</sup>K<sub>≤t</sub><sup>T</sup>)/√d<sub>m</sub>)V<sub>≤t</sub>"]
class _attn matmul
end
A model that implements KV caching would run a first pass with the entire prompt to populate the cache, then a series of single-token passes that incrementally build on the cached values. There is no need to modify the feed-forward layer because it doesn’t involve interactions between multiple \(t\), but we should make sure to offset the positional embeddings \(W_{pe}\) accordingly.
We implement KV caching in implementations/memoized.py
.
We override the base implementation of CausalSelfAttention
with a MemoizedCausalSelfAttention
that stores \(K_{<t}\) and \(V_{<t}\)
as buffers, and augment it with the ability to function
both as a normal self-attention layer, overwriting the cache,
and as an incremental self-attention layer, reading from the cache
and writing \(k_t\) and \(v_t\) back to it.
Then we adapt the generate
code in the base class
to use our incremental self-attention functionality,
and add our new MemoizedGPT
to our test harness.
The code is here.
Let’s review where we are FLOP-wise. Our current sampling procedure is to run one “full” forward pass with \(t_{min} = 256\) tokens to populate the \(KV\) cache, then \(t_{max} - (t_{min} + 1) = 255\) “incremental” forward passes to sample autoregressively. Thus:
This works out to a total of 0.12 TFLOPs, broken down as follows:
From this calculation, we can draw the following observations:
To the best of my knowledge, the only obvious inference optimization remaining that reduces FLOPs is to fuse the mask to the matmul in the “full” attention step for an 8.4% speedup. (I could think of a few architectural changes that could get us further, like sparse attention or removing the bias on the \(QKV\) network.) We’ll get to it in due time, but it’s time to start thinking about making better use of the specific features of our hardware.
KV caching gives us a marked improvement in performance. First, the time taken for the execution of a single task falls from 11 seconds to 7 seconds. More importantly, however, we are able to run much larger batches. Whereas the base model was unable to run with a concurrent batch size greater than 1 (as it was already running with an “effective” batch size of up to 512), the KV-cached model can concurrently execute dozens of tasks.
In fact, the limiting factor turns out not to be FLOPS but memory;
my implementation instantiates the buffers with size \(B \times n_h \times L \times (d_m / n_h)\)
for fixed constants \((B, L)\) representing the maximum batch size supported
and the maximum sequence length supported.
Because we have two such buffers for each of the \(n_\ell=48\) layers,
and nanoGPT uses 4-byte fl32
for everything,
the total memory occupied is \(2 n_\ell BLd_m\cdot 4 = 2\cdot 48\cdot 1600 \cdot 4 \cdot BL = 614400BL\)
bytes.
And if we set \(L\) to 512 — the minimum value required to execute a task) —
we come to the unfortunate realization that our buffers take up
\(0.29B\) gibibytes and that we can only fit a batch size of 32 on the GPU.
Right?
Let’s say we could make all the values on the GPU take up half as much space as they did before. What should the new batch size be? Well, if the device could previously support a batch size of 32, we could argue heuristically that the parameters and buffers now take up just 12 GiB, and we can fit another 12 GiB / (0.15 GiB) = 80 buffers on the GPU for a total of 112.
We’d be wrong, though, because I’m also using the GPU to drive my 5K monitor, so it’s already got 1.75 GiB permanently set aside. Moreover, a larger batch size also increases the size of other tensors in the graph, which grow with the buffers. It all nets out to being able to support a batch size of 81, and I know this because actually halving the memory usage of all floats on the GPU to test this is the easy part:
# implementations/fp16.py, full text
from implementations.base import GPTConfig
from implementations.memoized import MemoizedGPT
class FP16MemoizedGPT(MemoizedGPT):
def __init__(self, config: GPTConfig):
super().__init__(config)
self.half()
@classmethod
def from_pretrained(cls, *args, **kwargs):
gpt = super().from_pretrained(*args, **kwargs)
gpt.half()
return gpt
Note that the FP16MemoizedGPT
is able to execute 81 tasks in the ballpark of six seconds.
That’s already about 150 times more throughput than the base implementation!
A natural question is whether we can make the jump from fp16 to fp8. Theoretically, this would be amazing for two reasons:
PyTorch doesn’t support fp8 yet, but Nvidia provides a Transformer Engine library^{6} that
effectively acts as an fp8 extension to PyTorch.
Transformer Engine is theoretically compatible with the Hopper and Ada GPU architectures,
and it contains drop-in replacements for torch.nn.Linear
and torch.nn.LayerNorm
,
as well as higher-level operators like LayerNormMLP
and TransformerLayer
.
Installing it is a little tricky, as it needs to be built with alongside 11.8,
which requires gcc and g++ version 11, which need to be installed separately and symlinked.
And when the whole thing is all set up and ready, we receive the final disappointment: despite having promised fp8 inference on Ada,
Nvidia is only delivering it in CUDA 12.1 in Q2.
How horrible! We’ll have to quantize to int8 instead.
Note that I follow the convention \(y = W(\begin{smallmatrix}x \\ [1]\end{smallmatrix})\) rather than \(y = xW + b\). ↩
The UTF-8 edition from Project Gutenberg has 305k tokens if we strip out the metadata, the leading spaces, and the intra-paragraph newlines. The newline transformation is critical; we need to format the book to resemble OpenWebText data to get the best results. ↩
Half are multiplications, the other half are additions. ↩
I haven’t included the bias because I’m lazy. ↩
Let’s say that rendering my desktop takes 1.75 GiB. Then 22.75 GiB remain for CUDA, consisting of parameters, buffers, and activations. The buffers and activations scale with the batch size and the weights do not, so if we let \(P\) be the total memory taken by the fp32 parameters and \(BA\) the memory taken by the fp32 buffers and activations, we have \(P + 32BA \approx P/2 + 81BA/2 \approx 22.5\), from which we may conclude that each batch element adds about 0.45 GiB of memory usage in the fp32 regime and about 0.11 GiB in the fp8 regime, for a theoretical batch size of around 136. ↩
Some marketing executive at Nvidia decided to pitch the Transformer Engine as a hardware feature that comes “bundled” with the Hopper architecture — an odd sell, as the Transformer Engine now comes “bundled” with 4000-series consumer GPUs as well. ↩