搜 索

🐒Attention变体:注意力的七十二变

  • 2阅读
  • 2025年02月15日
  • 0评论
首页 / AI/大数据 / 正文

上一篇我们学会了 Self-Attention,这一篇我们来学习如何把它玩坏。

前言:Attention 很好,但是...

上一篇文章,我们搞懂了 Transformer 的核心——Self-Attention。

它很强大,但有一个致命的问题:太慢了,太费显存了

$$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V $$

看到那个 $QK^T$ 了吗?这是一个 $n \times n$ 的矩阵,$n$ 是序列长度。

这意味着:

  • 时间复杂度:$O(n^2 \cdot d)$
  • 空间复杂度:$O(n^2)$

当 $n = 1000$ 时,需要存储 100 万个注意力分数。
当 $n = 100000$ 时,需要存储 100 亿 个注意力分数。

这就是为什么早期的 GPT 只能处理 2048 个 token,而你想让它读完一本书?做梦。

graph LR subgraph "序列长度 vs 计算量" A["1K tokens"] --> A1["1M 次计算"] B["4K tokens"] --> B1["16M 次计算"] C["32K tokens"] --> C1["1B 次计算"] D["128K tokens"] --> D1["16B 次计算"] end style D1 fill:#ff6b6b

过去几年,研究者们想尽办法优化 Attention,催生了一大堆变体。

今天,我们来盘点这些"注意力的七十二变"。


一、问题诊断:Attention 到底慢在哪?

在开始优化之前,我们需要先搞清楚瓶颈在哪。

1.1 计算瓶颈 vs 内存瓶颈

很多人以为 Attention 慢是因为计算量大。错!

现代 GPU 的计算能力很强,真正的瓶颈是 内存带宽

graph TB subgraph GPU架构 HBM["HBM (显存)
容量大,带宽低
~2TB/s"] SRAM["SRAM (片上缓存)
容量小,带宽高
~19TB/s"] Compute["计算单元
312 TFLOPS (A100)"] end HBM <-->|"瓶颈!"| SRAM SRAM <--> Compute style HBM fill:#ff6b6b

A100 GPU 的数据:

  • 计算能力:312 TFLOPS (FP16)
  • HBM 带宽:2 TB/s
  • SRAM 容量:20 MB(很小!)

这意味着什么?

计算 1 TFLOP 需要的时间:$\frac{1}{312} \approx 3.2$ ms

从 HBM 读取 1 TB 需要的时间:$\frac{1}{2} = 500$ ms

差了 150 倍

所以,如果你的算法需要频繁地在 HBM 和 SRAM 之间搬运数据,计算单元就会闲着等数据——这就是 内存带宽瓶颈(Memory-bound)

1.2 标准 Attention 的内存访问

让我们看看标准 Attention 是怎么执行的:

# 标准 Attention 的执行流程(伪代码)
def standard_attention(Q, K, V):
    # Step 1: 计算 QK^T,写回 HBM
    S = Q @ K.T                    # 从 HBM 读 Q, K,写 S 到 HBM
    
    # Step 2: 计算 softmax,写回 HBM  
    P = softmax(S)                 # 从 HBM 读 S,写 P 到 HBM
    
    # Step 3: 计算 PV,写回 HBM
    O = P @ V                      # 从 HBM 读 P, V,写 O 到 HBM
    
    return O

问题在于:中间结果 S 和 P 都是 $n \times n$ 的大矩阵,必须存到 HBM 里。

这导致了大量的内存读写,而内存带宽是瓶颈。

1.3 优化方向

基于以上分析,Attention 的优化有三个方向:

mindmap root((Attention 优化)) 减少计算量 稀疏注意力 线性注意力 减少内存访问 Flash Attention 融合算子 减少 KV Cache MQA/GQA KV Cache 量化

接下来,我们逐一介绍。


二、稀疏注意力:不是所有位置都重要

2.1 核心思想

标准 Attention 让每个 token 都关注所有其他 token。但真的有必要吗?

想想你读文章的时候:

  • 读到代词"它"时,你会回看最近的名词
  • 读到总结段落时,你会回看开头
  • 大部分时候,你主要关注附近的内容

稀疏注意力的思路就是:只计算"重要"位置的注意力,忽略其他位置

2.2 Local Attention(滑动窗口)

最简单的稀疏模式:只关注附近的 token

标准 Attention(每个位置关注所有位置):
位置 0: [1, 1, 1, 1, 1, 1, 1, 1]
位置 1: [1, 1, 1, 1, 1, 1, 1, 1]
位置 2: [1, 1, 1, 1, 1, 1, 1, 1]
...

Local Attention(窗口大小=3):
位置 0: [1, 1, 0, 0, 0, 0, 0, 0]
位置 1: [1, 1, 1, 0, 0, 0, 0, 0]
位置 2: [0, 1, 1, 1, 0, 0, 0, 0]
位置 3: [0, 0, 1, 1, 1, 0, 0, 0]
...

复杂度:从 $O(n^2)$ 降到 $O(n \cdot w)$,其中 $w$ 是窗口大小。

代表模型:Mistral 使用 Sliding Window Attention,窗口大小 4096。

def sliding_window_attention(Q, K, V, window_size):
    """滑动窗口注意力"""
    seq_len = Q.shape[1]
    
    # 创建滑动窗口 mask
    mask = torch.zeros(seq_len, seq_len)
    for i in range(seq_len):
        start = max(0, i - window_size // 2)
        end = min(seq_len, i + window_size // 2 + 1)
        mask[i, start:end] = 1
    
    # 计算注意力(只在 mask=1 的位置)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(Q.shape[-1])
    scores = scores.masked_fill(mask == 0, float('-inf'))
    attn_weights = F.softmax(scores, dim=-1)
    
    return torch.matmul(attn_weights, V)

2.3 Longformer:混合稀疏模式

纯 Local Attention 的问题:无法捕捉长距离依赖

Longformer 的解决方案:Local + Global 混合

graph TB subgraph Longformer注意力模式 Local["Local Attention
每个 token 关注窗口内"] Global["Global Attention
特殊 token 关注所有"] end Local --> Mix["混合使用"] Global --> Mix

具体来说:

  • 大部分 token:使用滑动窗口(Local)
  • 特殊 token(如 [CLS]、段落开头):关注所有位置(Global)
[CLS] token:  [1, 1, 1, 1, 1, 1, 1, 1]  ← Global
普通 token:   [1, 1, 1, 0, 0, 0, 0, 0]  ← Local
普通 token:   [0, 1, 1, 1, 0, 0, 0, 0]  ← Local
段落开头:     [1, 1, 1, 1, 1, 1, 1, 1]  ← Global

2.4 BigBird:加入随机连接

BigBird 在 Longformer 基础上又加了一个:随机注意力

graph LR subgraph BigBird R["Random
随机关注几个位置"] L["Local
关注邻近位置"] G["Global
特殊 token"] end R --> Final["完整注意力"] L --> Final G --> Final

随机连接的好处:即使两个 token 很远,也有概率"直接对话",信息传播更快。

2.5 稀疏注意力的效果

模型最大长度复杂度方法
标准 Transformer512-2K$O(n^2)$Full
Longformer4K-16K$O(n)$Local + Global
BigBird4K-16K$O(n)$Local + Global + Random
Mistral32K$O(n \cdot w)$Sliding Window

问题:稀疏注意力虽然快,但会损失一些信息。现代大模型更倾向于用 Flash Attention 来加速,而不是牺牲注意力的完整性。


三、线性注意力:从 $O(n^2)$ 到 $O(n)$

3.1 核心思想

标准 Attention 的复杂度来自于 $QK^T$ 这个矩阵乘法:

$$ \text{Attention} = \text{softmax}(QK^T) V $$

线性注意力的想法:能不能避免显式计算 $QK^T$?

关键观察:如果去掉 softmax,可以改变计算顺序!

$$ \text{Linear Attention} = Q(K^T V) $$

矩阵乘法是结合的:

  • $(\text{softmax}(QK^T)) V$:先算 $n \times n$,再乘 $V$,复杂度 $O(n^2 d)$
  • $Q(K^T V)$:先算 $K^T V$($d \times d$),再乘 $Q$,复杂度 $O(nd^2)$

当 $n >> d$ 时,后者快得多!

3.2 用核函数替代 Softmax

但问题是,softmax 保证了注意力权重是正的、和为 1。去掉 softmax 会破坏这些性质。

解决方案:用核函数(Kernel)近似 softmax

$$ \text{softmax}(q_i^T k_j) \approx \phi(q_i)^T \phi(k_j) $$

其中 $\phi$ 是一个特征映射函数。

def linear_attention(Q, K, V, feature_map):
    """线性注意力"""
    # 应用特征映射
    Q = feature_map(Q)  # (batch, seq, feature_dim)
    K = feature_map(K)
    
    # 先算 K^T V (d × d 矩阵)
    KV = torch.einsum('bnd,bnm->bdm', K, V)
    
    # 再用 Q 乘
    output = torch.einsum('bnd,bdm->bnm', Q, KV)
    
    # 归一化
    normalizer = torch.einsum('bnd,bd->bn', Q, K.sum(dim=1))
    output = output / normalizer.unsqueeze(-1)
    
    return output

def elu_feature_map(x):
    """ELU 特征映射"""
    return F.elu(x) + 1

3.3 RWKV:RNN 的复仇

RWKV 是一个很有意思的模型:它用 RNN 的形式实现了类似 Transformer 的效果。

核心思想:把 Attention 重写成可以递归计算的形式。

# RWKV 的 Time Mixing(简化版)
def rwkv_time_mixing(x, state, w, u, k, v):
    """
    x: 当前输入
    state: 上一步的状态
    w, u, k, v: 可学习参数
    """
    # 计算 k, v
    k = x @ W_k
    v = x @ W_v
    
    # 更新状态(类似 RNN)
    wkv = state * w + k * v
    state = state * w + k
    
    # 输出
    output = wkv / state
    
    return output, state

优点

  • 推理时复杂度 $O(1)$(相对于序列长度)
  • 可以处理无限长序列
  • 训练可以并行

缺点

  • 效果略逊于标准 Transformer
  • 社区支持较少

3.4 线性注意力的现状

timeline title 线性注意力发展 2020 : Linear Transformer : 核函数近似 2021 : Performer : FAVOR+ 随机特征 2022 : RWKV : RNN 复仇 2023 : Mamba : 状态空间模型 2024 : 仍在探索中 : 尚未成为主流

现实情况:线性注意力在学术上很有意思,但实际应用中,大家更多用 Flash Attention 来加速标准 Attention,而不是换成线性注意力。

原因:

  1. 效果差距:线性注意力在长距离依赖上仍有差距
  2. 工程成熟度:Flash Attention 已经很成熟
  3. 硬件优化:GPU 对矩阵乘法优化很好

四、Flash Attention:内存优化的艺术

这是目前最实用的 Attention 优化技术,几乎所有现代大模型都在用。

4.1 核心思想

Flash Attention 的核心洞察:标准 Attention 的瓶颈不是计算,而是内存访问

解决方案:Tiling(分块)+ 重计算

不要把整个 $n \times n$ 的注意力矩阵存到 HBM,而是:

  1. 把 Q, K, V 分成小块
  2. 每次只在 SRAM 里计算一小块
  3. 用 Online Softmax 逐块更新结果
  4. 最终结果直接写回 HBM,不存中间矩阵
graph TB subgraph 标准Attention A1["1. Q,K,V 从 HBM 读入"] A2["2. 计算 S = QK^T"] A3["3. S 写回 HBM"] A4["4. S 读入,计算 P = softmax(S)"] A5["5. P 写回 HBM"] A6["6. P,V 读入,计算 O = PV"] A7["7. O 写回 HBM"] A1 --> A2 --> A3 --> A4 --> A5 --> A6 --> A7 end subgraph FlashAttention B1["1. 分块加载 Q,K,V"] B2["2. 在 SRAM 中计算"] B3["3. Online Softmax 更新"] B4["4. 直接写出最终 O"] B1 --> B2 --> B3 --> B4 B3 -.-> B2 end style A3 fill:#ff6b6b style A5 fill:#ff6b6b

4.2 Online Softmax

Flash Attention 的关键技术是 Online Softmax:不需要看完整个序列就能计算 softmax。

标准 softmax 需要两次遍历:

  1. 第一遍:找最大值 $m = \max(x)$
  2. 第二遍:计算 $\text{softmax}(x) = \frac{e^{x-m}}{\sum e^{x-m}}$

Online Softmax 可以一遍完成,通过增量更新:

def online_softmax(x_blocks):
    """Online Softmax: 逐块更新"""
    m = float('-inf')  # 当前最大值
    l = 0.0            # 当前归一化因子
    
    for x_block in x_blocks:
        # 更新最大值
        m_new = max(m, x_block.max())
        
        # 更新归一化因子
        l = l * exp(m - m_new) + exp(x_block - m_new).sum()
        m = m_new
    
    # 最终的 softmax 可以从 m 和 l 计算出来
    return m, l

4.3 Flash Attention 算法

def flash_attention(Q, K, V, block_size=64):
    """
    Flash Attention 简化实现
    实际实现是 CUDA kernel,这里只展示逻辑
    """
    batch, seq_len, d = Q.shape
    O = torch.zeros_like(Q)
    
    # 分块
    num_blocks = (seq_len + block_size - 1) // block_size
    
    for i in range(num_blocks):
        # 当前 Q 块
        q_block = Q[:, i*block_size:(i+1)*block_size, :]
        
        # 初始化 Online Softmax 状态
        m_i = torch.full((batch, block_size), float('-inf'))
        l_i = torch.zeros(batch, block_size)
        o_i = torch.zeros(batch, block_size, d)
        
        for j in range(num_blocks):
            # 当前 K, V 块
            k_block = K[:, j*block_size:(j+1)*block_size, :]
            v_block = V[:, j*block_size:(j+1)*block_size, :]
            
            # 计算注意力分数(在 SRAM 中)
            s_ij = q_block @ k_block.transpose(-2, -1) / math.sqrt(d)
            
            # Online Softmax 更新
            m_new = torch.maximum(m_i, s_ij.max(dim=-1).values)
            
            # 更新归一化因子和输出
            exp_old = torch.exp(m_i - m_new)
            exp_new = torch.exp(s_ij - m_new.unsqueeze(-1))
            
            l_new = l_i * exp_old + exp_new.sum(dim=-1)
            
            o_i = (o_i * l_i.unsqueeze(-1) * exp_old.unsqueeze(-1) + 
                   exp_new @ v_block) / l_new.unsqueeze(-1)
            
            m_i = m_new
            l_i = l_new
        
        # 写回结果
        O[:, i*block_size:(i+1)*block_size, :] = o_i
    
    return O

4.4 Flash Attention 的效果

序列长度标准 AttentionFlash Attention加速比
1K1x2-3x2-3x
4K1x3-4x3-4x
16KOOM
64KOOM

关键优势

  1. 显存节省:不存储 $n \times n$ 的注意力矩阵
  2. 速度提升:减少 HBM 访问次数
  3. 支持长序列:显存占用从 $O(n^2)$ 降到 $O(n)$

4.5 Flash Attention 2 & 3

Flash Attention 2 的改进:

  • 更好的并行策略
  • 减少非矩阵乘法运算
  • 速度再提升 2x

Flash Attention 3(2024)的改进:

  • 利用 Hopper GPU(H100)的新特性
  • 异步执行、硬件加速
  • 接近理论峰值性能

4.6 如何使用 Flash Attention

# 方法 1:使用 PyTorch 内置(2.0+)
import torch.nn.functional as F

# 自动使用 Flash Attention(如果可用)
output = F.scaled_dot_product_attention(query, key, value)

# 方法 2:使用 flash-attn 库
from flash_attn import flash_attn_func

output = flash_attn_func(q, k, v, causal=True)

# 方法 3:使用 transformers 库
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b",
    attn_implementation="flash_attention_2"  # 指定使用 Flash Attention
)

五、GQA/MQA:KV Cache 的优化

5.1 KV Cache 的显存问题

在自回归生成时,我们需要缓存之前所有 token 的 K 和 V(KV Cache)。

KV Cache 大小计算

$$ \text{KV Cache} = 2 \times n_{layers} \times n_{heads} \times d_{head} \times \text{seq\_len} \times \text{batch\_size} $$

以 LLaMA-2-70B 为例:

  • 80 层,64 个 head,head 维度 128
  • 序列长度 4096,batch size 1

$$ \text{KV Cache} = 2 \times 80 \times 64 \times 128 \times 4096 \times 2\text{ bytes} = \textbf{10.7 GB} $$

光 KV Cache 就要 10GB 显存!如果 batch size 大一点,或者序列更长,显存直接爆炸。

5.2 Multi-Query Attention (MQA)

核心思想:所有 Query head 共享同一组 K 和 V。

graph TB subgraph MHA["Multi-Head Attention (标准)"] Q1["Q head 1"] --> K1["K head 1"] Q2["Q head 2"] --> K2["K head 2"] Q3["Q head 3"] --> K3["K head 3"] Q4["Q head 4"] --> K4["K head 4"] end subgraph MQA["Multi-Query Attention"] Q1_["Q head 1"] --> K_["K (共享)"] Q2_["Q head 2"] --> K_ Q3_["Q head 3"] --> K_ Q4_["Q head 4"] --> K_ end

效果:KV Cache 减少到原来的 $\frac{1}{n_{heads}}$。

代价:效果会略有下降。

5.3 Grouped-Query Attention (GQA)

GQA 是 MHA 和 MQA 的折中:把 Query head 分组,每组共享一个 K/V head

graph TB subgraph GQA["Grouped-Query Attention (4 groups)"] Q1["Q head 1"] --> K1["K group 1"] Q2["Q head 2"] --> K1 Q3["Q head 3"] --> K2["K group 2"] Q4["Q head 4"] --> K2 Q5["Q head 5"] --> K3["K group 3"] Q6["Q head 6"] --> K3 Q7["Q head 7"] --> K4["K group 4"] Q8["Q head 8"] --> K4 end

LLaMA 2 的选择

  • 70B 模型:8 个 KV head(64 个 Q head)
  • KV Cache 减少 8 倍,效果几乎无损

5.4 代码实现

class GroupedQueryAttention(nn.Module):
    """Grouped-Query Attention"""
    
    def __init__(self, embed_dim, num_q_heads, num_kv_heads):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_q_heads = num_q_heads
        self.num_kv_heads = num_kv_heads
        self.head_dim = embed_dim // num_q_heads
        
        # Q 有 num_q_heads 个 head
        self.W_q = nn.Linear(embed_dim, num_q_heads * self.head_dim)
        # K, V 只有 num_kv_heads 个 head
        self.W_k = nn.Linear(embed_dim, num_kv_heads * self.head_dim)
        self.W_v = nn.Linear(embed_dim, num_kv_heads * self.head_dim)
        self.W_o = nn.Linear(embed_dim, embed_dim)
        
        # 每组有多少个 Q head
        self.num_groups = num_q_heads // num_kv_heads
    
    def forward(self, x, kv_cache=None):
        batch, seq_len, _ = x.shape
        
        # 投影
        Q = self.W_q(x).view(batch, seq_len, self.num_q_heads, self.head_dim)
        K = self.W_k(x).view(batch, seq_len, self.num_kv_heads, self.head_dim)
        V = self.W_v(x).view(batch, seq_len, self.num_kv_heads, self.head_dim)
        
        # 处理 KV Cache
        if kv_cache is not None:
            K = torch.cat([kv_cache['k'], K], dim=1)
            V = torch.cat([kv_cache['v'], V], dim=1)
        
        # 扩展 K, V 以匹配 Q 的 head 数
        # (batch, seq, num_kv_heads, head_dim) -> (batch, seq, num_q_heads, head_dim)
        K = K.repeat_interleave(self.num_groups, dim=2)
        V = V.repeat_interleave(self.num_groups, dim=2)
        
        # 计算注意力(后续和标准 MHA 相同)
        Q = Q.transpose(1, 2)  # (batch, num_q_heads, seq, head_dim)
        K = K.transpose(1, 2)
        V = V.transpose(1, 2)
        
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
        attn_weights = F.softmax(scores, dim=-1)
        output = torch.matmul(attn_weights, V)
        
        # 合并 head
        output = output.transpose(1, 2).contiguous().view(batch, seq_len, -1)
        
        return self.W_o(output)

5.5 各方案对比

方案Q headsKV headsKV Cache效果
MHA3232100%最好
GQA-832825%接近 MHA
GQA-432412.5%略有下降
MQA3213.1%下降明显

实践建议:GQA 是目前的最佳实践,LLaMA 2/3、Mistral、Qwen 2 等都在用。


六、其他优化技术

6.1 投机采样(Speculative Decoding)

自回归生成的问题:每次只能生成一个 token。

投机采样的思路:用一个小模型先"猜"多个 token,然后让大模型一次性验证。

sequenceDiagram participant Draft as 小模型 (Draft) participant Target as 大模型 (Target) Draft->>Draft: 生成 token 1, 2, 3, 4 Draft->>Target: 发送候选序列 Target->>Target: 并行验证所有 token Target->>Target: 接受 token 1, 2, 3,拒绝 4 Target->>Draft: 返回正确的 token 4

效果:在不损失质量的前提下,加速 2-3 倍。

6.2 Prefix Caching

如果多个请求有相同的前缀(如相同的 system prompt),可以共享 KV Cache。

# 传统方式:每个请求都计算完整的 KV Cache
request_1 = "You are a helpful assistant. What is 2+2?"
request_2 = "You are a helpful assistant. What is 3+3?"
# 两次都要计算 "You are a helpful assistant." 的 KV

# Prefix Caching:相同前缀只计算一次
prefix_cache = compute_kv("You are a helpful assistant.")
answer_1 = generate_with_cache(prefix_cache, "What is 2+2?")
answer_2 = generate_with_cache(prefix_cache, "What is 3+3?")

vLLM 原生支持这个功能。

6.3 Chunked Prefill

处理长 prompt 时,一次性计算所有 token 的 KV 可能导致延迟尖峰。

Chunked Prefill 把 prefill 阶段分成多个小块,交错执行,让延迟更平滑。


七、实战:使用不同的 Attention 实现

7.1 PyTorch 原生(推荐)

import torch
import torch.nn.functional as F

def attention_comparison(Q, K, V):
    """比较不同 Attention 实现"""
    
    # 1. 标准实现
    def standard_attention(Q, K, V):
        d_k = Q.shape[-1]
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
        attn_weights = F.softmax(scores, dim=-1)
        return torch.matmul(attn_weights, V)
    
    # 2. PyTorch 2.0+ 的 scaled_dot_product_attention
    # 会自动选择最优实现(Flash Attention 如果可用)
    def sdpa_attention(Q, K, V):
        return F.scaled_dot_product_attention(Q, K, V)
    
    # 比较结果
    out_standard = standard_attention(Q, K, V)
    out_sdpa = sdpa_attention(Q, K, V)
    
    print(f"结果一致: {torch.allclose(out_standard, out_sdpa, atol=1e-5)}")
    
    return out_standard, out_sdpa

# 测试
batch, heads, seq_len, d_head = 2, 8, 1024, 64
Q = torch.randn(batch, heads, seq_len, d_head, device='cuda')
K = torch.randn(batch, heads, seq_len, d_head, device='cuda')
V = torch.randn(batch, heads, seq_len, d_head, device='cuda')

out1, out2 = attention_comparison(Q, K, V)

7.2 使用 Flash Attention 库

# 安装: pip install flash-attn --no-build-isolation
from flash_attn import flash_attn_func, flash_attn_qkvpacked_func

def use_flash_attention():
    batch, seq_len, heads, d_head = 2, 4096, 32, 128
    
    # 方式 1:分开的 Q, K, V
    q = torch.randn(batch, seq_len, heads, d_head, device='cuda', dtype=torch.float16)
    k = torch.randn(batch, seq_len, heads, d_head, device='cuda', dtype=torch.float16)
    v = torch.randn(batch, seq_len, heads, d_head, device='cuda', dtype=torch.float16)
    
    output = flash_attn_func(q, k, v, causal=True)
    print(f"Output shape: {output.shape}")
    
    # 方式 2:打包的 QKV
    qkv = torch.randn(batch, seq_len, 3, heads, d_head, device='cuda', dtype=torch.float16)
    output = flash_attn_qkvpacked_func(qkv, causal=True)
    
    return output

# 性能测试
import time

def benchmark_attention(seq_lengths=[1024, 4096, 16384]):
    for seq_len in seq_lengths:
        q = torch.randn(1, seq_len, 32, 128, device='cuda', dtype=torch.float16)
        k = torch.randn(1, seq_len, 32, 128, device='cuda', dtype=torch.float16)
        v = torch.randn(1, seq_len, 32, 128, device='cuda', dtype=torch.float16)
        
        # Warmup
        for _ in range(10):
            _ = flash_attn_func(q, k, v, causal=True)
        
        torch.cuda.synchronize()
        start = time.time()
        
        for _ in range(100):
            _ = flash_attn_func(q, k, v, causal=True)
        
        torch.cuda.synchronize()
        elapsed = (time.time() - start) / 100 * 1000
        
        print(f"Seq length {seq_len}: {elapsed:.2f} ms")

7.3 在 Transformers 中使用

from transformers import AutoModelForCausalLM, AutoTokenizer

# 加载模型时指定 attention 实现
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    torch_dtype=torch.float16,
    device_map="auto",
    attn_implementation="flash_attention_2"  # 使用 Flash Attention 2
)

# 或者使用 SDPA(PyTorch 原生)
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    torch_dtype=torch.float16,
    device_map="auto",
    attn_implementation="sdpa"  # 使用 PyTorch SDPA
)

# 检查当前使用的 attention 实现
print(f"Attention implementation: {model.config._attn_implementation}")

八、总结:如何选择?

flowchart TB Start["需要优化 Attention"] --> Q1{"显存是瓶颈?"} Q1 -->|是| Q2{"需要长序列?"} Q1 -->|否| Q3{"推理还是训练?"} Q2 -->|是| A1["Flash Attention
+ GQA"] Q2 -->|否| A2["Flash Attention"] Q3 -->|推理| A3["Flash Attention
+ KV Cache 优化"] Q3 -->|训练| A4["Flash Attention
+ 混合精度"] A1 --> Final["搞定!"] A2 --> Final A3 --> Final A4 --> Final

实践建议

场景推荐方案
一般使用PyTorch 2.0+ 的 SDPA(自动优化)
需要最佳性能flash-attn 库
长序列(>8K)Flash Attention + GQA
高并发推理vLLM(内置各种优化)
边缘部署量化 + GQA

关键 Takeaway

  1. Attention 的瓶颈是内存带宽,不是计算量
  2. Flash Attention 是目前最实用的优化,几乎无损
  3. GQA 可以大幅减少 KV Cache,LLaMA 2+ 都在用
  4. 稀疏注意力和线性注意力有学术价值,但工程应用有限
  5. PyTorch 2.0+ 的 SDPA 会自动选择最优实现,对新手友好
评论区
暂无评论
avatar