上一篇我们学会了 Self-Attention,这一篇我们来学习如何把它玩坏。
前言:Attention 很好,但是...
上一篇文章,我们搞懂了 Transformer 的核心——Self-Attention。
它很强大,但有一个致命的问题:太慢了,太费显存了。
$$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V $$
看到那个 $QK^T$ 了吗?这是一个 $n \times n$ 的矩阵,$n$ 是序列长度。
这意味着:
- 时间复杂度:$O(n^2 \cdot d)$
- 空间复杂度:$O(n^2)$
当 $n = 1000$ 时,需要存储 100 万个注意力分数。
当 $n = 100000$ 时,需要存储 100 亿 个注意力分数。
这就是为什么早期的 GPT 只能处理 2048 个 token,而你想让它读完一本书?做梦。
过去几年,研究者们想尽办法优化 Attention,催生了一大堆变体。
今天,我们来盘点这些"注意力的七十二变"。
一、问题诊断:Attention 到底慢在哪?
在开始优化之前,我们需要先搞清楚瓶颈在哪。
1.1 计算瓶颈 vs 内存瓶颈
很多人以为 Attention 慢是因为计算量大。错!
现代 GPU 的计算能力很强,真正的瓶颈是 内存带宽。
容量大,带宽低
~2TB/s"] SRAM["SRAM (片上缓存)
容量小,带宽高
~19TB/s"] Compute["计算单元
312 TFLOPS (A100)"] end HBM <-->|"瓶颈!"| SRAM SRAM <--> Compute style HBM fill:#ff6b6b
A100 GPU 的数据:
- 计算能力:312 TFLOPS (FP16)
- HBM 带宽:2 TB/s
- SRAM 容量:20 MB(很小!)
这意味着什么?
计算 1 TFLOP 需要的时间:$\frac{1}{312} \approx 3.2$ ms
从 HBM 读取 1 TB 需要的时间:$\frac{1}{2} = 500$ ms
差了 150 倍!
所以,如果你的算法需要频繁地在 HBM 和 SRAM 之间搬运数据,计算单元就会闲着等数据——这就是 内存带宽瓶颈(Memory-bound)。
1.2 标准 Attention 的内存访问
让我们看看标准 Attention 是怎么执行的:
# 标准 Attention 的执行流程(伪代码)
def standard_attention(Q, K, V):
# Step 1: 计算 QK^T,写回 HBM
S = Q @ K.T # 从 HBM 读 Q, K,写 S 到 HBM
# Step 2: 计算 softmax,写回 HBM
P = softmax(S) # 从 HBM 读 S,写 P 到 HBM
# Step 3: 计算 PV,写回 HBM
O = P @ V # 从 HBM 读 P, V,写 O 到 HBM
return O问题在于:中间结果 S 和 P 都是 $n \times n$ 的大矩阵,必须存到 HBM 里。
这导致了大量的内存读写,而内存带宽是瓶颈。
1.3 优化方向
基于以上分析,Attention 的优化有三个方向:
接下来,我们逐一介绍。
二、稀疏注意力:不是所有位置都重要
2.1 核心思想
标准 Attention 让每个 token 都关注所有其他 token。但真的有必要吗?
想想你读文章的时候:
- 读到代词"它"时,你会回看最近的名词
- 读到总结段落时,你会回看开头
- 大部分时候,你主要关注附近的内容
稀疏注意力的思路就是:只计算"重要"位置的注意力,忽略其他位置。
2.2 Local Attention(滑动窗口)
最简单的稀疏模式:只关注附近的 token。
标准 Attention(每个位置关注所有位置):
位置 0: [1, 1, 1, 1, 1, 1, 1, 1]
位置 1: [1, 1, 1, 1, 1, 1, 1, 1]
位置 2: [1, 1, 1, 1, 1, 1, 1, 1]
...
Local Attention(窗口大小=3):
位置 0: [1, 1, 0, 0, 0, 0, 0, 0]
位置 1: [1, 1, 1, 0, 0, 0, 0, 0]
位置 2: [0, 1, 1, 1, 0, 0, 0, 0]
位置 3: [0, 0, 1, 1, 1, 0, 0, 0]
...复杂度:从 $O(n^2)$ 降到 $O(n \cdot w)$,其中 $w$ 是窗口大小。
代表模型:Mistral 使用 Sliding Window Attention,窗口大小 4096。
def sliding_window_attention(Q, K, V, window_size):
"""滑动窗口注意力"""
seq_len = Q.shape[1]
# 创建滑动窗口 mask
mask = torch.zeros(seq_len, seq_len)
for i in range(seq_len):
start = max(0, i - window_size // 2)
end = min(seq_len, i + window_size // 2 + 1)
mask[i, start:end] = 1
# 计算注意力(只在 mask=1 的位置)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(Q.shape[-1])
scores = scores.masked_fill(mask == 0, float('-inf'))
attn_weights = F.softmax(scores, dim=-1)
return torch.matmul(attn_weights, V)2.3 Longformer:混合稀疏模式
纯 Local Attention 的问题:无法捕捉长距离依赖。
Longformer 的解决方案:Local + Global 混合。
每个 token 关注窗口内"] Global["Global Attention
特殊 token 关注所有"] end Local --> Mix["混合使用"] Global --> Mix
具体来说:
- 大部分 token:使用滑动窗口(Local)
- 特殊 token(如 [CLS]、段落开头):关注所有位置(Global)
[CLS] token: [1, 1, 1, 1, 1, 1, 1, 1] ← Global
普通 token: [1, 1, 1, 0, 0, 0, 0, 0] ← Local
普通 token: [0, 1, 1, 1, 0, 0, 0, 0] ← Local
段落开头: [1, 1, 1, 1, 1, 1, 1, 1] ← Global2.4 BigBird:加入随机连接
BigBird 在 Longformer 基础上又加了一个:随机注意力。
随机关注几个位置"] L["Local
关注邻近位置"] G["Global
特殊 token"] end R --> Final["完整注意力"] L --> Final G --> Final
随机连接的好处:即使两个 token 很远,也有概率"直接对话",信息传播更快。
2.5 稀疏注意力的效果
| 模型 | 最大长度 | 复杂度 | 方法 |
|---|---|---|---|
| 标准 Transformer | 512-2K | $O(n^2)$ | Full |
| Longformer | 4K-16K | $O(n)$ | Local + Global |
| BigBird | 4K-16K | $O(n)$ | Local + Global + Random |
| Mistral | 32K | $O(n \cdot w)$ | Sliding Window |
问题:稀疏注意力虽然快,但会损失一些信息。现代大模型更倾向于用 Flash Attention 来加速,而不是牺牲注意力的完整性。
三、线性注意力:从 $O(n^2)$ 到 $O(n)$
3.1 核心思想
标准 Attention 的复杂度来自于 $QK^T$ 这个矩阵乘法:
$$ \text{Attention} = \text{softmax}(QK^T) V $$
线性注意力的想法:能不能避免显式计算 $QK^T$?
关键观察:如果去掉 softmax,可以改变计算顺序!
$$ \text{Linear Attention} = Q(K^T V) $$
矩阵乘法是结合的:
- $(\text{softmax}(QK^T)) V$:先算 $n \times n$,再乘 $V$,复杂度 $O(n^2 d)$
- $Q(K^T V)$:先算 $K^T V$($d \times d$),再乘 $Q$,复杂度 $O(nd^2)$
当 $n >> d$ 时,后者快得多!
3.2 用核函数替代 Softmax
但问题是,softmax 保证了注意力权重是正的、和为 1。去掉 softmax 会破坏这些性质。
解决方案:用核函数(Kernel)近似 softmax。
$$ \text{softmax}(q_i^T k_j) \approx \phi(q_i)^T \phi(k_j) $$
其中 $\phi$ 是一个特征映射函数。
def linear_attention(Q, K, V, feature_map):
"""线性注意力"""
# 应用特征映射
Q = feature_map(Q) # (batch, seq, feature_dim)
K = feature_map(K)
# 先算 K^T V (d × d 矩阵)
KV = torch.einsum('bnd,bnm->bdm', K, V)
# 再用 Q 乘
output = torch.einsum('bnd,bdm->bnm', Q, KV)
# 归一化
normalizer = torch.einsum('bnd,bd->bn', Q, K.sum(dim=1))
output = output / normalizer.unsqueeze(-1)
return output
def elu_feature_map(x):
"""ELU 特征映射"""
return F.elu(x) + 13.3 RWKV:RNN 的复仇
RWKV 是一个很有意思的模型:它用 RNN 的形式实现了类似 Transformer 的效果。
核心思想:把 Attention 重写成可以递归计算的形式。
# RWKV 的 Time Mixing(简化版)
def rwkv_time_mixing(x, state, w, u, k, v):
"""
x: 当前输入
state: 上一步的状态
w, u, k, v: 可学习参数
"""
# 计算 k, v
k = x @ W_k
v = x @ W_v
# 更新状态(类似 RNN)
wkv = state * w + k * v
state = state * w + k
# 输出
output = wkv / state
return output, state优点:
- 推理时复杂度 $O(1)$(相对于序列长度)
- 可以处理无限长序列
- 训练可以并行
缺点:
- 效果略逊于标准 Transformer
- 社区支持较少
3.4 线性注意力的现状
现实情况:线性注意力在学术上很有意思,但实际应用中,大家更多用 Flash Attention 来加速标准 Attention,而不是换成线性注意力。
原因:
- 效果差距:线性注意力在长距离依赖上仍有差距
- 工程成熟度:Flash Attention 已经很成熟
- 硬件优化:GPU 对矩阵乘法优化很好
四、Flash Attention:内存优化的艺术
这是目前最实用的 Attention 优化技术,几乎所有现代大模型都在用。
4.1 核心思想
Flash Attention 的核心洞察:标准 Attention 的瓶颈不是计算,而是内存访问。
解决方案:Tiling(分块)+ 重计算。
不要把整个 $n \times n$ 的注意力矩阵存到 HBM,而是:
- 把 Q, K, V 分成小块
- 每次只在 SRAM 里计算一小块
- 用 Online Softmax 逐块更新结果
- 最终结果直接写回 HBM,不存中间矩阵
4.2 Online Softmax
Flash Attention 的关键技术是 Online Softmax:不需要看完整个序列就能计算 softmax。
标准 softmax 需要两次遍历:
- 第一遍:找最大值 $m = \max(x)$
- 第二遍:计算 $\text{softmax}(x) = \frac{e^{x-m}}{\sum e^{x-m}}$
Online Softmax 可以一遍完成,通过增量更新:
def online_softmax(x_blocks):
"""Online Softmax: 逐块更新"""
m = float('-inf') # 当前最大值
l = 0.0 # 当前归一化因子
for x_block in x_blocks:
# 更新最大值
m_new = max(m, x_block.max())
# 更新归一化因子
l = l * exp(m - m_new) + exp(x_block - m_new).sum()
m = m_new
# 最终的 softmax 可以从 m 和 l 计算出来
return m, l4.3 Flash Attention 算法
def flash_attention(Q, K, V, block_size=64):
"""
Flash Attention 简化实现
实际实现是 CUDA kernel,这里只展示逻辑
"""
batch, seq_len, d = Q.shape
O = torch.zeros_like(Q)
# 分块
num_blocks = (seq_len + block_size - 1) // block_size
for i in range(num_blocks):
# 当前 Q 块
q_block = Q[:, i*block_size:(i+1)*block_size, :]
# 初始化 Online Softmax 状态
m_i = torch.full((batch, block_size), float('-inf'))
l_i = torch.zeros(batch, block_size)
o_i = torch.zeros(batch, block_size, d)
for j in range(num_blocks):
# 当前 K, V 块
k_block = K[:, j*block_size:(j+1)*block_size, :]
v_block = V[:, j*block_size:(j+1)*block_size, :]
# 计算注意力分数(在 SRAM 中)
s_ij = q_block @ k_block.transpose(-2, -1) / math.sqrt(d)
# Online Softmax 更新
m_new = torch.maximum(m_i, s_ij.max(dim=-1).values)
# 更新归一化因子和输出
exp_old = torch.exp(m_i - m_new)
exp_new = torch.exp(s_ij - m_new.unsqueeze(-1))
l_new = l_i * exp_old + exp_new.sum(dim=-1)
o_i = (o_i * l_i.unsqueeze(-1) * exp_old.unsqueeze(-1) +
exp_new @ v_block) / l_new.unsqueeze(-1)
m_i = m_new
l_i = l_new
# 写回结果
O[:, i*block_size:(i+1)*block_size, :] = o_i
return O4.4 Flash Attention 的效果
| 序列长度 | 标准 Attention | Flash Attention | 加速比 |
|---|---|---|---|
| 1K | 1x | 2-3x | 2-3x |
| 4K | 1x | 3-4x | 3-4x |
| 16K | OOM | ✓ | ∞ |
| 64K | OOM | ✓ | ∞ |
关键优势:
- 显存节省:不存储 $n \times n$ 的注意力矩阵
- 速度提升:减少 HBM 访问次数
- 支持长序列:显存占用从 $O(n^2)$ 降到 $O(n)$
4.5 Flash Attention 2 & 3
Flash Attention 2 的改进:
- 更好的并行策略
- 减少非矩阵乘法运算
- 速度再提升 2x
Flash Attention 3(2024)的改进:
- 利用 Hopper GPU(H100)的新特性
- 异步执行、硬件加速
- 接近理论峰值性能
4.6 如何使用 Flash Attention
# 方法 1:使用 PyTorch 内置(2.0+)
import torch.nn.functional as F
# 自动使用 Flash Attention(如果可用)
output = F.scaled_dot_product_attention(query, key, value)
# 方法 2:使用 flash-attn 库
from flash_attn import flash_attn_func
output = flash_attn_func(q, k, v, causal=True)
# 方法 3:使用 transformers 库
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b",
attn_implementation="flash_attention_2" # 指定使用 Flash Attention
)五、GQA/MQA:KV Cache 的优化
5.1 KV Cache 的显存问题
在自回归生成时,我们需要缓存之前所有 token 的 K 和 V(KV Cache)。
KV Cache 大小计算:
$$ \text{KV Cache} = 2 \times n_{layers} \times n_{heads} \times d_{head} \times \text{seq\_len} \times \text{batch\_size} $$
以 LLaMA-2-70B 为例:
- 80 层,64 个 head,head 维度 128
- 序列长度 4096,batch size 1
$$ \text{KV Cache} = 2 \times 80 \times 64 \times 128 \times 4096 \times 2\text{ bytes} = \textbf{10.7 GB} $$
光 KV Cache 就要 10GB 显存!如果 batch size 大一点,或者序列更长,显存直接爆炸。
5.2 Multi-Query Attention (MQA)
核心思想:所有 Query head 共享同一组 K 和 V。
效果:KV Cache 减少到原来的 $\frac{1}{n_{heads}}$。
代价:效果会略有下降。
5.3 Grouped-Query Attention (GQA)
GQA 是 MHA 和 MQA 的折中:把 Query head 分组,每组共享一个 K/V head。
LLaMA 2 的选择:
- 70B 模型:8 个 KV head(64 个 Q head)
- KV Cache 减少 8 倍,效果几乎无损
5.4 代码实现
class GroupedQueryAttention(nn.Module):
"""Grouped-Query Attention"""
def __init__(self, embed_dim, num_q_heads, num_kv_heads):
super().__init__()
self.embed_dim = embed_dim
self.num_q_heads = num_q_heads
self.num_kv_heads = num_kv_heads
self.head_dim = embed_dim // num_q_heads
# Q 有 num_q_heads 个 head
self.W_q = nn.Linear(embed_dim, num_q_heads * self.head_dim)
# K, V 只有 num_kv_heads 个 head
self.W_k = nn.Linear(embed_dim, num_kv_heads * self.head_dim)
self.W_v = nn.Linear(embed_dim, num_kv_heads * self.head_dim)
self.W_o = nn.Linear(embed_dim, embed_dim)
# 每组有多少个 Q head
self.num_groups = num_q_heads // num_kv_heads
def forward(self, x, kv_cache=None):
batch, seq_len, _ = x.shape
# 投影
Q = self.W_q(x).view(batch, seq_len, self.num_q_heads, self.head_dim)
K = self.W_k(x).view(batch, seq_len, self.num_kv_heads, self.head_dim)
V = self.W_v(x).view(batch, seq_len, self.num_kv_heads, self.head_dim)
# 处理 KV Cache
if kv_cache is not None:
K = torch.cat([kv_cache['k'], K], dim=1)
V = torch.cat([kv_cache['v'], V], dim=1)
# 扩展 K, V 以匹配 Q 的 head 数
# (batch, seq, num_kv_heads, head_dim) -> (batch, seq, num_q_heads, head_dim)
K = K.repeat_interleave(self.num_groups, dim=2)
V = V.repeat_interleave(self.num_groups, dim=2)
# 计算注意力(后续和标准 MHA 相同)
Q = Q.transpose(1, 2) # (batch, num_q_heads, seq, head_dim)
K = K.transpose(1, 2)
V = V.transpose(1, 2)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
attn_weights = F.softmax(scores, dim=-1)
output = torch.matmul(attn_weights, V)
# 合并 head
output = output.transpose(1, 2).contiguous().view(batch, seq_len, -1)
return self.W_o(output)5.5 各方案对比
| 方案 | Q heads | KV heads | KV Cache | 效果 |
|---|---|---|---|---|
| MHA | 32 | 32 | 100% | 最好 |
| GQA-8 | 32 | 8 | 25% | 接近 MHA |
| GQA-4 | 32 | 4 | 12.5% | 略有下降 |
| MQA | 32 | 1 | 3.1% | 下降明显 |
实践建议:GQA 是目前的最佳实践,LLaMA 2/3、Mistral、Qwen 2 等都在用。
六、其他优化技术
6.1 投机采样(Speculative Decoding)
自回归生成的问题:每次只能生成一个 token。
投机采样的思路:用一个小模型先"猜"多个 token,然后让大模型一次性验证。
效果:在不损失质量的前提下,加速 2-3 倍。
6.2 Prefix Caching
如果多个请求有相同的前缀(如相同的 system prompt),可以共享 KV Cache。
# 传统方式:每个请求都计算完整的 KV Cache
request_1 = "You are a helpful assistant. What is 2+2?"
request_2 = "You are a helpful assistant. What is 3+3?"
# 两次都要计算 "You are a helpful assistant." 的 KV
# Prefix Caching:相同前缀只计算一次
prefix_cache = compute_kv("You are a helpful assistant.")
answer_1 = generate_with_cache(prefix_cache, "What is 2+2?")
answer_2 = generate_with_cache(prefix_cache, "What is 3+3?")vLLM 原生支持这个功能。
6.3 Chunked Prefill
处理长 prompt 时,一次性计算所有 token 的 KV 可能导致延迟尖峰。
Chunked Prefill 把 prefill 阶段分成多个小块,交错执行,让延迟更平滑。
七、实战:使用不同的 Attention 实现
7.1 PyTorch 原生(推荐)
import torch
import torch.nn.functional as F
def attention_comparison(Q, K, V):
"""比较不同 Attention 实现"""
# 1. 标准实现
def standard_attention(Q, K, V):
d_k = Q.shape[-1]
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
attn_weights = F.softmax(scores, dim=-1)
return torch.matmul(attn_weights, V)
# 2. PyTorch 2.0+ 的 scaled_dot_product_attention
# 会自动选择最优实现(Flash Attention 如果可用)
def sdpa_attention(Q, K, V):
return F.scaled_dot_product_attention(Q, K, V)
# 比较结果
out_standard = standard_attention(Q, K, V)
out_sdpa = sdpa_attention(Q, K, V)
print(f"结果一致: {torch.allclose(out_standard, out_sdpa, atol=1e-5)}")
return out_standard, out_sdpa
# 测试
batch, heads, seq_len, d_head = 2, 8, 1024, 64
Q = torch.randn(batch, heads, seq_len, d_head, device='cuda')
K = torch.randn(batch, heads, seq_len, d_head, device='cuda')
V = torch.randn(batch, heads, seq_len, d_head, device='cuda')
out1, out2 = attention_comparison(Q, K, V)7.2 使用 Flash Attention 库
# 安装: pip install flash-attn --no-build-isolation
from flash_attn import flash_attn_func, flash_attn_qkvpacked_func
def use_flash_attention():
batch, seq_len, heads, d_head = 2, 4096, 32, 128
# 方式 1:分开的 Q, K, V
q = torch.randn(batch, seq_len, heads, d_head, device='cuda', dtype=torch.float16)
k = torch.randn(batch, seq_len, heads, d_head, device='cuda', dtype=torch.float16)
v = torch.randn(batch, seq_len, heads, d_head, device='cuda', dtype=torch.float16)
output = flash_attn_func(q, k, v, causal=True)
print(f"Output shape: {output.shape}")
# 方式 2:打包的 QKV
qkv = torch.randn(batch, seq_len, 3, heads, d_head, device='cuda', dtype=torch.float16)
output = flash_attn_qkvpacked_func(qkv, causal=True)
return output
# 性能测试
import time
def benchmark_attention(seq_lengths=[1024, 4096, 16384]):
for seq_len in seq_lengths:
q = torch.randn(1, seq_len, 32, 128, device='cuda', dtype=torch.float16)
k = torch.randn(1, seq_len, 32, 128, device='cuda', dtype=torch.float16)
v = torch.randn(1, seq_len, 32, 128, device='cuda', dtype=torch.float16)
# Warmup
for _ in range(10):
_ = flash_attn_func(q, k, v, causal=True)
torch.cuda.synchronize()
start = time.time()
for _ in range(100):
_ = flash_attn_func(q, k, v, causal=True)
torch.cuda.synchronize()
elapsed = (time.time() - start) / 100 * 1000
print(f"Seq length {seq_len}: {elapsed:.2f} ms")7.3 在 Transformers 中使用
from transformers import AutoModelForCausalLM, AutoTokenizer
# 加载模型时指定 attention 实现
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
torch_dtype=torch.float16,
device_map="auto",
attn_implementation="flash_attention_2" # 使用 Flash Attention 2
)
# 或者使用 SDPA(PyTorch 原生)
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
torch_dtype=torch.float16,
device_map="auto",
attn_implementation="sdpa" # 使用 PyTorch SDPA
)
# 检查当前使用的 attention 实现
print(f"Attention implementation: {model.config._attn_implementation}")八、总结:如何选择?
+ GQA"] Q2 -->|否| A2["Flash Attention"] Q3 -->|推理| A3["Flash Attention
+ KV Cache 优化"] Q3 -->|训练| A4["Flash Attention
+ 混合精度"] A1 --> Final["搞定!"] A2 --> Final A3 --> Final A4 --> Final
实践建议
| 场景 | 推荐方案 |
|---|---|
| 一般使用 | PyTorch 2.0+ 的 SDPA(自动优化) |
| 需要最佳性能 | flash-attn 库 |
| 长序列(>8K) | Flash Attention + GQA |
| 高并发推理 | vLLM(内置各种优化) |
| 边缘部署 | 量化 + GQA |
关键 Takeaway
- Attention 的瓶颈是内存带宽,不是计算量
- Flash Attention 是目前最实用的优化,几乎无损
- GQA 可以大幅减少 KV Cache,LLaMA 2+ 都在用
- 稀疏注意力和线性注意力有学术价值,但工程应用有限
- PyTorch 2.0+ 的 SDPA 会自动选择最优实现,对新手友好