搜 索

Transformer架构:从入门到放弃

  • 15阅读
  • 2025年02月08日
  • 0评论
首页 / AI/大数据 / 正文

前言:一切的起点

如果要评选"21 世纪最重要的 AI 论文",《Attention Is All You Need》绝对是头部选手。

GPT、BERT、LLaMA、Claude、DeepSeek……这些你听过的大模型,全部基于 Transformer 架构。

没有 Transformer,就没有今天的 AI 浪潮。

这篇文章,我们来彻底搞懂它。

timeline title Transformer 的统治之路 2017 : Transformer 诞生 : "Attention Is All You Need" 2018 : BERT 发布 : NLP 全面拥抱 Transformer 2019 : GPT-2 震惊世界 : "太危险了不敢开源" 2020 : GPT-3 涌现能力 : 1750 亿参数 2021 : ViT 证明视觉也行 : Transformer 统一 CV 2022 : ChatGPT 出圈 : 全民 AI 时代开启 2023-2025 : GPT-4/Claude/DeepSeek : 多模态、推理、Agent

一、在 Transformer 之前:RNN 的痛苦

1.1 序列建模的老大哥:RNN

在 Transformer 出现之前,处理序列数据(文本、语音、时间序列)的标准答案是 RNN(循环神经网络) 及其变体 LSTM、GRU。

RNN 的核心思想很直觉:按顺序处理,把前面的信息传递给后面

graph LR subgraph RNN 处理流程 X1[我] --> H1[隐状态1] X2[爱] --> H2[隐状态2] X3[学] --> H3[隐状态3] X4[习] --> H4[隐状态4] H1 --> H2 --> H3 --> H4 end H4 --> Y[输出]

看起来很美好对吧?但实际上,RNN 有三个致命问题:

1.2 RNN 的三宗罪

罪一:无法并行计算

RNN 必须一个词一个词地处理,$h_t$ 依赖 $h_{t-1}$。

处理 "我爱学习Transformer" 需要:
第1步:处理 "我" → h1
第2步:处理 "爱" → h2(必须等 h1 算完)
第3步:处理 "学" → h3(必须等 h2 算完)
...

这意味着:100 个词的句子,需要 100 个时间步。GPU 的并行能力完全浪费了。

罪二:长距离依赖困难

理论上,RNN 可以记住任意长的历史。实际上?

"我出生在中国,在这里生活了三十年,所以我的母语是______"

当模型处理到"母语是"的时候,"中国"这个关键信息已经经过了十几个时间步的传递,早就被"稀释"得差不多了。

这就是著名的 梯度消失/爆炸 问题。LSTM 和 GRU 缓解了这个问题,但没有根治。

罪三:信息瓶颈

所有历史信息都要压缩到一个固定大小的隐状态向量里。想象一下把一本书的内容压缩到一个 512 维的向量——信息损失是必然的。

graph LR subgraph 信息瓶颈 A[100个词的输入] --> B[512维隐状态] B --> C[期望记住所有信息?] end style B fill:#ff6b6b

1.3 Attention 的曙光

2014-2015 年,研究者们开始给 RNN 加上 Attention 机制:与其让模型自己记住所有信息,不如让它在需要的时候回头看

这个思路很成功,机器翻译的效果大幅提升。

但 Google 的研究者想得更远:既然 Attention 这么好用,我们为什么还需要 RNN?

于是,Transformer 诞生了。

二、Transformer 核心:Attention 机制

2.1 Attention 的直觉

想象你在读这句话:

"The animal didn't cross the street because it was too tired."

当你读到 "it" 的时候,你的大脑会自动"回看"前文,判断 "it" 指的是 "animal" 还是 "street"。

这就是 Attention(注意力):让模型在处理每个位置时,能够"关注"到输入的其他位置。

graph TB subgraph 注意力可视化 W1[The] W2[animal] W3[didn't] W4[cross] W5[the] W6[street] W7[because] W8[it] W9[was] W10[tired] end W8 -->|强关注| W2 W8 -.->|弱关注| W6 W8 -.->|弱关注| W1 style W2 fill:#4ecdc4 style W8 fill:#ff6b6b

2.2 Self-Attention:数学定义

Attention 的核心是三个向量:Query(查询)Key(键)Value(值)

对于输入序列中的每个词,我们都生成这三个向量:

$$ Q = X W^Q, \quad K = X W^K, \quad V = X W^V $$

其中 $X$ 是输入的词向量矩阵,$W^Q, W^K, W^V$ 是可学习的参数矩阵。

然后,Attention 的计算公式是:

$$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V $$

别被公式吓到,我们一步步拆解:

Step 1: 计算相似度 $QK^T$

$Q$ 和 $K$ 的点积衡量的是"这个 Query 和那个 Key 有多相关"。

# 假设序列长度为 4,维度为 64
Q = [q1, q2, q3, q4]  # 4 个 Query 向量
K = [k1, k2, k3, k4]  # 4 个 Key 向量

# QK^T 得到 4x4 的相似度矩阵
scores = [
    [q1·k1, q1·k2, q1·k3, q1·k4],  # q1 和所有 k 的相似度
    [q2·k1, q2·k2, q2·k3, q2·k4],  # q2 和所有 k 的相似度
    [q3·k1, q3·k2, q3·k3, q3·k4],
    [q4·k1, q4·k2, q4·k3, q4·k4],
]

Step 2: Scale 缩放 $\frac{1}{\sqrt{d_k}}$

为什么要除以 $\sqrt{d_k}$?

当维度 $d_k$ 很大时,点积的结果会很大,导致 softmax 的梯度非常小(接近 one-hot)。缩放可以让训练更稳定。

Step 3: Softmax 归一化

把相似度分数转换成概率分布(和为 1):

# 假设某行的 scores 是 [2.0, 1.0, 0.5, 0.1]
# softmax 后变成    [0.48, 0.18, 0.11, 0.07]  # 和为 1

Step 4: 加权求和 Value

用 softmax 得到的权重,对 Value 向量加权求和:

output = 0.48 * v1 + 0.18 * v2 + 0.11 * v3 + 0.07 * v4

这样,输出就融合了"被关注"的信息。

2.3 一图胜千言

flowchart TB subgraph 输入 X[输入 X] end subgraph 线性变换 X --> WQ[W_Q] X --> WK[W_K] X --> WV[W_V] WQ --> Q[Query] WK --> K[Key] WV --> V[Value] end subgraph Attention计算 Q --> MatMul1[矩阵乘法] K --> MatMul1 MatMul1 --> Scale[缩放 ÷√dk] Scale --> Softmax[Softmax] Softmax --> MatMul2[矩阵乘法] V --> MatMul2 end MatMul2 --> Output[输出]

2.4 手撕 Self-Attention 代码

Talk is cheap, show me the code:

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class SelfAttention(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.embed_dim = embed_dim
        
        # Q, K, V 的线性变换
        self.W_q = nn.Linear(embed_dim, embed_dim)
        self.W_k = nn.Linear(embed_dim, embed_dim)
        self.W_v = nn.Linear(embed_dim, embed_dim)
    
    def forward(self, x):
        # x: (batch_size, seq_len, embed_dim)
        
        Q = self.W_q(x)  # (batch, seq_len, embed_dim)
        K = self.W_k(x)
        V = self.W_v(x)
        
        # 计算注意力分数
        d_k = self.embed_dim
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
        # scores: (batch, seq_len, seq_len)
        
        # Softmax 归一化
        attn_weights = F.softmax(scores, dim=-1)
        
        # 加权求和
        output = torch.matmul(attn_weights, V)
        # output: (batch, seq_len, embed_dim)
        
        return output, attn_weights

# 测试
batch_size, seq_len, embed_dim = 2, 4, 64
x = torch.randn(batch_size, seq_len, embed_dim)

attention = SelfAttention(embed_dim)
output, weights = attention(x)

print(f"输入形状: {x.shape}")
print(f"输出形状: {output.shape}")
print(f"注意力权重形状: {weights.shape}")
# 输入形状: torch.Size([2, 4, 64])
# 输出形状: torch.Size([2, 4, 64])
# 注意力权重形状: torch.Size([2, 4, 4])

三、Multi-Head Attention:多角度观察

3.1 为什么需要多头?

单个 Attention 只能学到一种"关注模式"。但语言是复杂的:

  • 语法关系:"it" 指代 "animal"
  • 语义关系:"tired" 描述 "animal" 的状态
  • 位置关系:相邻的词往往相关

Multi-Head Attention 的思路:让模型同时学习多种不同的关注模式。

graph TB subgraph Multi-Head Attention Input[输入 X] Input --> H1[Head 1: 学习语法关系] Input --> H2[Head 2: 学习语义关系] Input --> H3[Head 3: 学习位置关系] Input --> H4[Head 4: ...] H1 --> Concat[拼接] H2 --> Concat H3 --> Concat H4 --> Concat Concat --> Linear[线性变换] Linear --> Output[输出] end

3.2 数学公式

$$ \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h) W^O $$

其中每个 head:

$$ \text{head}_i = \text{Attention}(Q W_i^Q, K W_i^K, V W_i^V) $$

假设模型维度是 512,使用 8 个 head,那么每个 head 的维度就是 512/8 = 64。

3.3 Multi-Head Attention 代码

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        assert embed_dim % num_heads == 0, "embed_dim 必须能被 num_heads 整除"
        
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        # Q, K, V 的线性变换(一次性生成所有 head)
        self.W_q = nn.Linear(embed_dim, embed_dim)
        self.W_k = nn.Linear(embed_dim, embed_dim)
        self.W_v = nn.Linear(embed_dim, embed_dim)
        
        # 输出投影
        self.W_o = nn.Linear(embed_dim, embed_dim)
    
    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape
        
        # 线性变换
        Q = self.W_q(x)  # (batch, seq_len, embed_dim)
        K = self.W_k(x)
        V = self.W_v(x)
        
        # 拆分成多个 head: (batch, seq_len, num_heads, head_dim)
        Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim)
        K = K.view(batch_size, seq_len, self.num_heads, self.head_dim)
        V = V.view(batch_size, seq_len, self.num_heads, self.head_dim)
        
        # 转置: (batch, num_heads, seq_len, head_dim)
        Q = Q.transpose(1, 2)
        K = K.transpose(1, 2)
        V = V.transpose(1, 2)
        
        # 计算注意力
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        attn_weights = F.softmax(scores, dim=-1)
        attn_output = torch.matmul(attn_weights, V)
        
        # 合并多个 head: (batch, seq_len, embed_dim)
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(batch_size, seq_len, self.embed_dim)
        
        # 输出投影
        output = self.W_o(attn_output)
        
        return output, attn_weights

四、Position Encoding:告诉模型"顺序"

4.1 Attention 的致命缺陷

Self-Attention 有一个问题:它是"顺序无关"的

输入 "我爱你" 和 "你爱我"
如果没有位置信息,Attention 认为它们是一样的!

这在 RNN 中不是问题(天然按顺序处理),但 Transformer 需要额外的机制来编码位置信息。

4.2 正弦位置编码

原论文使用的是正弦/余弦函数:

$$ PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{model}}}\right) $$

$$ PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{model}}}\right) $$

其中 $pos$ 是位置,$i$ 是维度索引。

为什么用正弦函数?

  1. 值域有界:始终在 [-1, 1] 之间
  2. 相对位置可学习:$PE_{pos+k}$ 可以表示为 $PE_{pos}$ 的线性函数
  3. 可扩展:理论上可以处理任意长度的序列
graph LR subgraph 位置编码 P0[位置0] --> E0[PE_0] P1[位置1] --> E1[PE_1] P2[位置2] --> E2[PE_2] P3[位置3] --> E3[PE_3] end subgraph 词嵌入 W0[我] --> WE0[词向量_0] W1[爱] --> WE1[词向量_1] W2[学] --> WE2[词向量_2] W3[习] --> WE3[词向量_3] end E0 --> Add0((+)) WE0 --> Add0 E1 --> Add1((+)) WE1 --> Add1 E2 --> Add2((+)) WE2 --> Add2 E3 --> Add3((+)) WE3 --> Add3 Add0 --> Input[输入到 Transformer] Add1 --> Input Add2 --> Input Add3 --> Input

4.3 Position Encoding 代码

class PositionalEncoding(nn.Module):
    def __init__(self, embed_dim, max_len=5000):
        super().__init__()
        
        # 创建位置编码矩阵
        pe = torch.zeros(max_len, embed_dim)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        
        # 计算 div_term: 10000^(2i/d_model)
        div_term = torch.exp(
            torch.arange(0, embed_dim, 2).float() * 
            (-math.log(10000.0) / embed_dim)
        )
        
        # 偶数维度用 sin,奇数维度用 cos
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        # 增加 batch 维度: (1, max_len, embed_dim)
        pe = pe.unsqueeze(0)
        
        # 注册为 buffer(不参与训练)
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        # x: (batch_size, seq_len, embed_dim)
        seq_len = x.size(1)
        return x + self.pe[:, :seq_len, :]

4.4 RoPE:旋转位置编码(现代主流)

现在的大模型(LLaMA、Qwen、DeepSeek)普遍使用 RoPE(Rotary Position Embedding)

RoPE 的核心思想:通过旋转变换来编码相对位置

$$ f_q(x_m, m) = (W_q x_m) e^{im\theta} $$

优点:

  • 更好的相对位置建模
  • 可以外推到更长的序列
  • 与 Attention 的点积天然兼容

五、完整的 Transformer 架构

5.1 整体结构

原版 Transformer 是 Encoder-Decoder 结构:

graph TB subgraph Encoder["Encoder (N层)"] E_Input[输入嵌入 + 位置编码] E_MHA[Multi-Head Self-Attention] E_Add1[Add & Norm] E_FFN[Feed Forward Network] E_Add2[Add & Norm] E_Input --> E_MHA E_MHA --> E_Add1 E_Input -.-> E_Add1 E_Add1 --> E_FFN E_FFN --> E_Add2 E_Add1 -.-> E_Add2 end subgraph Decoder["Decoder (N层)"] D_Input[输出嵌入 + 位置编码] D_MHA1[Masked Multi-Head Self-Attention] D_Add1[Add & Norm] D_MHA2[Multi-Head Cross-Attention] D_Add2[Add & Norm] D_FFN[Feed Forward Network] D_Add3[Add & Norm] D_Input --> D_MHA1 D_MHA1 --> D_Add1 D_Input -.-> D_Add1 D_Add1 --> D_MHA2 E_Add2 --> D_MHA2 D_MHA2 --> D_Add2 D_Add1 -.-> D_Add2 D_Add2 --> D_FFN D_FFN --> D_Add3 D_Add2 -.-> D_Add3 end D_Add3 --> Linear[线性层] Linear --> Softmax[Softmax] Softmax --> Output[输出概率]

5.2 关键组件详解

Feed Forward Network (FFN)

两层全连接,中间用 ReLU(或 GELU)激活:

$$ \text{FFN}(x) = \max(0, xW_1 + b_1)W_2 + b_2 $$

通常中间层维度是输入的 4 倍(512 → 2048 → 512)。

class FeedForward(nn.Module):
    def __init__(self, embed_dim, ff_dim):
        super().__init__()
        self.linear1 = nn.Linear(embed_dim, ff_dim)
        self.linear2 = nn.Linear(ff_dim, embed_dim)
        self.activation = nn.GELU()
    
    def forward(self, x):
        return self.linear2(self.activation(self.linear1(x)))

Add & Norm(残差连接 + 层归一化)

# 残差连接:让梯度更容易流动
output = LayerNorm(x + Sublayer(x))

Masked Attention

在 Decoder 中,为了保证自回归(只能看到之前的 token),需要用 mask 遮住未来的位置:

# 创建因果 mask
def create_causal_mask(seq_len):
    mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
    return mask == 0  # True 表示可以 attend

# 例如 seq_len = 4:
# [[True, False, False, False],
#  [True, True,  False, False],
#  [True, True,  True,  False],
#  [True, True,  True,  True]]

5.3 三种 Transformer 变体

graph TB subgraph 原版["原版 Transformer (Encoder-Decoder)"] A1[Encoder] --> A2[Decoder] A3[用途: 机器翻译、T5] end subgraph EncoderOnly["Encoder-Only"] B1[Encoder × N] B2[用途: BERT、文本分类、NER] B3[特点: 双向注意力] end subgraph DecoderOnly["Decoder-Only"] C1[Decoder × N] C2[用途: GPT、LLaMA、Claude] C3[特点: 因果注意力、自回归生成] end

现代大模型几乎都是 Decoder-Only 架构,因为:

  1. 更适合生成任务
  2. 架构更简单,scaling 效果更好
  3. 可以统一处理各种任务(通过 prompt)

六、为什么 Transformer 如此成功?

6.1 并行计算能力

RNN 必须顺序计算,Transformer 可以一次性处理整个序列:

模型处理 1000 词计算复杂度
RNN1000 步串行O(n) 时间步
Transformer1 步并行O(1) 时间步

这让 Transformer 可以充分利用 GPU 的并行能力,训练速度大幅提升。

6.2 长距离依赖建模

RNN 需要信息"逐步传递",Transformer 可以"直接连接":

graph LR subgraph RNN A1[词1] --> A2[词2] --> A3[词3] --> A4[...] --> A100[词100] A1 -.->|信息衰减| A100 end subgraph Transformer B1[词1] B100[词100] B1 <-->|直接 Attention| B100 end

任意两个位置之间的"距离"都是 1,长距离依赖不再是问题。

6.3 Scaling Law:大力出奇迹

Transformer 有一个神奇的特性:模型越大、数据越多、效果越好,而且是可预测的。

graph LR subgraph Scaling Law A[模型参数 ↑] --> D[性能 ↑] B[训练数据 ↑] --> D C[计算量 ↑] --> D end

这给了研究者一个清晰的路线图:只要有足够的算力和数据,就能持续提升性能。

6.4 通用性:一统 NLP 和 CV

Transformer 最初为 NLP 设计,但很快被证明在其他领域也有效:

领域代表模型
自然语言GPT、BERT、LLaMA
计算机视觉ViT、CLIP、DALL-E
语音Whisper、wav2vec
多模态GPT-4V、Gemini
蛋白质结构AlphaFold

一个架构统一所有领域,这是前所未有的。

6.5 涌现能力

当模型足够大时,会出现一些"涌现能力"——小模型完全做不到,大模型突然就会了:

  • 上下文学习(In-Context Learning)
  • 思维链推理(Chain-of-Thought)
  • 指令遵循(Instruction Following)

这些能力不是"训练"出来的,而是"涌现"出来的。

七、Transformer 的计算复杂度与优化

7.1 计算瓶颈:$O(n^2)$ 的注意力

Self-Attention 的计算复杂度是 $O(n^2 \cdot d)$,其中 $n$ 是序列长度。

这意味着:

  • 序列长度翻倍,计算量变成 4 倍
  • 处理 100K token 的长文本非常昂贵
graph LR subgraph 复杂度对比 A[序列长度 1K] --> A1[Attention: 1M 次计算] B[序列长度 4K] --> B1[Attention: 16M 次计算] C[序列长度 100K] --> C1[Attention: 10B 次计算] end

7.2 优化方案

方法原理代表工作
稀疏注意力只计算部分位置的 attentionSparse Transformer, Longformer
线性注意力用核函数近似,降到 O(n)Linear Attention, RWKV
Flash Attention优化内存访问模式FlashAttention 1/2/3
分组查询注意力多个 Query 共享 K/VGQA (LLaMA 2)
滑动窗口只关注局部窗口Mistral

7.3 KV Cache:推理加速

在自回归生成时,每一步都需要计算之前所有 token 的 K、V。KV Cache 把它们缓存起来,避免重复计算:

# 没有 KV Cache:每步都重新计算
step 1: 计算 K_1, V_1
step 2: 计算 K_1, K_2, V_1, V_2  # K_1, V_1 重复了!
step 3: 计算 K_1, K_2, K_3, V_1, V_2, V_3  # 又重复了!

# 有 KV Cache:缓存之前的结果
step 1: 计算并缓存 K_1, V_1
step 2: 只计算 K_2, V_2,拼接缓存的 K_1, V_1
step 3: 只计算 K_3, V_3,拼接缓存的 K_1, K_2, V_1, V_2

KV Cache 是推理优化的关键,但也带来显存压力(缓存随序列长度线性增长)。

八、完整代码:从零实现 Transformer

import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class MultiHeadAttention(nn.Module):
    """多头注意力"""
    def __init__(self, embed_dim, num_heads, dropout=0.1):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        self.W_q = nn.Linear(embed_dim, embed_dim)
        self.W_k = nn.Linear(embed_dim, embed_dim)
        self.W_v = nn.Linear(embed_dim, embed_dim)
        self.W_o = nn.Linear(embed_dim, embed_dim)
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)
        
        # 线性变换
        Q = self.W_q(query)
        K = self.W_k(key)
        V = self.W_v(value)
        
        # 拆分成多头
        Q = Q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        
        # 计算注意力
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        attn_output = torch.matmul(attn_weights, V)
        
        # 合并多头
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(batch_size, -1, self.embed_dim)
        
        return self.W_o(attn_output)


class FeedForward(nn.Module):
    """前馈神经网络"""
    def __init__(self, embed_dim, ff_dim, dropout=0.1):
        super().__init__()
        self.linear1 = nn.Linear(embed_dim, ff_dim)
        self.linear2 = nn.Linear(ff_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
        self.activation = nn.GELU()
    
    def forward(self, x):
        x = self.activation(self.linear1(x))
        x = self.dropout(x)
        return self.linear2(x)


class TransformerEncoderLayer(nn.Module):
    """Transformer 编码器层"""
    def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(embed_dim, num_heads, dropout)
        self.ffn = FeedForward(embed_dim, ff_dim, dropout)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        # Self-Attention + 残差
        attn_output = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        
        # FFN + 残差
        ffn_output = self.ffn(x)
        x = self.norm2(x + self.dropout(ffn_output))
        
        return x


class TransformerDecoderLayer(nn.Module):
    """Transformer 解码器层"""
    def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(embed_dim, num_heads, dropout)
        self.cross_attn = MultiHeadAttention(embed_dim, num_heads, dropout)
        self.ffn = FeedForward(embed_dim, ff_dim, dropout)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.norm3 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
        # Masked Self-Attention
        attn_output = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(attn_output))
        
        # Cross-Attention
        attn_output = self.cross_attn(x, encoder_output, encoder_output, src_mask)
        x = self.norm2(x + self.dropout(attn_output))
        
        # FFN
        ffn_output = self.ffn(x)
        x = self.norm3(x + self.dropout(ffn_output))
        
        return x


class PositionalEncoding(nn.Module):
    """正弦位置编码"""
    def __init__(self, embed_dim, max_len=5000, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        
        pe = torch.zeros(max_len, embed_dim)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, embed_dim, 2).float() * (-math.log(10000.0) / embed_dim)
        )
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)


class Transformer(nn.Module):
    """完整的 Transformer 模型"""
    def __init__(
        self, 
        src_vocab_size,
        tgt_vocab_size,
        embed_dim=512, 
        num_heads=8, 
        num_encoder_layers=6,
        num_decoder_layers=6,
        ff_dim=2048, 
        dropout=0.1,
        max_len=5000
    ):
        super().__init__()
        
        # Embedding
        self.src_embedding = nn.Embedding(src_vocab_size, embed_dim)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, embed_dim)
        self.pos_encoding = PositionalEncoding(embed_dim, max_len, dropout)
        
        # Encoder
        self.encoder_layers = nn.ModuleList([
            TransformerEncoderLayer(embed_dim, num_heads, ff_dim, dropout)
            for _ in range(num_encoder_layers)
        ])
        
        # Decoder
        self.decoder_layers = nn.ModuleList([
            TransformerDecoderLayer(embed_dim, num_heads, ff_dim, dropout)
            for _ in range(num_decoder_layers)
        ])
        
        # Output
        self.fc_out = nn.Linear(embed_dim, tgt_vocab_size)
        
        self.embed_dim = embed_dim
    
    def encode(self, src, src_mask=None):
        x = self.src_embedding(src) * math.sqrt(self.embed_dim)
        x = self.pos_encoding(x)
        
        for layer in self.encoder_layers:
            x = layer(x, src_mask)
        
        return x
    
    def decode(self, tgt, encoder_output, src_mask=None, tgt_mask=None):
        x = self.tgt_embedding(tgt) * math.sqrt(self.embed_dim)
        x = self.pos_encoding(x)
        
        for layer in self.decoder_layers:
            x = layer(x, encoder_output, src_mask, tgt_mask)
        
        return x
    
    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        encoder_output = self.encode(src, src_mask)
        decoder_output = self.decode(tgt, encoder_output, src_mask, tgt_mask)
        output = self.fc_out(decoder_output)
        return output


# 使用示例
if __name__ == "__main__":
    # 超参数
    src_vocab_size = 10000
    tgt_vocab_size = 10000
    embed_dim = 512
    num_heads = 8
    num_layers = 6
    
    # 创建模型
    model = Transformer(
        src_vocab_size=src_vocab_size,
        tgt_vocab_size=tgt_vocab_size,
        embed_dim=embed_dim,
        num_heads=num_heads,
        num_encoder_layers=num_layers,
        num_decoder_layers=num_layers
    )
    
    # 模拟输入
    batch_size = 2
    src_len = 10
    tgt_len = 8
    
    src = torch.randint(0, src_vocab_size, (batch_size, src_len))
    tgt = torch.randint(0, tgt_vocab_size, (batch_size, tgt_len))
    
    # 前向传播
    output = model(src, tgt)
    print(f"输入源序列: {src.shape}")
    print(f"输入目标序列: {tgt.shape}")
    print(f"输出形状: {output.shape}")
    
    # 计算参数量
    total_params = sum(p.numel() for p in model.parameters())
    print(f"模型参数量: {total_params:,}")

九、总结:Transformer 的核心洞见

mindmap root((Transformer 核心)) 注意力机制 Self-Attention Multi-Head 可并行计算 位置编码 正弦编码 RoPE 相对位置 架构设计 残差连接 层归一化 FFN 扩展 成功因素 并行训练 长距离建模 Scaling Law 通用性

关键 Takeaway

  1. Attention Is All You Need:用 Attention 替代 RNN,获得并行能力和长距离建模能力
  2. Multi-Head:让模型学习多种不同的关注模式
  3. Position Encoding:补充 Attention 缺失的位置信息
  4. 残差 + LayerNorm:让深层网络可以训练
  5. Scaling Law:大力出奇迹是可预测的

参考资料

  1. Attention Is All You Need - 原论文
  2. The Illustrated Transformer - 可视化讲解
  3. The Annotated Transformer - 带注释的代码
  4. Let's build GPT - Karpathy 的视频教程
评论区
暂无评论
avatar