AI 模型的计算与存储

从矩阵乘法到 FLOPs、MFU、显存——以 GPT-2 small (124M) 为例

0引子:两个数字决定一切

训练一个 AI 大模型,本质上是反复做巨大的矩阵乘法。"巨大"到什么程度?又"反复"到什么程度?决定一台 GPU 能不能跑得动、跑得多快,几乎只看两个数字:

读完这份文档,你能回答:

我们用一个最经典的小模型做贯穿全文的算例:OpenAI 公开的 GPT-2 small(HuggingFace 上的 openai-community/gpt2)。它麻雀虽小,五脏俱全。

GPT-2 small · 名片卡
层数 L12隐藏维度 d768
注意力头数 n_head12每头维度 d_head64
FFN 中间维 d_ff3072 (=4d)最大上下文 s_max1024
词表大小 V50 257参数总量 N≈ 124 M

1矩阵乘法 ABC

所有的"AI 算力",归根结底都是矩阵乘法的算力。先把这件事拆开看。

1.1 形状变换

两个矩阵 $A$ 和 $B$ 相乘,规则只有一条:左边的列数要等于右边的行数。结果矩阵的形状由两边"外侧"维度决定:

$$A_{\,M\times K} \;\cdot\; B_{\,K\times N} \;=\; C_{\,M\times N}$$

中间那个 $K$ 维度"被消掉"了——可以想象成两块乐高拼合时被压扁的那条边

A M K × B K N = C M N K 维"被消掉"
图 ① 矩阵乘法的形状变换:$(M\times K)\cdot(K\times N)\to(M\times N)$。

1.2 运算次数

$C$ 里面有 $M\times N$ 个格子。每一个格子怎么来?它等于 $A$ 的一行($K$ 个数)和 $B$ 的一列($K$ 个数)做点积

所以填满整个 $C$ 矩阵,总运算次数是:

总 FLOPs $\;\approx\; 2\,M\,N\,K$

动手算一遍:GPT-2 的 FFN 升维层
GPT-2 的每个 Transformer 块里都有一个 "FFN 升维":把 768 维的向量映射成 3072 维。 仅仅"一个 token、一个子模块、一次运算",就已经接近 500 万次乘加。

2Attention 与 MoE

Transformer 之所以叫 Transformer,核心就在 Self-Attention。MoE(Mixture of Experts)则是近年大模型省算力的主流套路。这一节只给最小化的公式 + 一个直观图。

2.1 Self-Attention

给定输入序列 $X \in \mathbb{R}^{s\times d}$($s$ 个 token,每个 $d$ 维),先用三组权重把它投影成 Query、Key、Value:

$$Q = X W_Q,\quad K = X W_K,\quad V = X W_V$$

然后做一件神奇的事:

$$\text{Attention}(Q,K,V) \;=\; \text{softmax}\!\left(\frac{Q K^\top}{\sqrt{d_k}}\right)\,V$$

X s × d Q K V 三个 s × d QKᵀ / √d s × s ↑ 长文罪魁 softmax 注意力权重 乘以 V Output s × d
图 ② Self-Attention 数据流。$QK^\top$ 是一个 $s\times s$ 的方阵——它是后面"长文爆显存"的源头。

2.2 Mixture of Experts (MoE)

朴素 Transformer 里,每个 token 都要走完整个 FFN。MoE 的核心想法是:多准备几套 FFN,但每个 token 只挑其中 top-k 个用

$$y \;=\; \sum_{i \,\in\, \text{TopK}(G(x))} G(x)_i \cdot E_i(x)$$

$E_i$ 是第 $i$ 个"专家"(结构上就是一个独立 FFN),$G$ 是一个小型路由器(router),输出一个分数向量,挑分数最高的 $k$ 个专家来算。

token x Router G top-k 选择 E₁ ★ E₂ E₃ E₄ ★ E₅ E₆ E₇ E₈ 加权求和 output y ★ = 被激活的专家,其余 6 个专家本次不参与运算
图 ③ MoE 示意。8 个专家、top-2 激活:参数量 ×8,但每 token FLOPs 只 ×2。

3FLOPs:一次训练究竟做多少次运算

第 1 节已经给了核心结论——每次矩阵乘 $\approx 2MNK$ FLOPs。Transformer 不过是一堆矩阵乘叠起来,下面把它一笔一笔加起来。

3.1 一个 Transformer block,一次前向,每个 token 多少 FLOPs?

记 $d$ 是隐藏维度(GPT-2 small 里 $d=768$),$s$ 是当前序列长度。我们把一个 block 的子模块逐个列出来(按"一个 token 一次过"的视角):

子模块形状FLOPs / token
Q、K、V 三个投影$(1\times d)\cdot(d\times d)\;\times 3$$6\,d^2$
Attention 打分 $QK^\top$$(1\times d)\cdot(d\times s)$$2\,s\,d$
Attention 加权 $\cdot V$$(1\times s)\cdot(s\times d)$$2\,s\,d$
输出投影$(1\times d)\cdot(d\times d)$$2\,d^2$
FFN 升维 ($d\to 4d$)$(1\times d)\cdot(d\times 4d)$$8\,d^2$
FFN 降维 ($4d\to d$)$(1\times 4d)\cdot(4d\times d)$$8\,d^2$
合计$24\,d^2 + 4\,s\,d$
一个 token · 一个 block · forward 的 FLOPs 构成(d=768, s=1024) QKV proj3.5M QKᵀ1.6M ·V1.6M out1.2M FFN ↑ (d→4d)4.7M FFN ↓ (4d→d)4.7M FFN 占了 55%
图 ④ 一个 block 一个 token 的 forward FLOPs 拆解。蓝=投影类,橙=attention 类,绿=FFN 类。FFN 几乎占了一半多。

3.2 全模型口诀:训练一步 $\approx 6\,N\,T$

把所有 block 加起来,再算上 embedding 等小头,工业界用一个非常顺口的近似式:

训练一步 FLOPs $\;\approx\; 6\,N\,T$

其中 $N$ 是模型参数量,$T$ 是这一步消化的 token 数(= batch × seq)。

动手算一遍:GPT-2 small 训练一步多少 FLOPs?
设 batch = 1,seq = 1024,所以 $T = 1024$: 这意味着,一张 H100(BF16 峰值 989 TFLOPs/s)理论上 不到 1 毫秒就能跑完一步——但实际远做不到,理由见下一节。

4MFU:你的 GPU 用得有多满

厂商在 spec sheet 上写的 "H100 989 TFLOPs" 只是理论峰值。真实训练里,由于 kernel 启动开销、内存搬运、通信、softmax 这种非 GEMM 算子等等,你能达到的实际 FLOPs/s 永远低于这个数。两者之比,就叫做 MFU(Model FLOPs Utilization)

$$\text{MFU} \;=\; \frac{\text{实际达成的 FLOPs/s}}{\text{硬件理论峰值 FLOPs/s}}$$

动手算一遍:GPT-2 small × H100,MFU 是多少?

5存储:Model States 与 Activation

前面四节都在算"算力",但真正经常逼得人换卡、改架构的,是显存。训练时显存分两大块:

  1. Model States:模型权重、梯度、优化器状态。大小不随序列长度变化
  2. Activation:前向时为了反向传播留下的"草稿纸"。随 batch × seq 线性甚至平方增长

5.1 Model States:训练为什么比推理贵 8 倍

现代主流配方是混合精度训练 + Adam 优化器。逐项算每个参数要占多少字节:

项目精度字节 / 参数
权重副本(前向用)FP16/BF162
权重主副本(更新用)FP324
梯度FP16/BF162
Adam 一阶动量 $m$FP324
Adam 二阶动量 $v$FP324
合计16 字节
每个参数训练时占 16 字节 = 推理(2 字节)的 8 倍 W(fp16) 2 B W(fp32 主) 4 B grad 2 B Adam m 4 B Adam v 4 B
图 ⑤ 混合精度 + Adam 下,每个参数训练时需要存 5 份东西,共 16 字节。
GPT-2 small 的 model states
这 2 GB 是"地板"——还没算任何中间结果,序列长度也还没出场。

5.2 Activation:那张随长文平方膨胀的草稿纸

反向传播要用到前向时的中间结果。比如 softmax 反传要知道"当时输出是多少"。这些中间值都得在显存里留着,直到对应的反向阶段把它们消耗掉。

回到第 2 节图②里那个 $s\times s$ 的 attention 矩阵——它就是 activation 显存的主角。每一层、每个 head 都要存一份用于反向传播,所以总字节数大约是:

$$\text{Attention activation} \;\approx\; L \cdot n_\text{head} \cdot b \cdot s^2 \cdot c \cdot \text{bytes}$$

其中 $c\approx 10$ 是工程经验系数(softmax 前的分数、softmax 后的权重、dropout mask 等都得各存一份)。注意里面那个 $s^2$——它是关键

带入 GPT-2 small($L=12,\,n_\text{head}=12$,batch $b=1$,FP16):

序列 $s$Attention activation相对于上一行对照 model states (2 GB)
512 ≈ 0.4 GB 1/5
1024 ≈ 1.5 GB × 4 3/4
2048 ≈ 6 GB × 4
4096 ≈ 24 GB × 4 12×
8192 ≈ 97 GB × 4 48× ⚠ 撑爆 H100

把刻度拉到现代大模型的真实工作区间——从 64K 一路推到 1M token——光是 attention 矩阵自己的格子数(单层单头)就长成这样:

Attention 矩阵的格子数 vs 序列长度(对数纵轴 · 单层 · 单头) 1 B 10 B 100 B 1 T 格子数(log) 4.3 B 64 K 17 B 128 K 69 B 256 K 275 B 512 K 1.1 T 1 M ×4 ×4 ×4 ×4 序列长度 s(每翻一倍 →) Attention memory ∝ s² · 翻一倍 token,矩阵翻 4 倍格子
图 ⑥ 长上下文场景下 attention 矩阵的格子数。从 64K 到 1M 翻 16 倍 token,对应 256 倍格子(4.3 B → 1.1 T)。乘上层数 L、头数 n_head、以及每格 ~10 字节,就是它在显存里真实占用的体量。

为什么"翻一倍 token = 4 倍格子"?换个最简单的玩具尺度(4 vs 32 tokens)直观体会一下:

Attention memory 正比于 N²:token 数 ×8,attention 矩阵格子数 ×64
图 ⑦ 把 token 数从 4 翻到 32(×8 倍),attention 矩阵的格子就从 16 个暴涨到 1024 个(×64 倍)。这就是图⑥里"×4"反复出现的直观来源。

A附录:GPT-2 small 速查表

项目数值
参数总量 $N$≈ 124 M
Embedding (token + position)≈ 39 M
每层 Attention (QKV + 输出投影)≈ 2.36 M × 12 = 28 M
每层 FFN (升维 + 降维)≈ 4.72 M × 12 = 57 M
训练一步 FLOPs (batch 1, seq 1024)≈ 0.76 TFLOPs (6NT)
Model states (fp16+fp32 master+Adam)≈ 2.0 GB
推理显存 (fp16 only)≈ 248 MB
Activation @ seq=1024, batch=1≈ 2 GB

本文做了哪些简化