搜 索

长上下文:让大模型读完一本书

  • 2阅读
  • 2025年08月09日
  • 0评论
首页 / AI/大数据 / 正文

前言:上下文长度的进化

大模型上下文长度的演进:

模型上下文长度相当于
GPT-34K~3000 字
GPT-48K/32K~2.4万字
Claude 2100K~7.5万字
Claude 3200K一本小说
Gemini 1.51M多本书

为什么长上下文重要?

graph LR subgraph 短上下文 A1[长文档] --> B1[切分] B1 --> C1[分别处理] C1 --> D1[信息丢失] end subgraph 长上下文 A2[长文档] --> B2[完整输入] B2 --> C2[全局理解] C2 --> D2[精准回答] end style D1 fill:#ff6b6b style D2 fill:#4ecdc4

应用场景

  • 📚 阅读整本书并回答问题
  • 💻 分析整个代码仓库
  • 📊 处理长篇报告和论文
  • 🎬 理解长视频内容

一、为什么长上下文这么难?

1.1 Self-Attention 的复杂度

标准 Self-Attention 的计算复杂度是 O(n²)

# Self-Attention 计算
def self_attention(Q, K, V):
    # Q, K, V: (batch, seq_len, d)
    
    # 计算注意力分数:O(n²)
    scores = Q @ K.transpose(-2, -1)  # (batch, seq_len, seq_len)
    
    # Softmax
    attn = F.softmax(scores / sqrt(d), dim=-1)
    
    # 加权求和:O(n²)
    output = attn @ V
    
    return output

问题

  • 序列长度 4K → 注意力矩阵 4K × 4K = 1600 万
  • 序列长度 100K → 注意力矩阵 100K × 100K = 100 亿!
graph LR subgraph 复杂度增长 A["4K tokens
16M 计算"] --> B["32K tokens
1B 计算"] B --> C["100K tokens
10B 计算"] C --> D["1M tokens
1T 计算"] end style D fill:#ff6b6b

1.2 内存瓶颈

def calculate_attention_memory(seq_len, batch_size, num_heads, head_dim, dtype_bytes=2):
    """计算注意力矩阵的内存需求"""
    
    # 注意力矩阵: (batch, num_heads, seq_len, seq_len)
    attn_matrix = batch_size * num_heads * seq_len * seq_len * dtype_bytes
    
    # Q, K, V: (batch, num_heads, seq_len, head_dim)
    qkv = 3 * batch_size * num_heads * seq_len * head_dim * dtype_bytes
    
    total_bytes = attn_matrix + qkv
    total_gb = total_bytes / (1024 ** 3)
    
    return total_gb


# 示例计算
print(f"4K context: {calculate_attention_memory(4096, 1, 32, 128):.2f} GB")
print(f"32K context: {calculate_attention_memory(32768, 1, 32, 128):.2f} GB")
print(f"100K context: {calculate_attention_memory(100000, 1, 32, 128):.2f} GB")

# 输出:
# 4K context: 1.03 GB
# 32K context: 66.00 GB
# 100K context: 614.00 GB  ← 单张 GPU 装不下!

1.3 位置编码的外推问题

传统位置编码在超出训练长度时会失效:

# 绝对位置编码:训练时最长 4K
position_embeddings = nn.Embedding(4096, hidden_size)

# 推理时遇到 5000 的位置
pos = 5000  # 超出范围,会报错或效果很差

二、位置编码改进

2.1 位置编码类型

mindmap root((位置编码)) 绝对位置 Sinusoidal Learned 相对位置 ALiBi RoPE XPos 无位置编码 某些架构不需要

2.2 RoPE:旋转位置编码

RoPE(Rotary Position Embedding) 是目前最流行的位置编码:

import torch
import torch.nn as nn


class RotaryEmbedding(nn.Module):
    """RoPE 旋转位置编码"""
    
    def __init__(self, dim, max_seq_len=4096, base=10000):
        super().__init__()
        
        self.dim = dim
        self.max_seq_len = max_seq_len
        self.base = base
        
        # 计算频率
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)
        
        # 预计算
        self._set_cos_sin_cache(max_seq_len)
    
    def _set_cos_sin_cache(self, seq_len):
        """预计算 cos 和 sin"""
        t = torch.arange(seq_len, device=self.inv_freq.device)
        freqs = torch.outer(t, self.inv_freq)  # (seq_len, dim/2)
        
        # 复制一份拼接
        emb = torch.cat((freqs, freqs), dim=-1)  # (seq_len, dim)
        
        self.register_buffer("cos_cached", emb.cos())
        self.register_buffer("sin_cached", emb.sin())
    
    def forward(self, x, seq_len=None):
        """应用 RoPE"""
        if seq_len > self.max_seq_len:
            self._set_cos_sin_cache(seq_len)
        
        return (
            self.cos_cached[:seq_len],
            self.sin_cached[:seq_len],
        )


def rotate_half(x):
    """旋转一半"""
    x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin):
    """应用旋转位置编码到 Q 和 K"""
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

2.3 位置插值(Position Interpolation)

让短上下文训练的模型支持长上下文:

class ScaledRotaryEmbedding(RotaryEmbedding):
    """位置插值的 RoPE"""
    
    def __init__(self, dim, max_seq_len=4096, base=10000, scaling_factor=1.0):
        self.scaling_factor = scaling_factor
        super().__init__(dim, max_seq_len, base)
    
    def _set_cos_sin_cache(self, seq_len):
        # 关键:对位置进行缩放
        t = torch.arange(seq_len, device=self.inv_freq.device)
        t = t / self.scaling_factor  # 位置插值
        
        freqs = torch.outer(t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        
        self.register_buffer("cos_cached", emb.cos())
        self.register_buffer("sin_cached", emb.sin())


# 使用:将 4K 模型扩展到 16K
# scaling_factor = 16K / 4K = 4
rope = ScaledRotaryEmbedding(dim=64, max_seq_len=16384, scaling_factor=4.0)

2.4 NTK-aware 插值

更好的插值方法,调整 base 频率:

class NTKScaledRotaryEmbedding(RotaryEmbedding):
    """NTK-aware 位置插值"""
    
    def __init__(self, dim, max_seq_len=4096, base=10000, scaling_factor=1.0):
        # 调整 base 而不是位置
        new_base = base * (scaling_factor ** (dim / (dim - 2)))
        super().__init__(dim, max_seq_len, new_base)


# 或者使用动态 NTK
class DynamicNTKRotaryEmbedding(RotaryEmbedding):
    """动态 NTK 缩放"""
    
    def __init__(self, dim, max_seq_len=4096, base=10000, original_max_seq_len=4096):
        self.original_max_seq_len = original_max_seq_len
        super().__init__(dim, max_seq_len, base)
    
    def forward(self, x, seq_len=None):
        if seq_len > self.original_max_seq_len:
            # 动态调整 base
            scaling_factor = seq_len / self.original_max_seq_len
            new_base = self.base * (scaling_factor ** (self.dim / (self.dim - 2)))
            
            inv_freq = 1.0 / (new_base ** (torch.arange(0, self.dim, 2).float() / self.dim))
            # 重新计算...
        
        return super().forward(x, seq_len)

2.5 YaRN:Yet another RoPE extensioN

结合位置插值和 NTK 的优点:

class YaRNRotaryEmbedding(RotaryEmbedding):
    """YaRN: 更好的 RoPE 扩展"""
    
    def __init__(
        self,
        dim,
        max_seq_len=4096,
        base=10000,
        original_max_seq_len=4096,
        beta_fast=32,
        beta_slow=1,
    ):
        self.original_max_seq_len = original_max_seq_len
        self.beta_fast = beta_fast
        self.beta_slow = beta_slow
        super().__init__(dim, max_seq_len, base)
    
    def _compute_yarn_scaling(self, seq_len):
        """计算 YaRN 缩放因子"""
        scale = seq_len / self.original_max_seq_len
        
        # 低频外推,高频插值
        dim_range = torch.arange(0, self.dim, 2).float()
        
        # 计算每个维度的缩放
        low_freq_factor = 1.0
        high_freq_factor = scale
        
        # 平滑过渡
        # ...
        
        return scaling_factors

2.6 ALiBi:无需训练的长度外推

ALiBi(Attention with Linear Biases):直接在注意力分数上加线性偏置

class ALiBiAttention(nn.Module):
    """ALiBi 注意力"""
    
    def __init__(self, num_heads):
        super().__init__()
        
        # 计算每个头的斜率
        # 斜率呈几何级数:2^(-8/n), 2^(-8*2/n), ...
        slopes = self._get_slopes(num_heads)
        self.register_buffer("slopes", slopes)
    
    def _get_slopes(self, num_heads):
        """计算 ALiBi 斜率"""
        def get_slopes_power_of_2(n):
            start = 2 ** (-(2 ** -(math.log2(n) - 3)))
            ratio = start
            return [start * (ratio ** i) for i in range(n)]
        
        if math.log2(num_heads).is_integer():
            return torch.tensor(get_slopes_power_of_2(num_heads))
        else:
            # 处理非 2 的幂次
            closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
            slopes = get_slopes_power_of_2(closest_power_of_2)
            # ...
            return torch.tensor(slopes)
    
    def forward(self, q, k, v):
        batch, num_heads, seq_len, head_dim = q.shape
        
        # 标准注意力分数
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(head_dim)
        
        # ALiBi 偏置
        positions = torch.arange(seq_len, device=q.device)
        relative_positions = positions.unsqueeze(0) - positions.unsqueeze(1)
        alibi_bias = self.slopes.view(1, num_heads, 1, 1) * relative_positions.abs().unsqueeze(0).unsqueeze(0)
        
        # 应用偏置
        scores = scores - alibi_bias
        
        # Softmax 和输出
        attn = F.softmax(scores, dim=-1)
        output = torch.matmul(attn, v)
        
        return output

三、高效注意力机制

3.1 稀疏注意力

只计算部分位置的注意力:

graph TB subgraph 全注意力 A1[每个位置] --> B1[看所有位置] B1 --> C1["O(n²)"] end subgraph 稀疏注意力 A2[每个位置] --> B2[只看部分位置] B2 --> C2["O(n√n) 或 O(n)"] end

常见稀疏模式

class SparseAttentionPatterns:
    """稀疏注意力模式"""
    
    @staticmethod
    def local_attention(seq_len, window_size):
        """局部注意力:只看周围的 token"""
        mask = torch.zeros(seq_len, seq_len, dtype=torch.bool)
        for i in range(seq_len):
            start = max(0, i - window_size)
            end = min(seq_len, i + window_size + 1)
            mask[i, start:end] = True
        return mask
    
    @staticmethod
    def strided_attention(seq_len, stride):
        """跨步注意力:每隔几个 token 看一个"""
        mask = torch.zeros(seq_len, seq_len, dtype=torch.bool)
        for i in range(seq_len):
            # 局部
            mask[i, max(0, i-stride):i+1] = True
            # 全局(每 stride 个位置)
            mask[i, ::stride] = True
        return mask
    
    @staticmethod
    def global_local_attention(seq_len, num_global_tokens, local_window):
        """全局 + 局部注意力"""
        mask = torch.zeros(seq_len, seq_len, dtype=torch.bool)
        
        # 全局 token 可以看所有位置
        mask[:num_global_tokens, :] = True
        mask[:, :num_global_tokens] = True
        
        # 其他 token 只看局部
        for i in range(num_global_tokens, seq_len):
            start = max(num_global_tokens, i - local_window)
            end = min(seq_len, i + local_window + 1)
            mask[i, start:end] = True
        
        return mask

3.2 Longformer 和 BigBird

class LongformerAttention(nn.Module):
    """Longformer 风格的注意力"""
    
    def __init__(self, hidden_size, num_heads, window_size=512, num_global_tokens=2):
        super().__init__()
        self.window_size = window_size
        self.num_global_tokens = num_global_tokens
        
        self.query = nn.Linear(hidden_size, hidden_size)
        self.key = nn.Linear(hidden_size, hidden_size)
        self.value = nn.Linear(hidden_size, hidden_size)
    
    def forward(self, hidden_states, global_attention_mask=None):
        batch_size, seq_len, _ = hidden_states.shape
        
        q = self.query(hidden_states)
        k = self.key(hidden_states)
        v = self.value(hidden_states)
        
        # 局部注意力(滑动窗口)
        local_output = self.sliding_window_attention(q, k, v)
        
        # 全局注意力(对于特殊 token)
        if global_attention_mask is not None:
            global_output = self.global_attention(q, k, v, global_attention_mask)
            # 合并
            output = local_output + global_output
        else:
            output = local_output
        
        return output
    
    def sliding_window_attention(self, q, k, v):
        """滑动窗口局部注意力"""
        # 使用高效实现(如 xformers)
        from xformers.ops import memory_efficient_attention
        
        # 创建局部注意力 mask
        # ...
        
        return output

3.3 Flash Attention

利用 GPU 内存层级优化注意力计算:

# Flash Attention 使用(已在第2篇详细介绍)
from flash_attn import flash_attn_func


def efficient_attention(q, k, v, causal=True):
    """使用 Flash Attention"""
    # q, k, v: (batch, seq_len, num_heads, head_dim)
    
    output = flash_attn_func(
        q, k, v,
        causal=causal,
        softmax_scale=1.0 / math.sqrt(q.shape[-1]),
    )
    
    return output


# Flash Attention 2 支持变长序列
from flash_attn import flash_attn_varlen_func

def variable_length_attention(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k):
    """变长序列的 Flash Attention"""
    output = flash_attn_varlen_func(
        q, k, v,
        cu_seqlens_q, cu_seqlens_k,
        max_seqlen_q, max_seqlen_k,
        causal=True,
    )
    return output

3.4 Ring Attention

分布式处理超长序列:

graph LR subgraph Ring Attention GPU1[GPU 1
Block 1] --> GPU2[GPU 2
Block 2] GPU2 --> GPU3[GPU 3
Block 3] GPU3 --> GPU4[GPU 4
Block 4] GPU4 --> GPU1 end
class RingAttention:
    """Ring Attention:分布式长序列处理"""
    
    def __init__(self, world_size, rank):
        self.world_size = world_size
        self.rank = rank
    
    def forward(self, q, k, v):
        """
        每个 GPU 持有序列的一部分
        通过环形通信完成完整注意力计算
        """
        local_seq_len = q.shape[1]
        
        # 初始化输出和 softmax 归一化项
        output = torch.zeros_like(q)
        lse = torch.full((q.shape[0], q.shape[1]), float('-inf'))  # log-sum-exp
        
        # 环形传递 K, V
        k_recv, v_recv = k, v
        
        for step in range(self.world_size):
            # 计算当前块的注意力
            block_output, block_lse = self.compute_block_attention(q, k_recv, v_recv)
            
            # 更新输出(online softmax)
            output, lse = self.update_output(output, lse, block_output, block_lse)
            
            # 环形传递:发送 K, V 给下一个 GPU
            k_recv = self.ring_send_recv(k_recv)
            v_recv = self.ring_send_recv(v_recv)
        
        return output
    
    def ring_send_recv(self, tensor):
        """环形通信"""
        send_to = (self.rank + 1) % self.world_size
        recv_from = (self.rank - 1) % self.world_size
        
        recv_tensor = torch.empty_like(tensor)
        
        # 异步发送和接收
        send_op = dist.isend(tensor, send_to)
        recv_op = dist.irecv(recv_tensor, recv_from)
        
        send_op.wait()
        recv_op.wait()
        
        return recv_tensor

四、长上下文训练技术

4.1 渐进式长度训练

def progressive_length_training(model, dataset, target_length=32768):
    """渐进式增加序列长度"""
    
    length_stages = [4096, 8192, 16384, 32768]
    
    for stage_length in length_stages:
        print(f"Training with context length: {stage_length}")
        
        # 调整数据
        stage_dataset = dataset.filter(lambda x: len(x['text']) <= stage_length)
        
        # 调整位置编码
        if hasattr(model, 'rope'):
            model.rope.max_seq_len = stage_length
        
        # 训练
        trainer = Trainer(
            model=model,
            train_dataset=stage_dataset,
            args=TrainingArguments(
                max_seq_length=stage_length,
                # 可能需要减小 batch size
                per_device_train_batch_size=max(1, 32 // (stage_length // 4096)),
                gradient_accumulation_steps=stage_length // 4096,
            ),
        )
        
        trainer.train()

4.2 长度外推微调

def length_extrapolation_finetune(
    model,
    original_length=4096,
    target_length=32768,
    num_steps=1000,
):
    """长度外推微调"""
    
    # 1. 调整位置编码
    scaling_factor = target_length / original_length
    
    # 使用 YaRN 或 NTK 插值
    for layer in model.layers:
        layer.self_attn.rotary_emb = YaRNRotaryEmbedding(
            dim=layer.self_attn.head_dim,
            max_seq_len=target_length,
            original_max_seq_len=original_length,
        )
    
    # 2. 准备长文本数据
    long_dataset = prepare_long_context_data(target_length)
    
    # 3. 微调(通常只需要少量步骤)
    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
    
    for step in range(num_steps):
        batch = next(iter(long_dataset))
        
        outputs = model(**batch)
        loss = outputs.loss
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        if step % 100 == 0:
            print(f"Step {step}, Loss: {loss.item():.4f}")

4.3 数据准备

def prepare_long_context_data(max_length=32768):
    """准备长上下文训练数据"""
    
    from datasets import load_dataset
    
    # 使用长文本数据集
    datasets_to_use = [
        "emozilla/pg19",           # 书籍
        "togethercomputer/Long-Data-Collections",  # 长文本集合
        "THUDM/LongBench",         # 长文本基准
    ]
    
    all_texts = []
    
    for ds_name in datasets_to_use:
        ds = load_dataset(ds_name, split="train")
        
        for item in ds:
            text = item.get('text', item.get('content', ''))
            if len(text) >= max_length // 2:  # 保留足够长的文本
                all_texts.append(text[:max_length])
    
    return all_texts


def create_long_context_examples(texts, tokenizer, max_length):
    """创建长上下文训练样本"""
    
    examples = []
    
    for text in texts:
        tokens = tokenizer.encode(text, truncation=True, max_length=max_length)
        
        if len(tokens) >= max_length * 0.8:  # 至少 80% 长度
            examples.append({
                'input_ids': tokens,
                'attention_mask': [1] * len(tokens),
            })
    
    return examples

五、长上下文推理优化

5.1 KV Cache 压缩

class CompressedKVCache:
    """压缩 KV Cache"""
    
    def __init__(self, compression_ratio=4):
        self.compression_ratio = compression_ratio
        self.cache = {}
    
    def compress(self, k, v, layer_idx):
        """压缩 KV Cache"""
        batch, num_heads, seq_len, head_dim = k.shape
        
        if seq_len <= 1024:
            # 短序列不压缩
            self.cache[layer_idx] = (k, v)
            return k, v
        
        # 保留最近的 token(full resolution)
        recent_len = seq_len // self.compression_ratio
        
        # 压缩旧的 token(average pooling)
        old_len = seq_len - recent_len
        compressed_old_len = old_len // self.compression_ratio
        
        k_old = k[:, :, :old_len, :].view(
            batch, num_heads, compressed_old_len, self.compression_ratio, head_dim
        ).mean(dim=3)
        
        v_old = v[:, :, :old_len, :].view(
            batch, num_heads, compressed_old_len, self.compression_ratio, head_dim
        ).mean(dim=3)
        
        # 拼接
        k_compressed = torch.cat([k_old, k[:, :, old_len:, :]], dim=2)
        v_compressed = torch.cat([v_old, v[:, :, old_len:, :]], dim=2)
        
        self.cache[layer_idx] = (k_compressed, v_compressed)
        
        return k_compressed, v_compressed

5.2 Streaming LLM

保持固定长度的 KV Cache:

class StreamingLLM:
    """Streaming LLM:无限长度生成"""
    
    def __init__(self, model, window_size=4096, sink_size=4):
        self.model = model
        self.window_size = window_size
        self.sink_size = sink_size  # 始终保留的开头 token
    
    def generate(self, input_ids, max_new_tokens=1000):
        """流式生成"""
        
        kv_cache = None
        generated_ids = input_ids.clone()
        
        for _ in range(max_new_tokens):
            # 获取当前输入
            if kv_cache is not None:
                # 只输入最后一个 token
                current_input = generated_ids[:, -1:]
            else:
                current_input = generated_ids
            
            # 前向传播
            outputs = self.model(
                input_ids=current_input,
                past_key_values=kv_cache,
                use_cache=True,
            )
            
            kv_cache = outputs.past_key_values
            
            # 压缩 KV Cache(保持固定长度)
            kv_cache = self.compress_kv_cache(kv_cache)
            
            # 采样下一个 token
            next_token = self.sample(outputs.logits[:, -1, :])
            generated_ids = torch.cat([generated_ids, next_token], dim=1)
            
            if next_token.item() == self.model.config.eos_token_id:
                break
        
        return generated_ids
    
    def compress_kv_cache(self, kv_cache):
        """压缩 KV Cache 到固定长度"""
        
        compressed_cache = []
        
        for layer_kv in kv_cache:
            k, v = layer_kv
            seq_len = k.shape[2]
            
            if seq_len <= self.window_size:
                compressed_cache.append((k, v))
                continue
            
            # 保留 sink tokens + 最近的 tokens
            k_sink = k[:, :, :self.sink_size, :]
            v_sink = v[:, :, :self.sink_size, :]
            
            k_recent = k[:, :, -(self.window_size - self.sink_size):, :]
            v_recent = v[:, :, -(self.window_size - self.sink_size):, :]
            
            k_compressed = torch.cat([k_sink, k_recent], dim=2)
            v_compressed = torch.cat([v_sink, v_recent], dim=2)
            
            compressed_cache.append((k_compressed, v_compressed))
        
        return compressed_cache

5.3 分块处理

def chunked_generation(model, tokenizer, long_prompt, chunk_size=4096):
    """分块处理长 prompt"""
    
    tokens = tokenizer.encode(long_prompt)
    total_len = len(tokens)
    
    if total_len <= chunk_size:
        # 短文本直接处理
        return model.generate(torch.tensor([tokens]))
    
    # 分块处理
    kv_cache = None
    
    for i in range(0, total_len, chunk_size):
        chunk = tokens[i:i+chunk_size]
        
        with torch.no_grad():
            outputs = model(
                input_ids=torch.tensor([chunk]),
                past_key_values=kv_cache,
                use_cache=True,
            )
            kv_cache = outputs.past_key_values
    
    # 生成
    generated = model.generate(
        input_ids=torch.tensor([[tokens[-1]]]),  # 最后一个 token
        past_key_values=kv_cache,
        max_new_tokens=500,
    )
    
    return generated

六、实战:使用长上下文模型

6.1 使用 Claude 3 (200K)

import anthropic


def use_claude_long_context():
    """使用 Claude 3 的长上下文能力"""
    
    client = anthropic.Anthropic()
    
    # 读取长文档
    with open("long_document.txt", "r") as f:
        long_text = f.read()  # 假设是一本小说
    
    # 构造 prompt
    prompt = f"""以下是一本完整的小说:

{long_text}

请回答以下问题:
1. 这本书的主要人物有哪些?
2. 故事的主线是什么?
3. 结局如何?
"""
    
    response = client.messages.create(
        model="claude-3-opus-20240229",
        max_tokens=4096,
        messages=[
            {"role": "user", "content": prompt}
        ],
    )
    
    print(response.content[0].text)

6.2 使用 Gemini 1.5 (1M)

import google.generativeai as genai


def use_gemini_long_context():
    """使用 Gemini 1.5 的超长上下文"""
    
    genai.configure(api_key="your-api-key")
    model = genai.GenerativeModel('gemini-1.5-pro')
    
    # 可以处理多本书、多个视频
    long_content = load_multiple_documents()  # 加载大量内容
    
    response = model.generate_content([
        long_content,
        "基于以上所有内容,进行综合分析..."
    ])
    
    print(response.text)

6.3 本地部署长上下文模型

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch


def load_long_context_model():
    """加载支持长上下文的开源模型"""
    
    # 使用支持长上下文的模型
    model_name = "gradientai/Llama-3-8B-Instruct-262k"  # 262K 上下文
    # 或者 "Qwen/Qwen2-7B-Instruct"  # 支持 32K
    
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16,
        device_map="auto",
        # 可能需要启用 Flash Attention
        attn_implementation="flash_attention_2",
    )
    
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    return model, tokenizer


def process_long_document(model, tokenizer, document, question):
    """处理长文档"""
    
    prompt = f"""Document:
{document}

Question: {question}

Answer:"""
    
    inputs = tokenizer(prompt, return_tensors="pt", truncation=False)
    
    print(f"Input length: {inputs.input_ids.shape[1]} tokens")
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs.to(model.device),
            max_new_tokens=500,
            temperature=0.7,
        )
    
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return response

七、总结

长上下文核心要点

mindmap root((长上下文)) 挑战 O(n²)复杂度 内存爆炸 位置编码外推 位置编码改进 RoPE 位置插值 YaRN ALiBi 高效注意力 稀疏注意力 Flash Attention Ring Attention 推理优化 KV Cache 压缩 Streaming LLM 分块处理

关键 Takeaway

  1. 长上下文的核心挑战是 O(n²) 复杂度:需要高效注意力机制
  2. 位置编码很关键:RoPE + 插值/YaRN 可以外推到更长
  3. Flash Attention 是基础:大幅降低内存使用
  4. 推理时需要 KV Cache 优化:Streaming LLM、压缩等
  5. 商业模型已支持超长上下文:Claude 200K、Gemini 1M
  6. 开源模型也在追赶:很多模型已支持 32K-128K

长上下文能力对比

模型上下文长度技术特点
Claude 3200K未公开
Gemini 1.51MRing Attention?
GPT-4 Turbo128K未公开
Llama 38K (原生)RoPE
Qwen232K/128KYaRN
Mistral32KSliding Window

参考资料

  1. RoPE 论文 - Rotary Position Embedding
  2. ALiBi 论文 - Attention with Linear Biases
  3. YaRN 论文 - Efficient Context Window Extension
  4. Flash Attention 2
  5. Streaming LLM

评论区
暂无评论
avatar