前言:上下文长度的进化
大模型上下文长度的演进:
| 模型 | 上下文长度 | 相当于 |
|---|---|---|
| GPT-3 | 4K | ~3000 字 |
| GPT-4 | 8K/32K | ~2.4万字 |
| Claude 2 | 100K | ~7.5万字 |
| Claude 3 | 200K | 一本小说 |
| Gemini 1.5 | 1M | 多本书 |
为什么长上下文重要?
graph LR
subgraph 短上下文
A1[长文档] --> B1[切分]
B1 --> C1[分别处理]
C1 --> D1[信息丢失]
end
subgraph 长上下文
A2[长文档] --> B2[完整输入]
B2 --> C2[全局理解]
C2 --> D2[精准回答]
end
style D1 fill:#ff6b6b
style D2 fill:#4ecdc4
应用场景:
- 📚 阅读整本书并回答问题
- 💻 分析整个代码仓库
- 📊 处理长篇报告和论文
- 🎬 理解长视频内容
一、为什么长上下文这么难?
1.1 Self-Attention 的复杂度
标准 Self-Attention 的计算复杂度是 O(n²):
# Self-Attention 计算
def self_attention(Q, K, V):
# Q, K, V: (batch, seq_len, d)
# 计算注意力分数:O(n²)
scores = Q @ K.transpose(-2, -1) # (batch, seq_len, seq_len)
# Softmax
attn = F.softmax(scores / sqrt(d), dim=-1)
# 加权求和:O(n²)
output = attn @ V
return output问题:
- 序列长度 4K → 注意力矩阵 4K × 4K = 1600 万
- 序列长度 100K → 注意力矩阵 100K × 100K = 100 亿!
graph LR
subgraph 复杂度增长
A["4K tokens
16M 计算"] --> B["32K tokens
1B 计算"] B --> C["100K tokens
10B 计算"] C --> D["1M tokens
1T 计算"] end style D fill:#ff6b6b
16M 计算"] --> B["32K tokens
1B 计算"] B --> C["100K tokens
10B 计算"] C --> D["1M tokens
1T 计算"] end style D fill:#ff6b6b
1.2 内存瓶颈
def calculate_attention_memory(seq_len, batch_size, num_heads, head_dim, dtype_bytes=2):
"""计算注意力矩阵的内存需求"""
# 注意力矩阵: (batch, num_heads, seq_len, seq_len)
attn_matrix = batch_size * num_heads * seq_len * seq_len * dtype_bytes
# Q, K, V: (batch, num_heads, seq_len, head_dim)
qkv = 3 * batch_size * num_heads * seq_len * head_dim * dtype_bytes
total_bytes = attn_matrix + qkv
total_gb = total_bytes / (1024 ** 3)
return total_gb
# 示例计算
print(f"4K context: {calculate_attention_memory(4096, 1, 32, 128):.2f} GB")
print(f"32K context: {calculate_attention_memory(32768, 1, 32, 128):.2f} GB")
print(f"100K context: {calculate_attention_memory(100000, 1, 32, 128):.2f} GB")
# 输出:
# 4K context: 1.03 GB
# 32K context: 66.00 GB
# 100K context: 614.00 GB ← 单张 GPU 装不下!1.3 位置编码的外推问题
传统位置编码在超出训练长度时会失效:
# 绝对位置编码:训练时最长 4K
position_embeddings = nn.Embedding(4096, hidden_size)
# 推理时遇到 5000 的位置
pos = 5000 # 超出范围,会报错或效果很差二、位置编码改进
2.1 位置编码类型
mindmap
root((位置编码))
绝对位置
Sinusoidal
Learned
相对位置
ALiBi
RoPE
XPos
无位置编码
某些架构不需要
2.2 RoPE:旋转位置编码
RoPE(Rotary Position Embedding) 是目前最流行的位置编码:
import torch
import torch.nn as nn
class RotaryEmbedding(nn.Module):
"""RoPE 旋转位置编码"""
def __init__(self, dim, max_seq_len=4096, base=10000):
super().__init__()
self.dim = dim
self.max_seq_len = max_seq_len
self.base = base
# 计算频率
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
# 预计算
self._set_cos_sin_cache(max_seq_len)
def _set_cos_sin_cache(self, seq_len):
"""预计算 cos 和 sin"""
t = torch.arange(seq_len, device=self.inv_freq.device)
freqs = torch.outer(t, self.inv_freq) # (seq_len, dim/2)
# 复制一份拼接
emb = torch.cat((freqs, freqs), dim=-1) # (seq_len, dim)
self.register_buffer("cos_cached", emb.cos())
self.register_buffer("sin_cached", emb.sin())
def forward(self, x, seq_len=None):
"""应用 RoPE"""
if seq_len > self.max_seq_len:
self._set_cos_sin_cache(seq_len)
return (
self.cos_cached[:seq_len],
self.sin_cached[:seq_len],
)
def rotate_half(x):
"""旋转一半"""
x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin):
"""应用旋转位置编码到 Q 和 K"""
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed2.3 位置插值(Position Interpolation)
让短上下文训练的模型支持长上下文:
class ScaledRotaryEmbedding(RotaryEmbedding):
"""位置插值的 RoPE"""
def __init__(self, dim, max_seq_len=4096, base=10000, scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(dim, max_seq_len, base)
def _set_cos_sin_cache(self, seq_len):
# 关键:对位置进行缩放
t = torch.arange(seq_len, device=self.inv_freq.device)
t = t / self.scaling_factor # 位置插值
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos())
self.register_buffer("sin_cached", emb.sin())
# 使用:将 4K 模型扩展到 16K
# scaling_factor = 16K / 4K = 4
rope = ScaledRotaryEmbedding(dim=64, max_seq_len=16384, scaling_factor=4.0)2.4 NTK-aware 插值
更好的插值方法,调整 base 频率:
class NTKScaledRotaryEmbedding(RotaryEmbedding):
"""NTK-aware 位置插值"""
def __init__(self, dim, max_seq_len=4096, base=10000, scaling_factor=1.0):
# 调整 base 而不是位置
new_base = base * (scaling_factor ** (dim / (dim - 2)))
super().__init__(dim, max_seq_len, new_base)
# 或者使用动态 NTK
class DynamicNTKRotaryEmbedding(RotaryEmbedding):
"""动态 NTK 缩放"""
def __init__(self, dim, max_seq_len=4096, base=10000, original_max_seq_len=4096):
self.original_max_seq_len = original_max_seq_len
super().__init__(dim, max_seq_len, base)
def forward(self, x, seq_len=None):
if seq_len > self.original_max_seq_len:
# 动态调整 base
scaling_factor = seq_len / self.original_max_seq_len
new_base = self.base * (scaling_factor ** (self.dim / (self.dim - 2)))
inv_freq = 1.0 / (new_base ** (torch.arange(0, self.dim, 2).float() / self.dim))
# 重新计算...
return super().forward(x, seq_len)2.5 YaRN:Yet another RoPE extensioN
结合位置插值和 NTK 的优点:
class YaRNRotaryEmbedding(RotaryEmbedding):
"""YaRN: 更好的 RoPE 扩展"""
def __init__(
self,
dim,
max_seq_len=4096,
base=10000,
original_max_seq_len=4096,
beta_fast=32,
beta_slow=1,
):
self.original_max_seq_len = original_max_seq_len
self.beta_fast = beta_fast
self.beta_slow = beta_slow
super().__init__(dim, max_seq_len, base)
def _compute_yarn_scaling(self, seq_len):
"""计算 YaRN 缩放因子"""
scale = seq_len / self.original_max_seq_len
# 低频外推,高频插值
dim_range = torch.arange(0, self.dim, 2).float()
# 计算每个维度的缩放
low_freq_factor = 1.0
high_freq_factor = scale
# 平滑过渡
# ...
return scaling_factors2.6 ALiBi:无需训练的长度外推
ALiBi(Attention with Linear Biases):直接在注意力分数上加线性偏置
class ALiBiAttention(nn.Module):
"""ALiBi 注意力"""
def __init__(self, num_heads):
super().__init__()
# 计算每个头的斜率
# 斜率呈几何级数:2^(-8/n), 2^(-8*2/n), ...
slopes = self._get_slopes(num_heads)
self.register_buffer("slopes", slopes)
def _get_slopes(self, num_heads):
"""计算 ALiBi 斜率"""
def get_slopes_power_of_2(n):
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
ratio = start
return [start * (ratio ** i) for i in range(n)]
if math.log2(num_heads).is_integer():
return torch.tensor(get_slopes_power_of_2(num_heads))
else:
# 处理非 2 的幂次
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
slopes = get_slopes_power_of_2(closest_power_of_2)
# ...
return torch.tensor(slopes)
def forward(self, q, k, v):
batch, num_heads, seq_len, head_dim = q.shape
# 标准注意力分数
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(head_dim)
# ALiBi 偏置
positions = torch.arange(seq_len, device=q.device)
relative_positions = positions.unsqueeze(0) - positions.unsqueeze(1)
alibi_bias = self.slopes.view(1, num_heads, 1, 1) * relative_positions.abs().unsqueeze(0).unsqueeze(0)
# 应用偏置
scores = scores - alibi_bias
# Softmax 和输出
attn = F.softmax(scores, dim=-1)
output = torch.matmul(attn, v)
return output三、高效注意力机制
3.1 稀疏注意力
只计算部分位置的注意力:
graph TB
subgraph 全注意力
A1[每个位置] --> B1[看所有位置]
B1 --> C1["O(n²)"]
end
subgraph 稀疏注意力
A2[每个位置] --> B2[只看部分位置]
B2 --> C2["O(n√n) 或 O(n)"]
end
常见稀疏模式:
class SparseAttentionPatterns:
"""稀疏注意力模式"""
@staticmethod
def local_attention(seq_len, window_size):
"""局部注意力:只看周围的 token"""
mask = torch.zeros(seq_len, seq_len, dtype=torch.bool)
for i in range(seq_len):
start = max(0, i - window_size)
end = min(seq_len, i + window_size + 1)
mask[i, start:end] = True
return mask
@staticmethod
def strided_attention(seq_len, stride):
"""跨步注意力:每隔几个 token 看一个"""
mask = torch.zeros(seq_len, seq_len, dtype=torch.bool)
for i in range(seq_len):
# 局部
mask[i, max(0, i-stride):i+1] = True
# 全局(每 stride 个位置)
mask[i, ::stride] = True
return mask
@staticmethod
def global_local_attention(seq_len, num_global_tokens, local_window):
"""全局 + 局部注意力"""
mask = torch.zeros(seq_len, seq_len, dtype=torch.bool)
# 全局 token 可以看所有位置
mask[:num_global_tokens, :] = True
mask[:, :num_global_tokens] = True
# 其他 token 只看局部
for i in range(num_global_tokens, seq_len):
start = max(num_global_tokens, i - local_window)
end = min(seq_len, i + local_window + 1)
mask[i, start:end] = True
return mask3.2 Longformer 和 BigBird
class LongformerAttention(nn.Module):
"""Longformer 风格的注意力"""
def __init__(self, hidden_size, num_heads, window_size=512, num_global_tokens=2):
super().__init__()
self.window_size = window_size
self.num_global_tokens = num_global_tokens
self.query = nn.Linear(hidden_size, hidden_size)
self.key = nn.Linear(hidden_size, hidden_size)
self.value = nn.Linear(hidden_size, hidden_size)
def forward(self, hidden_states, global_attention_mask=None):
batch_size, seq_len, _ = hidden_states.shape
q = self.query(hidden_states)
k = self.key(hidden_states)
v = self.value(hidden_states)
# 局部注意力(滑动窗口)
local_output = self.sliding_window_attention(q, k, v)
# 全局注意力(对于特殊 token)
if global_attention_mask is not None:
global_output = self.global_attention(q, k, v, global_attention_mask)
# 合并
output = local_output + global_output
else:
output = local_output
return output
def sliding_window_attention(self, q, k, v):
"""滑动窗口局部注意力"""
# 使用高效实现(如 xformers)
from xformers.ops import memory_efficient_attention
# 创建局部注意力 mask
# ...
return output3.3 Flash Attention
利用 GPU 内存层级优化注意力计算:
# Flash Attention 使用(已在第2篇详细介绍)
from flash_attn import flash_attn_func
def efficient_attention(q, k, v, causal=True):
"""使用 Flash Attention"""
# q, k, v: (batch, seq_len, num_heads, head_dim)
output = flash_attn_func(
q, k, v,
causal=causal,
softmax_scale=1.0 / math.sqrt(q.shape[-1]),
)
return output
# Flash Attention 2 支持变长序列
from flash_attn import flash_attn_varlen_func
def variable_length_attention(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k):
"""变长序列的 Flash Attention"""
output = flash_attn_varlen_func(
q, k, v,
cu_seqlens_q, cu_seqlens_k,
max_seqlen_q, max_seqlen_k,
causal=True,
)
return output3.4 Ring Attention
分布式处理超长序列:
graph LR
subgraph Ring Attention
GPU1[GPU 1
Block 1] --> GPU2[GPU 2
Block 2] GPU2 --> GPU3[GPU 3
Block 3] GPU3 --> GPU4[GPU 4
Block 4] GPU4 --> GPU1 end
Block 1] --> GPU2[GPU 2
Block 2] GPU2 --> GPU3[GPU 3
Block 3] GPU3 --> GPU4[GPU 4
Block 4] GPU4 --> GPU1 end
class RingAttention:
"""Ring Attention:分布式长序列处理"""
def __init__(self, world_size, rank):
self.world_size = world_size
self.rank = rank
def forward(self, q, k, v):
"""
每个 GPU 持有序列的一部分
通过环形通信完成完整注意力计算
"""
local_seq_len = q.shape[1]
# 初始化输出和 softmax 归一化项
output = torch.zeros_like(q)
lse = torch.full((q.shape[0], q.shape[1]), float('-inf')) # log-sum-exp
# 环形传递 K, V
k_recv, v_recv = k, v
for step in range(self.world_size):
# 计算当前块的注意力
block_output, block_lse = self.compute_block_attention(q, k_recv, v_recv)
# 更新输出(online softmax)
output, lse = self.update_output(output, lse, block_output, block_lse)
# 环形传递:发送 K, V 给下一个 GPU
k_recv = self.ring_send_recv(k_recv)
v_recv = self.ring_send_recv(v_recv)
return output
def ring_send_recv(self, tensor):
"""环形通信"""
send_to = (self.rank + 1) % self.world_size
recv_from = (self.rank - 1) % self.world_size
recv_tensor = torch.empty_like(tensor)
# 异步发送和接收
send_op = dist.isend(tensor, send_to)
recv_op = dist.irecv(recv_tensor, recv_from)
send_op.wait()
recv_op.wait()
return recv_tensor四、长上下文训练技术
4.1 渐进式长度训练
def progressive_length_training(model, dataset, target_length=32768):
"""渐进式增加序列长度"""
length_stages = [4096, 8192, 16384, 32768]
for stage_length in length_stages:
print(f"Training with context length: {stage_length}")
# 调整数据
stage_dataset = dataset.filter(lambda x: len(x['text']) <= stage_length)
# 调整位置编码
if hasattr(model, 'rope'):
model.rope.max_seq_len = stage_length
# 训练
trainer = Trainer(
model=model,
train_dataset=stage_dataset,
args=TrainingArguments(
max_seq_length=stage_length,
# 可能需要减小 batch size
per_device_train_batch_size=max(1, 32 // (stage_length // 4096)),
gradient_accumulation_steps=stage_length // 4096,
),
)
trainer.train()4.2 长度外推微调
def length_extrapolation_finetune(
model,
original_length=4096,
target_length=32768,
num_steps=1000,
):
"""长度外推微调"""
# 1. 调整位置编码
scaling_factor = target_length / original_length
# 使用 YaRN 或 NTK 插值
for layer in model.layers:
layer.self_attn.rotary_emb = YaRNRotaryEmbedding(
dim=layer.self_attn.head_dim,
max_seq_len=target_length,
original_max_seq_len=original_length,
)
# 2. 准备长文本数据
long_dataset = prepare_long_context_data(target_length)
# 3. 微调(通常只需要少量步骤)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
for step in range(num_steps):
batch = next(iter(long_dataset))
outputs = model(**batch)
loss = outputs.loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
if step % 100 == 0:
print(f"Step {step}, Loss: {loss.item():.4f}")4.3 数据准备
def prepare_long_context_data(max_length=32768):
"""准备长上下文训练数据"""
from datasets import load_dataset
# 使用长文本数据集
datasets_to_use = [
"emozilla/pg19", # 书籍
"togethercomputer/Long-Data-Collections", # 长文本集合
"THUDM/LongBench", # 长文本基准
]
all_texts = []
for ds_name in datasets_to_use:
ds = load_dataset(ds_name, split="train")
for item in ds:
text = item.get('text', item.get('content', ''))
if len(text) >= max_length // 2: # 保留足够长的文本
all_texts.append(text[:max_length])
return all_texts
def create_long_context_examples(texts, tokenizer, max_length):
"""创建长上下文训练样本"""
examples = []
for text in texts:
tokens = tokenizer.encode(text, truncation=True, max_length=max_length)
if len(tokens) >= max_length * 0.8: # 至少 80% 长度
examples.append({
'input_ids': tokens,
'attention_mask': [1] * len(tokens),
})
return examples五、长上下文推理优化
5.1 KV Cache 压缩
class CompressedKVCache:
"""压缩 KV Cache"""
def __init__(self, compression_ratio=4):
self.compression_ratio = compression_ratio
self.cache = {}
def compress(self, k, v, layer_idx):
"""压缩 KV Cache"""
batch, num_heads, seq_len, head_dim = k.shape
if seq_len <= 1024:
# 短序列不压缩
self.cache[layer_idx] = (k, v)
return k, v
# 保留最近的 token(full resolution)
recent_len = seq_len // self.compression_ratio
# 压缩旧的 token(average pooling)
old_len = seq_len - recent_len
compressed_old_len = old_len // self.compression_ratio
k_old = k[:, :, :old_len, :].view(
batch, num_heads, compressed_old_len, self.compression_ratio, head_dim
).mean(dim=3)
v_old = v[:, :, :old_len, :].view(
batch, num_heads, compressed_old_len, self.compression_ratio, head_dim
).mean(dim=3)
# 拼接
k_compressed = torch.cat([k_old, k[:, :, old_len:, :]], dim=2)
v_compressed = torch.cat([v_old, v[:, :, old_len:, :]], dim=2)
self.cache[layer_idx] = (k_compressed, v_compressed)
return k_compressed, v_compressed5.2 Streaming LLM
保持固定长度的 KV Cache:
class StreamingLLM:
"""Streaming LLM:无限长度生成"""
def __init__(self, model, window_size=4096, sink_size=4):
self.model = model
self.window_size = window_size
self.sink_size = sink_size # 始终保留的开头 token
def generate(self, input_ids, max_new_tokens=1000):
"""流式生成"""
kv_cache = None
generated_ids = input_ids.clone()
for _ in range(max_new_tokens):
# 获取当前输入
if kv_cache is not None:
# 只输入最后一个 token
current_input = generated_ids[:, -1:]
else:
current_input = generated_ids
# 前向传播
outputs = self.model(
input_ids=current_input,
past_key_values=kv_cache,
use_cache=True,
)
kv_cache = outputs.past_key_values
# 压缩 KV Cache(保持固定长度)
kv_cache = self.compress_kv_cache(kv_cache)
# 采样下一个 token
next_token = self.sample(outputs.logits[:, -1, :])
generated_ids = torch.cat([generated_ids, next_token], dim=1)
if next_token.item() == self.model.config.eos_token_id:
break
return generated_ids
def compress_kv_cache(self, kv_cache):
"""压缩 KV Cache 到固定长度"""
compressed_cache = []
for layer_kv in kv_cache:
k, v = layer_kv
seq_len = k.shape[2]
if seq_len <= self.window_size:
compressed_cache.append((k, v))
continue
# 保留 sink tokens + 最近的 tokens
k_sink = k[:, :, :self.sink_size, :]
v_sink = v[:, :, :self.sink_size, :]
k_recent = k[:, :, -(self.window_size - self.sink_size):, :]
v_recent = v[:, :, -(self.window_size - self.sink_size):, :]
k_compressed = torch.cat([k_sink, k_recent], dim=2)
v_compressed = torch.cat([v_sink, v_recent], dim=2)
compressed_cache.append((k_compressed, v_compressed))
return compressed_cache5.3 分块处理
def chunked_generation(model, tokenizer, long_prompt, chunk_size=4096):
"""分块处理长 prompt"""
tokens = tokenizer.encode(long_prompt)
total_len = len(tokens)
if total_len <= chunk_size:
# 短文本直接处理
return model.generate(torch.tensor([tokens]))
# 分块处理
kv_cache = None
for i in range(0, total_len, chunk_size):
chunk = tokens[i:i+chunk_size]
with torch.no_grad():
outputs = model(
input_ids=torch.tensor([chunk]),
past_key_values=kv_cache,
use_cache=True,
)
kv_cache = outputs.past_key_values
# 生成
generated = model.generate(
input_ids=torch.tensor([[tokens[-1]]]), # 最后一个 token
past_key_values=kv_cache,
max_new_tokens=500,
)
return generated六、实战:使用长上下文模型
6.1 使用 Claude 3 (200K)
import anthropic
def use_claude_long_context():
"""使用 Claude 3 的长上下文能力"""
client = anthropic.Anthropic()
# 读取长文档
with open("long_document.txt", "r") as f:
long_text = f.read() # 假设是一本小说
# 构造 prompt
prompt = f"""以下是一本完整的小说:
{long_text}
请回答以下问题:
1. 这本书的主要人物有哪些?
2. 故事的主线是什么?
3. 结局如何?
"""
response = client.messages.create(
model="claude-3-opus-20240229",
max_tokens=4096,
messages=[
{"role": "user", "content": prompt}
],
)
print(response.content[0].text)6.2 使用 Gemini 1.5 (1M)
import google.generativeai as genai
def use_gemini_long_context():
"""使用 Gemini 1.5 的超长上下文"""
genai.configure(api_key="your-api-key")
model = genai.GenerativeModel('gemini-1.5-pro')
# 可以处理多本书、多个视频
long_content = load_multiple_documents() # 加载大量内容
response = model.generate_content([
long_content,
"基于以上所有内容,进行综合分析..."
])
print(response.text)6.3 本地部署长上下文模型
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
def load_long_context_model():
"""加载支持长上下文的开源模型"""
# 使用支持长上下文的模型
model_name = "gradientai/Llama-3-8B-Instruct-262k" # 262K 上下文
# 或者 "Qwen/Qwen2-7B-Instruct" # 支持 32K
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto",
# 可能需要启用 Flash Attention
attn_implementation="flash_attention_2",
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
return model, tokenizer
def process_long_document(model, tokenizer, document, question):
"""处理长文档"""
prompt = f"""Document:
{document}
Question: {question}
Answer:"""
inputs = tokenizer(prompt, return_tensors="pt", truncation=False)
print(f"Input length: {inputs.input_ids.shape[1]} tokens")
with torch.no_grad():
outputs = model.generate(
**inputs.to(model.device),
max_new_tokens=500,
temperature=0.7,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return response七、总结
长上下文核心要点
mindmap
root((长上下文))
挑战
O(n²)复杂度
内存爆炸
位置编码外推
位置编码改进
RoPE
位置插值
YaRN
ALiBi
高效注意力
稀疏注意力
Flash Attention
Ring Attention
推理优化
KV Cache 压缩
Streaming LLM
分块处理
关键 Takeaway
- 长上下文的核心挑战是 O(n²) 复杂度:需要高效注意力机制
- 位置编码很关键:RoPE + 插值/YaRN 可以外推到更长
- Flash Attention 是基础:大幅降低内存使用
- 推理时需要 KV Cache 优化:Streaming LLM、压缩等
- 商业模型已支持超长上下文:Claude 200K、Gemini 1M
- 开源模型也在追赶:很多模型已支持 32K-128K
长上下文能力对比
| 模型 | 上下文长度 | 技术特点 |
|---|---|---|
| Claude 3 | 200K | 未公开 |
| Gemini 1.5 | 1M | Ring Attention? |
| GPT-4 Turbo | 128K | 未公开 |
| Llama 3 | 8K (原生) | RoPE |
| Qwen2 | 32K/128K | YaRN |
| Mistral | 32K | Sliding Window |
参考资料
- RoPE 论文 - Rotary Position Embedding
- ALiBi 论文 - Attention with Linear Biases
- YaRN 论文 - Efficient Context Window Extension
- Flash Attention 2
- Streaming LLM