AI 模型的计算与存储
0引子:两个数字决定一切
训练一个 AI 大模型,本质上是反复做巨大的矩阵乘法。"巨大"到什么程度?又"反复"到什么程度?决定一台 GPU 能不能跑得动、跑得多快,几乎只看两个数字:
- 计算量(FLOPs):一共要做多少次乘加?硬件能跑多快?两者之比就是 MFU。
- 存储量:模型权重、优化器状态、梯度、还有那个会随着句子变长爆炸的 activation,加起来要多少 GB?
读完这份文档,你能回答:
我们用一个最经典的小模型做贯穿全文的算例:OpenAI 公开的 GPT-2 small(HuggingFace 上的 openai-community/gpt2)。它麻雀虽小,五脏俱全。
| 层数 L | 12 | 隐藏维度 d | 768 |
| 注意力头数 n_head | 12 | 每头维度 d_head | 64 |
| FFN 中间维 d_ff | 3072 (=4d) | 最大上下文 s_max | 1024 |
| 词表大小 V | 50 257 | 参数总量 N | ≈ 124 M |
1矩阵乘法 ABC
所有的"AI 算力",归根结底都是矩阵乘法的算力。先把这件事拆开看。
1.1 形状变换
两个矩阵 $A$ 和 $B$ 相乘,规则只有一条:左边的列数要等于右边的行数。结果矩阵的形状由两边"外侧"维度决定:
$$A_{\,M\times K} \;\cdot\; B_{\,K\times N} \;=\; C_{\,M\times N}$$
中间那个 $K$ 维度"被消掉"了——可以想象成两块乐高拼合时被压扁的那条边。
1.2 运算次数
$C$ 里面有 $M\times N$ 个格子。每一个格子怎么来?它等于 $A$ 的一行($K$ 个数)和 $B$ 的一列($K$ 个数)做点积:
- $K$ 次乘法
- $K-1$ 次加法(把 $K$ 个数加起来要加 $K-1$ 次)
- 合计 $\approx 2K$ 次运算
所以填满整个 $C$ 矩阵,总运算次数是:
总 FLOPs $\;\approx\; 2\,M\,N\,K$
- 输入一个 token:$x \in \mathbb{R}^{1\times 768}$
- 权重:$W \in \mathbb{R}^{768\times 3072}$
- 输出:$1\times 3072$
- FLOPs $= 2 \cdot 1 \cdot 3072 \cdot 768 \approx$ 4.7 M
2Attention 与 MoE
Transformer 之所以叫 Transformer,核心就在 Self-Attention。MoE(Mixture of Experts)则是近年大模型省算力的主流套路。这一节只给最小化的公式 + 一个直观图。
2.1 Self-Attention
给定输入序列 $X \in \mathbb{R}^{s\times d}$($s$ 个 token,每个 $d$ 维),先用三组权重把它投影成 Query、Key、Value:
$$Q = X W_Q,\quad K = X W_K,\quad V = X W_V$$
然后做一件神奇的事:
$$\text{Attention}(Q,K,V) \;=\; \text{softmax}\!\left(\frac{Q K^\top}{\sqrt{d_k}}\right)\,V$$
2.2 Mixture of Experts (MoE)
朴素 Transformer 里,每个 token 都要走完整个 FFN。MoE 的核心想法是:多准备几套 FFN,但每个 token 只挑其中 top-k 个用。
$$y \;=\; \sum_{i \,\in\, \text{TopK}(G(x))} G(x)_i \cdot E_i(x)$$
$E_i$ 是第 $i$ 个"专家"(结构上就是一个独立 FFN),$G$ 是一个小型路由器(router),输出一个分数向量,挑分数最高的 $k$ 个专家来算。
3FLOPs:一次训练究竟做多少次运算
第 1 节已经给了核心结论——每次矩阵乘 $\approx 2MNK$ FLOPs。Transformer 不过是一堆矩阵乘叠起来,下面把它一笔一笔加起来。
3.1 一个 Transformer block,一次前向,每个 token 多少 FLOPs?
记 $d$ 是隐藏维度(GPT-2 small 里 $d=768$),$s$ 是当前序列长度。我们把一个 block 的子模块逐个列出来(按"一个 token 一次过"的视角):
| 子模块 | 形状 | FLOPs / token |
|---|---|---|
| Q、K、V 三个投影 | $(1\times d)\cdot(d\times d)\;\times 3$ | $6\,d^2$ |
| Attention 打分 $QK^\top$ | $(1\times d)\cdot(d\times s)$ | $2\,s\,d$ |
| Attention 加权 $\cdot V$ | $(1\times s)\cdot(s\times d)$ | $2\,s\,d$ |
| 输出投影 | $(1\times d)\cdot(d\times d)$ | $2\,d^2$ |
| FFN 升维 ($d\to 4d$) | $(1\times d)\cdot(d\times 4d)$ | $8\,d^2$ |
| FFN 降维 ($4d\to d$) | $(1\times 4d)\cdot(4d\times d)$ | $8\,d^2$ |
| 合计 | $24\,d^2 + 4\,s\,d$ | |
3.2 全模型口诀:训练一步 $\approx 6\,N\,T$
把所有 block 加起来,再算上 embedding 等小头,工业界用一个非常顺口的近似式:
训练一步 FLOPs $\;\approx\; 6\,N\,T$
其中 $N$ 是模型参数量,$T$ 是这一步消化的 token 数(= batch × seq)。
- Forward:$2 N T \;=\; 2 \times 124\text{M} \times 1024 \;\approx$ 254 GFLOPs
- Backward:$4 N T \;\approx$ 508 GFLOPs
- 合计一步:762 GFLOPs ≈ 0.76 TFLOPs
4MFU:你的 GPU 用得有多满
厂商在 spec sheet 上写的 "H100 989 TFLOPs" 只是理论峰值。真实训练里,由于 kernel 启动开销、内存搬运、通信、softmax 这种非 GEMM 算子等等,你能达到的实际 FLOPs/s 永远低于这个数。两者之比,就叫做 MFU(Model FLOPs Utilization):
$$\text{MFU} \;=\; \frac{\text{实际达成的 FLOPs/s}}{\text{硬件理论峰值 FLOPs/s}}$$
- 配置:batch = 8,seq = 1024 → 每步 $T = 8\,192$ token
- 每步 FLOPs $\approx 6 \times 124\text{M} \times 8\,192 \approx 6.1 \times 10^{12} =$ 6.1 TFLOPs
- 假设实测一步耗时 60 ms,那么实际吞吐 $= 6.1 / 0.06 \approx$ 102 TFLOPs/s
- H100 SXM BF16 峰值 = 989 TFLOPs/s
- $\text{MFU} = 102 / 989 \approx$ 10.3 %
5存储:Model States 与 Activation
前面四节都在算"算力",但真正经常逼得人换卡、改架构的,是显存。训练时显存分两大块:
- Model States:模型权重、梯度、优化器状态。大小不随序列长度变化。
- Activation:前向时为了反向传播留下的"草稿纸"。随 batch × seq 线性甚至平方增长。
5.1 Model States:训练为什么比推理贵 8 倍
现代主流配方是混合精度训练 + Adam 优化器。逐项算每个参数要占多少字节:
| 项目 | 精度 | 字节 / 参数 |
|---|---|---|
| 权重副本(前向用) | FP16/BF16 | 2 |
| 权重主副本(更新用) | FP32 | 4 |
| 梯度 | FP16/BF16 | 2 |
| Adam 一阶动量 $m$ | FP32 | 4 |
| Adam 二阶动量 $v$ | FP32 | 4 |
| 合计 | 16 字节 | |
- 训练(mixed-precision + Adam):$124\text{M} \times 16\text{ B} \approx$ 2.0 GB
- 推理(纯 FP16):$124\text{M} \times 2\text{ B} \approx$ 248 MB
5.2 Activation:那张随长文平方膨胀的草稿纸
反向传播要用到前向时的中间结果。比如 softmax 反传要知道"当时输出是多少"。这些中间值都得在显存里留着,直到对应的反向阶段把它们消耗掉。
回到第 2 节图②里那个 $s\times s$ 的 attention 矩阵——它就是 activation 显存的主角。每一层、每个 head 都要存一份用于反向传播,所以总字节数大约是:
$$\text{Attention activation} \;\approx\; L \cdot n_\text{head} \cdot b \cdot s^2 \cdot c \cdot \text{bytes}$$
其中 $c\approx 10$ 是工程经验系数(softmax 前的分数、softmax 后的权重、dropout mask 等都得各存一份)。注意里面那个 $s^2$——它是关键。
带入 GPT-2 small($L=12,\,n_\text{head}=12$,batch $b=1$,FP16):
| 序列 $s$ | Attention activation | 相对于上一行 | 对照 model states (2 GB) |
|---|---|---|---|
| 512 | ≈ 0.4 GB | — | 1/5 |
| 1024 | ≈ 1.5 GB | × 4 | 3/4 |
| 2048 | ≈ 6 GB | × 4 | 3× |
| 4096 | ≈ 24 GB | × 4 | 12× |
| 8192 | ≈ 97 GB | × 4 | 48× ⚠ 撑爆 H100 |
把刻度拉到现代大模型的真实工作区间——从 64K 一路推到 1M token——光是 attention 矩阵自己的格子数(单层单头)就长成这样:
为什么"翻一倍 token = 4 倍格子"?换个最简单的玩具尺度(4 vs 32 tokens)直观体会一下:
A附录:GPT-2 small 速查表
| 项目 | 数值 |
|---|---|
| 参数总量 $N$ | ≈ 124 M |
| Embedding (token + position) | ≈ 39 M |
| 每层 Attention (QKV + 输出投影) | ≈ 2.36 M × 12 = 28 M |
| 每层 FFN (升维 + 降维) | ≈ 4.72 M × 12 = 57 M |
| 训练一步 FLOPs (batch 1, seq 1024) | ≈ 0.76 TFLOPs (6NT) |
| Model states (fp16+fp32 master+Adam) | ≈ 2.0 GB |
| 推理显存 (fp16 only) | ≈ 248 MB |
| Activation @ seq=1024, batch=1 | ≈ 2 GB |
本文做了哪些简化
- 忽略 LayerNorm、bias、dropout 等 $O(d)$ 级别的小项。
- "$6N$ 经验式"假设 attention FLOPs 远小于线性层——上下文较短($s\!\ll\!6d$)时成立。
- Activation 估算用的是 Megatron 论文中"无 activation checkpointing、无 Flash Attention"的标准实现。引入 Flash Attention 后,attention 的 $s^2$ 中间矩阵不再实例化,平方项几乎消失。
- 未讨论 ZeRO 分片、张量并行、序列并行——它们改变 model states / activation 在多卡之间的分布,但不改变本文给出的"单卡视角"的总量公式。