前言:一切的起点
如果要评选"21 世纪最重要的 AI 论文",《Attention Is All You Need》绝对是头部选手。
GPT、BERT、LLaMA、Claude、DeepSeek……这些你听过的大模型,全部基于 Transformer 架构。
没有 Transformer,就没有今天的 AI 浪潮。
这篇文章,我们来彻底搞懂它。
一、在 Transformer 之前:RNN 的痛苦
1.1 序列建模的老大哥:RNN
在 Transformer 出现之前,处理序列数据(文本、语音、时间序列)的标准答案是 RNN(循环神经网络) 及其变体 LSTM、GRU。
RNN 的核心思想很直觉:按顺序处理,把前面的信息传递给后面。
看起来很美好对吧?但实际上,RNN 有三个致命问题:
1.2 RNN 的三宗罪
罪一:无法并行计算
RNN 必须一个词一个词地处理,$h_t$ 依赖 $h_{t-1}$。
处理 "我爱学习Transformer" 需要:
第1步:处理 "我" → h1
第2步:处理 "爱" → h2(必须等 h1 算完)
第3步:处理 "学" → h3(必须等 h2 算完)
...这意味着:100 个词的句子,需要 100 个时间步。GPU 的并行能力完全浪费了。
罪二:长距离依赖困难
理论上,RNN 可以记住任意长的历史。实际上?
"我出生在中国,在这里生活了三十年,所以我的母语是______"当模型处理到"母语是"的时候,"中国"这个关键信息已经经过了十几个时间步的传递,早就被"稀释"得差不多了。
这就是著名的 梯度消失/爆炸 问题。LSTM 和 GRU 缓解了这个问题,但没有根治。
罪三:信息瓶颈
所有历史信息都要压缩到一个固定大小的隐状态向量里。想象一下把一本书的内容压缩到一个 512 维的向量——信息损失是必然的。
1.3 Attention 的曙光
2014-2015 年,研究者们开始给 RNN 加上 Attention 机制:与其让模型自己记住所有信息,不如让它在需要的时候回头看。
这个思路很成功,机器翻译的效果大幅提升。
但 Google 的研究者想得更远:既然 Attention 这么好用,我们为什么还需要 RNN?
于是,Transformer 诞生了。
二、Transformer 核心:Attention 机制
2.1 Attention 的直觉
想象你在读这句话:
"The animal didn't cross the street because it was too tired."
当你读到 "it" 的时候,你的大脑会自动"回看"前文,判断 "it" 指的是 "animal" 还是 "street"。
这就是 Attention(注意力):让模型在处理每个位置时,能够"关注"到输入的其他位置。
2.2 Self-Attention:数学定义
Attention 的核心是三个向量:Query(查询)、Key(键)、Value(值)。
对于输入序列中的每个词,我们都生成这三个向量:
$$ Q = X W^Q, \quad K = X W^K, \quad V = X W^V $$
其中 $X$ 是输入的词向量矩阵,$W^Q, W^K, W^V$ 是可学习的参数矩阵。
然后,Attention 的计算公式是:
$$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V $$
别被公式吓到,我们一步步拆解:
Step 1: 计算相似度 $QK^T$
$Q$ 和 $K$ 的点积衡量的是"这个 Query 和那个 Key 有多相关"。
# 假设序列长度为 4,维度为 64
Q = [q1, q2, q3, q4] # 4 个 Query 向量
K = [k1, k2, k3, k4] # 4 个 Key 向量
# QK^T 得到 4x4 的相似度矩阵
scores = [
[q1·k1, q1·k2, q1·k3, q1·k4], # q1 和所有 k 的相似度
[q2·k1, q2·k2, q2·k3, q2·k4], # q2 和所有 k 的相似度
[q3·k1, q3·k2, q3·k3, q3·k4],
[q4·k1, q4·k2, q4·k3, q4·k4],
]Step 2: Scale 缩放 $\frac{1}{\sqrt{d_k}}$
为什么要除以 $\sqrt{d_k}$?
当维度 $d_k$ 很大时,点积的结果会很大,导致 softmax 的梯度非常小(接近 one-hot)。缩放可以让训练更稳定。
Step 3: Softmax 归一化
把相似度分数转换成概率分布(和为 1):
# 假设某行的 scores 是 [2.0, 1.0, 0.5, 0.1]
# softmax 后变成 [0.48, 0.18, 0.11, 0.07] # 和为 1Step 4: 加权求和 Value
用 softmax 得到的权重,对 Value 向量加权求和:
output = 0.48 * v1 + 0.18 * v2 + 0.11 * v3 + 0.07 * v4这样,输出就融合了"被关注"的信息。
2.3 一图胜千言
2.4 手撕 Self-Attention 代码
Talk is cheap, show me the code:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class SelfAttention(nn.Module):
def __init__(self, embed_dim):
super().__init__()
self.embed_dim = embed_dim
# Q, K, V 的线性变换
self.W_q = nn.Linear(embed_dim, embed_dim)
self.W_k = nn.Linear(embed_dim, embed_dim)
self.W_v = nn.Linear(embed_dim, embed_dim)
def forward(self, x):
# x: (batch_size, seq_len, embed_dim)
Q = self.W_q(x) # (batch, seq_len, embed_dim)
K = self.W_k(x)
V = self.W_v(x)
# 计算注意力分数
d_k = self.embed_dim
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
# scores: (batch, seq_len, seq_len)
# Softmax 归一化
attn_weights = F.softmax(scores, dim=-1)
# 加权求和
output = torch.matmul(attn_weights, V)
# output: (batch, seq_len, embed_dim)
return output, attn_weights
# 测试
batch_size, seq_len, embed_dim = 2, 4, 64
x = torch.randn(batch_size, seq_len, embed_dim)
attention = SelfAttention(embed_dim)
output, weights = attention(x)
print(f"输入形状: {x.shape}")
print(f"输出形状: {output.shape}")
print(f"注意力权重形状: {weights.shape}")
# 输入形状: torch.Size([2, 4, 64])
# 输出形状: torch.Size([2, 4, 64])
# 注意力权重形状: torch.Size([2, 4, 4])三、Multi-Head Attention:多角度观察
3.1 为什么需要多头?
单个 Attention 只能学到一种"关注模式"。但语言是复杂的:
- 语法关系:"it" 指代 "animal"
- 语义关系:"tired" 描述 "animal" 的状态
- 位置关系:相邻的词往往相关
Multi-Head Attention 的思路:让模型同时学习多种不同的关注模式。
3.2 数学公式
$$ \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h) W^O $$
其中每个 head:
$$ \text{head}_i = \text{Attention}(Q W_i^Q, K W_i^K, V W_i^V) $$
假设模型维度是 512,使用 8 个 head,那么每个 head 的维度就是 512/8 = 64。
3.3 Multi-Head Attention 代码
class MultiHeadAttention(nn.Module):
def __init__(self, embed_dim, num_heads):
super().__init__()
assert embed_dim % num_heads == 0, "embed_dim 必须能被 num_heads 整除"
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
# Q, K, V 的线性变换(一次性生成所有 head)
self.W_q = nn.Linear(embed_dim, embed_dim)
self.W_k = nn.Linear(embed_dim, embed_dim)
self.W_v = nn.Linear(embed_dim, embed_dim)
# 输出投影
self.W_o = nn.Linear(embed_dim, embed_dim)
def forward(self, x, mask=None):
batch_size, seq_len, _ = x.shape
# 线性变换
Q = self.W_q(x) # (batch, seq_len, embed_dim)
K = self.W_k(x)
V = self.W_v(x)
# 拆分成多个 head: (batch, seq_len, num_heads, head_dim)
Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim)
K = K.view(batch_size, seq_len, self.num_heads, self.head_dim)
V = V.view(batch_size, seq_len, self.num_heads, self.head_dim)
# 转置: (batch, num_heads, seq_len, head_dim)
Q = Q.transpose(1, 2)
K = K.transpose(1, 2)
V = V.transpose(1, 2)
# 计算注意力
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn_weights = F.softmax(scores, dim=-1)
attn_output = torch.matmul(attn_weights, V)
# 合并多个 head: (batch, seq_len, embed_dim)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, seq_len, self.embed_dim)
# 输出投影
output = self.W_o(attn_output)
return output, attn_weights四、Position Encoding:告诉模型"顺序"
4.1 Attention 的致命缺陷
Self-Attention 有一个问题:它是"顺序无关"的。
输入 "我爱你" 和 "你爱我"
如果没有位置信息,Attention 认为它们是一样的!这在 RNN 中不是问题(天然按顺序处理),但 Transformer 需要额外的机制来编码位置信息。
4.2 正弦位置编码
原论文使用的是正弦/余弦函数:
$$ PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{model}}}\right) $$
$$ PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{model}}}\right) $$
其中 $pos$ 是位置,$i$ 是维度索引。
为什么用正弦函数?
- 值域有界:始终在 [-1, 1] 之间
- 相对位置可学习:$PE_{pos+k}$ 可以表示为 $PE_{pos}$ 的线性函数
- 可扩展:理论上可以处理任意长度的序列
4.3 Position Encoding 代码
class PositionalEncoding(nn.Module):
def __init__(self, embed_dim, max_len=5000):
super().__init__()
# 创建位置编码矩阵
pe = torch.zeros(max_len, embed_dim)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
# 计算 div_term: 10000^(2i/d_model)
div_term = torch.exp(
torch.arange(0, embed_dim, 2).float() *
(-math.log(10000.0) / embed_dim)
)
# 偶数维度用 sin,奇数维度用 cos
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
# 增加 batch 维度: (1, max_len, embed_dim)
pe = pe.unsqueeze(0)
# 注册为 buffer(不参与训练)
self.register_buffer('pe', pe)
def forward(self, x):
# x: (batch_size, seq_len, embed_dim)
seq_len = x.size(1)
return x + self.pe[:, :seq_len, :]4.4 RoPE:旋转位置编码(现代主流)
现在的大模型(LLaMA、Qwen、DeepSeek)普遍使用 RoPE(Rotary Position Embedding)。
RoPE 的核心思想:通过旋转变换来编码相对位置。
$$ f_q(x_m, m) = (W_q x_m) e^{im\theta} $$
优点:
- 更好的相对位置建模
- 可以外推到更长的序列
- 与 Attention 的点积天然兼容
五、完整的 Transformer 架构
5.1 整体结构
原版 Transformer 是 Encoder-Decoder 结构:
5.2 关键组件详解
Feed Forward Network (FFN)
两层全连接,中间用 ReLU(或 GELU)激活:
$$ \text{FFN}(x) = \max(0, xW_1 + b_1)W_2 + b_2 $$
通常中间层维度是输入的 4 倍(512 → 2048 → 512)。
class FeedForward(nn.Module):
def __init__(self, embed_dim, ff_dim):
super().__init__()
self.linear1 = nn.Linear(embed_dim, ff_dim)
self.linear2 = nn.Linear(ff_dim, embed_dim)
self.activation = nn.GELU()
def forward(self, x):
return self.linear2(self.activation(self.linear1(x)))Add & Norm(残差连接 + 层归一化)
# 残差连接:让梯度更容易流动
output = LayerNorm(x + Sublayer(x))Masked Attention
在 Decoder 中,为了保证自回归(只能看到之前的 token),需要用 mask 遮住未来的位置:
# 创建因果 mask
def create_causal_mask(seq_len):
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
return mask == 0 # True 表示可以 attend
# 例如 seq_len = 4:
# [[True, False, False, False],
# [True, True, False, False],
# [True, True, True, False],
# [True, True, True, True]]5.3 三种 Transformer 变体
现代大模型几乎都是 Decoder-Only 架构,因为:
- 更适合生成任务
- 架构更简单,scaling 效果更好
- 可以统一处理各种任务(通过 prompt)
六、为什么 Transformer 如此成功?
6.1 并行计算能力
RNN 必须顺序计算,Transformer 可以一次性处理整个序列:
| 模型 | 处理 1000 词 | 计算复杂度 |
|---|---|---|
| RNN | 1000 步串行 | O(n) 时间步 |
| Transformer | 1 步并行 | O(1) 时间步 |
这让 Transformer 可以充分利用 GPU 的并行能力,训练速度大幅提升。
6.2 长距离依赖建模
RNN 需要信息"逐步传递",Transformer 可以"直接连接":
任意两个位置之间的"距离"都是 1,长距离依赖不再是问题。
6.3 Scaling Law:大力出奇迹
Transformer 有一个神奇的特性:模型越大、数据越多、效果越好,而且是可预测的。
这给了研究者一个清晰的路线图:只要有足够的算力和数据,就能持续提升性能。
6.4 通用性:一统 NLP 和 CV
Transformer 最初为 NLP 设计,但很快被证明在其他领域也有效:
| 领域 | 代表模型 |
|---|---|
| 自然语言 | GPT、BERT、LLaMA |
| 计算机视觉 | ViT、CLIP、DALL-E |
| 语音 | Whisper、wav2vec |
| 多模态 | GPT-4V、Gemini |
| 蛋白质结构 | AlphaFold |
一个架构统一所有领域,这是前所未有的。
6.5 涌现能力
当模型足够大时,会出现一些"涌现能力"——小模型完全做不到,大模型突然就会了:
- 上下文学习(In-Context Learning)
- 思维链推理(Chain-of-Thought)
- 指令遵循(Instruction Following)
这些能力不是"训练"出来的,而是"涌现"出来的。
七、Transformer 的计算复杂度与优化
7.1 计算瓶颈:$O(n^2)$ 的注意力
Self-Attention 的计算复杂度是 $O(n^2 \cdot d)$,其中 $n$ 是序列长度。
这意味着:
- 序列长度翻倍,计算量变成 4 倍
- 处理 100K token 的长文本非常昂贵
7.2 优化方案
| 方法 | 原理 | 代表工作 |
|---|---|---|
| 稀疏注意力 | 只计算部分位置的 attention | Sparse Transformer, Longformer |
| 线性注意力 | 用核函数近似,降到 O(n) | Linear Attention, RWKV |
| Flash Attention | 优化内存访问模式 | FlashAttention 1/2/3 |
| 分组查询注意力 | 多个 Query 共享 K/V | GQA (LLaMA 2) |
| 滑动窗口 | 只关注局部窗口 | Mistral |
7.3 KV Cache:推理加速
在自回归生成时,每一步都需要计算之前所有 token 的 K、V。KV Cache 把它们缓存起来,避免重复计算:
# 没有 KV Cache:每步都重新计算
step 1: 计算 K_1, V_1
step 2: 计算 K_1, K_2, V_1, V_2 # K_1, V_1 重复了!
step 3: 计算 K_1, K_2, K_3, V_1, V_2, V_3 # 又重复了!
# 有 KV Cache:缓存之前的结果
step 1: 计算并缓存 K_1, V_1
step 2: 只计算 K_2, V_2,拼接缓存的 K_1, V_1
step 3: 只计算 K_3, V_3,拼接缓存的 K_1, K_2, V_1, V_2KV Cache 是推理优化的关键,但也带来显存压力(缓存随序列长度线性增长)。
八、完整代码:从零实现 Transformer
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MultiHeadAttention(nn.Module):
"""多头注意力"""
def __init__(self, embed_dim, num_heads, dropout=0.1):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.W_q = nn.Linear(embed_dim, embed_dim)
self.W_k = nn.Linear(embed_dim, embed_dim)
self.W_v = nn.Linear(embed_dim, embed_dim)
self.W_o = nn.Linear(embed_dim, embed_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, query, key, value, mask=None):
batch_size = query.size(0)
# 线性变换
Q = self.W_q(query)
K = self.W_k(key)
V = self.W_v(value)
# 拆分成多头
Q = Q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
K = K.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
V = V.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
# 计算注意力
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn_weights = F.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
attn_output = torch.matmul(attn_weights, V)
# 合并多头
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, -1, self.embed_dim)
return self.W_o(attn_output)
class FeedForward(nn.Module):
"""前馈神经网络"""
def __init__(self, embed_dim, ff_dim, dropout=0.1):
super().__init__()
self.linear1 = nn.Linear(embed_dim, ff_dim)
self.linear2 = nn.Linear(ff_dim, embed_dim)
self.dropout = nn.Dropout(dropout)
self.activation = nn.GELU()
def forward(self, x):
x = self.activation(self.linear1(x))
x = self.dropout(x)
return self.linear2(x)
class TransformerEncoderLayer(nn.Module):
"""Transformer 编码器层"""
def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
super().__init__()
self.self_attn = MultiHeadAttention(embed_dim, num_heads, dropout)
self.ffn = FeedForward(embed_dim, ff_dim, dropout)
self.norm1 = nn.LayerNorm(embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# Self-Attention + 残差
attn_output = self.self_attn(x, x, x, mask)
x = self.norm1(x + self.dropout(attn_output))
# FFN + 残差
ffn_output = self.ffn(x)
x = self.norm2(x + self.dropout(ffn_output))
return x
class TransformerDecoderLayer(nn.Module):
"""Transformer 解码器层"""
def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
super().__init__()
self.self_attn = MultiHeadAttention(embed_dim, num_heads, dropout)
self.cross_attn = MultiHeadAttention(embed_dim, num_heads, dropout)
self.ffn = FeedForward(embed_dim, ff_dim, dropout)
self.norm1 = nn.LayerNorm(embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)
self.norm3 = nn.LayerNorm(embed_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
# Masked Self-Attention
attn_output = self.self_attn(x, x, x, tgt_mask)
x = self.norm1(x + self.dropout(attn_output))
# Cross-Attention
attn_output = self.cross_attn(x, encoder_output, encoder_output, src_mask)
x = self.norm2(x + self.dropout(attn_output))
# FFN
ffn_output = self.ffn(x)
x = self.norm3(x + self.dropout(ffn_output))
return x
class PositionalEncoding(nn.Module):
"""正弦位置编码"""
def __init__(self, embed_dim, max_len=5000, dropout=0.1):
super().__init__()
self.dropout = nn.Dropout(dropout)
pe = torch.zeros(max_len, embed_dim)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, embed_dim, 2).float() * (-math.log(10000.0) / embed_dim)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:, :x.size(1), :]
return self.dropout(x)
class Transformer(nn.Module):
"""完整的 Transformer 模型"""
def __init__(
self,
src_vocab_size,
tgt_vocab_size,
embed_dim=512,
num_heads=8,
num_encoder_layers=6,
num_decoder_layers=6,
ff_dim=2048,
dropout=0.1,
max_len=5000
):
super().__init__()
# Embedding
self.src_embedding = nn.Embedding(src_vocab_size, embed_dim)
self.tgt_embedding = nn.Embedding(tgt_vocab_size, embed_dim)
self.pos_encoding = PositionalEncoding(embed_dim, max_len, dropout)
# Encoder
self.encoder_layers = nn.ModuleList([
TransformerEncoderLayer(embed_dim, num_heads, ff_dim, dropout)
for _ in range(num_encoder_layers)
])
# Decoder
self.decoder_layers = nn.ModuleList([
TransformerDecoderLayer(embed_dim, num_heads, ff_dim, dropout)
for _ in range(num_decoder_layers)
])
# Output
self.fc_out = nn.Linear(embed_dim, tgt_vocab_size)
self.embed_dim = embed_dim
def encode(self, src, src_mask=None):
x = self.src_embedding(src) * math.sqrt(self.embed_dim)
x = self.pos_encoding(x)
for layer in self.encoder_layers:
x = layer(x, src_mask)
return x
def decode(self, tgt, encoder_output, src_mask=None, tgt_mask=None):
x = self.tgt_embedding(tgt) * math.sqrt(self.embed_dim)
x = self.pos_encoding(x)
for layer in self.decoder_layers:
x = layer(x, encoder_output, src_mask, tgt_mask)
return x
def forward(self, src, tgt, src_mask=None, tgt_mask=None):
encoder_output = self.encode(src, src_mask)
decoder_output = self.decode(tgt, encoder_output, src_mask, tgt_mask)
output = self.fc_out(decoder_output)
return output
# 使用示例
if __name__ == "__main__":
# 超参数
src_vocab_size = 10000
tgt_vocab_size = 10000
embed_dim = 512
num_heads = 8
num_layers = 6
# 创建模型
model = Transformer(
src_vocab_size=src_vocab_size,
tgt_vocab_size=tgt_vocab_size,
embed_dim=embed_dim,
num_heads=num_heads,
num_encoder_layers=num_layers,
num_decoder_layers=num_layers
)
# 模拟输入
batch_size = 2
src_len = 10
tgt_len = 8
src = torch.randint(0, src_vocab_size, (batch_size, src_len))
tgt = torch.randint(0, tgt_vocab_size, (batch_size, tgt_len))
# 前向传播
output = model(src, tgt)
print(f"输入源序列: {src.shape}")
print(f"输入目标序列: {tgt.shape}")
print(f"输出形状: {output.shape}")
# 计算参数量
total_params = sum(p.numel() for p in model.parameters())
print(f"模型参数量: {total_params:,}")九、总结:Transformer 的核心洞见
关键 Takeaway
- Attention Is All You Need:用 Attention 替代 RNN,获得并行能力和长距离建模能力
- Multi-Head:让模型学习多种不同的关注模式
- Position Encoding:补充 Attention 缺失的位置信息
- 残差 + LayerNorm:让深层网络可以训练
Scaling Law:大力出奇迹是可预测的
参考资料
- Attention Is All You Need - 原论文
- The Illustrated Transformer - 可视化讲解
- The Annotated Transformer - 带注释的代码
- Let's build GPT - Karpathy 的视频教程