搜 索

预训练:如何喂出一个大模型

  • 4阅读
  • 2025年03月08日
  • 0评论
首页 / AI/大数据 / 正文

前言:从"白纸"到"博学"

一个刚初始化的神经网络,参数是随机的,什么都不会。

经过预训练后,它能:

  • 写诗、写代码、写论文
  • 回答问题、做数学题
  • 理解上下文、遵循指令

这个从"白纸"到"博学"的过程,就是 预训练(Pre-training)

graph LR A[随机初始化的模型] --> B[海量文本数据] B --> C[预训练] C --> D[预训练模型] D --> E[具备语言理解能力] D --> F[具备知识储备] D --> G[具备推理能力] style A fill:#ff6b6b style D fill:#4ecdc4

本文我们来搞懂:

  • 预训练的目标是什么?
  • 需要什么样的数据?
  • 训练过程是怎样的?
  • Scaling Law 是怎么回事?
  • 为什么预测下一个词就能产生"智能"?

一、预训练的本质:压缩即智能

1.1 预训练任务回顾

上一篇我们讲过,GPT 的预训练任务是 预测下一个 Token

$$ \mathcal{L} = -\sum_{t=1}^{T} \log P(x_t | x_1, x_2, ..., x_{t-1}) $$

用人话说:给定前面的词,预测下一个词。预测对了,损失就小;预测错了,损失就大。

# 预训练的核心:Next Token Prediction
def compute_loss(model, input_ids):
    """
    input_ids: [今天, 天气, 真, 好]
    
    模型需要预测:
    - 给定 [今天],预测 [天气]
    - 给定 [今天, 天气],预测 [真]
    - 给定 [今天, 天气, 真],预测 [好]
    """
    # 输入是 [:-1],目标是 [1:]
    inputs = input_ids[:, :-1]   # [今天, 天气, 真]
    targets = input_ids[:, 1:]   # [天气, 真, 好]
    
    logits = model(inputs)       # 模型预测
    loss = cross_entropy(logits, targets)
    
    return loss

1.2 为什么预测下一个词能产生"智能"?

这个问题困扰了很多人:不就是预测下一个词吗?怎么就"智能"了?

Ilya Sutskever(OpenAI 联合创始人)的观点

"预测下一个词,本质上是在做数据压缩。要完美预测下一个词,你需要理解这段文本在说什么,需要具备相关的世界知识,需要进行逻辑推理。"

举个例子:

文本: "中国的首都是___"

要正确预测"北京",模型需要:
1. 理解"首都"的概念
2. 知道中国的地理知识
3. 理解问句的结构

再看一个更复杂的:

文本: "小明有 3 个苹果,小红给了他 2 个,现在小明有___"

要正确预测"5",模型需要:
1. 理解数学加法
2. 跟踪实体状态(小明的苹果数)
3. 理解"给"这个动作的含义

压缩即智能假说

如果一个模型能完美预测任意文本的下一个词,它必须:

  • 理解语法和语义
  • 具备世界知识
  • 能进行推理
  • 理解因果关系

这些能力加在一起,就是我们所说的"智能"。

graph TB subgraph 预测下一个词需要的能力 A[语法理解] --> P[准确预测] B[语义理解] --> P C[世界知识] --> P D[逻辑推理] --> P E[常识推断] --> P F[上下文记忆] --> P end P --> I[这些能力的总和 ≈ 智能] style I fill:#4ecdc4

1.3 预训练 vs 从头训练

在预训练范式之前,NLP 是这样做的:

graph TB subgraph 传统方式 T1[任务1数据] --> M1[模型1] T2[任务2数据] --> M2[模型2] T3[任务3数据] --> M3[模型3] end subgraph 预训练范式 D[海量无标注数据] --> PT[预训练] PT --> Base[基础模型] Base --> FT1[微调→任务1] Base --> FT2[微调→任务2] Base --> FT3[微调→任务3] end style Base fill:#4ecdc4

预训练的优势:

  1. 利用无标注数据:标注数据贵,无标注数据免费
  2. 知识迁移:预训练学到的知识可以迁移到各种任务
  3. 数据效率:下游任务只需要少量数据微调

二、预训练数据:大模型的"粮食"

2.1 数据决定上限

有一句话在 AI 圈很流行:

"数据决定了模型的上限,算法只是逼近这个上限。"

预训练数据的质量和多样性,直接决定了模型的能力边界。

2.2 主要数据来源

pie showData title 典型预训练数据配比(参考 LLaMA) "网页数据 (Common Crawl)" : 67 "代码 (GitHub)" : 4.5 "Wikipedia" : 4.5 "书籍" : 4.5 "学术论文 (ArXiv)" : 2.5 "问答数据 (StackExchange)" : 2 "其他" : 15

各类数据的特点

数据源规模质量特点
Common Crawl超大(PB级)参差不齐覆盖广,需要大量清洗
Wikipedia中等结构化知识,事实准确
书籍中等长文本,逻辑连贯
GitHub 代码中高编程能力的关键
学术论文中等专业知识,推理能力
社交媒体超大口语化,噪音多

2.3 数据处理流程

原始数据不能直接用,需要经过一系列处理:

flowchart TB A[原始数据] --> B[语言识别] B --> C[质量过滤] C --> D[去重] D --> E[敏感信息过滤] E --> F[格式标准化] F --> G[Tokenization] G --> H[打包成训练样本] subgraph 质量过滤 C1[长度过滤] C2[困惑度过滤] C3[规则过滤] C4[分类器过滤] end subgraph 去重 D1[精确去重] D2[模糊去重 MinHash] D3[跨文档去重] end

2.4 数据质量过滤

def quality_filter(text):
    """数据质量过滤示例"""
    
    # 1. 长度过滤
    if len(text) < 100:
        return False, "too_short"
    if len(text) > 100000:
        return False, "too_long"
    
    # 2. 语言检测
    if not is_target_language(text, target="en"):
        return False, "wrong_language"
    
    # 3. 特殊字符比例
    special_ratio = count_special_chars(text) / len(text)
    if special_ratio > 0.3:
        return False, "too_many_special_chars"
    
    # 4. 重复内容检测
    if has_excessive_repetition(text):
        return False, "repetitive"
    
    # 5. 广告/垃圾内容检测
    if is_spam_or_ad(text):
        return False, "spam"
    
    # 6. 困惑度过滤(用小模型打分)
    perplexity = compute_perplexity(text, reference_model)
    if perplexity > 1000:
        return False, "high_perplexity"
    
    return True, "passed"


def has_excessive_repetition(text, threshold=0.3):
    """检测过度重复"""
    lines = text.split('\n')
    unique_lines = set(lines)
    
    if len(unique_lines) / len(lines) < threshold:
        return True
    
    # 检测 n-gram 重复
    words = text.split()
    ngrams = [tuple(words[i:i+5]) for i in range(len(words)-4)]
    unique_ngrams = set(ngrams)
    
    if len(ngrams) > 0 and len(unique_ngrams) / len(ngrams) < threshold:
        return True
    
    return False

2.5 数据去重

去重非常重要——重复数据会导致:

  • 模型"记住"而非"理解"
  • 测试集污染
  • 训练效率下降
from datasketch import MinHash, MinHashLSH

def deduplicate_minhash(documents, threshold=0.8, num_perm=128):
    """
    使用 MinHash LSH 进行模糊去重
    """
    # 创建 LSH 索引
    lsh = MinHashLSH(threshold=threshold, num_perm=num_perm)
    
    minhashes = {}
    duplicates = set()
    
    for doc_id, doc in enumerate(documents):
        # 创建 MinHash
        mh = MinHash(num_perm=num_perm)
        
        # 使用 5-gram 作为特征
        words = doc.split()
        for i in range(len(words) - 4):
            ngram = ' '.join(words[i:i+5])
            mh.update(ngram.encode('utf-8'))
        
        # 查询相似文档
        similar = lsh.query(mh)
        
        if similar:
            duplicates.add(doc_id)
        else:
            lsh.insert(doc_id, mh)
            minhashes[doc_id] = mh
    
    # 返回去重后的文档
    return [doc for i, doc in enumerate(documents) if i not in duplicates]


# 使用示例
documents = [
    "This is the first document about machine learning.",
    "This is the first document about machine learning.",  # 完全重复
    "This is the first document about deep learning.",     # 相似
    "A completely different document about cooking.",      # 不同
]

unique_docs = deduplicate_minhash(documents, threshold=0.8)
print(f"去重前: {len(documents)}, 去重后: {len(unique_docs)}")

2.6 数据配比的艺术

不同类型数据的配比会显著影响模型能力:

# LLaMA 的数据配比(参考)
data_mixture = {
    "CommonCrawl": 0.67,      # 通用网页
    "C4": 0.15,               # 清洗后的网页
    "GitHub": 0.045,          # 代码
    "Wikipedia": 0.045,       # 百科
    "Books": 0.045,           # 书籍
    "ArXiv": 0.025,           # 论文
    "StackExchange": 0.02,    # 问答
}

# 代码数据的重要性
# - 提升逻辑推理能力
# - 提升指令遵循能力
# - 结构化思维

# 高质量数据的重要性
# - Wikipedia 虽然只占 4.5%,但对知识准确性影响很大
# - 书籍数据提升长文本理解能力

三、Scaling Law:大力出奇迹的科学

3.1 什么是 Scaling Law?

2020 年,OpenAI 发表了著名的 Scaling Law 论文,发现了一个惊人的规律:

模型性能可以用参数量、数据量、计算量精确预测!

$$ L(N, D, C) = \left(\frac{N_c}{N}\right)^{\alpha_N} + \left(\frac{D_c}{D}\right)^{\alpha_D} + L_\infty $$

其中:

  • $L$ 是测试损失(越低越好)
  • $N$ 是模型参数量
  • $D$ 是训练数据量
  • $C$ 是计算量(FLOPs)
  • $\alpha_N, \alpha_D$ 是幂指数(约 0.076 和 0.095)
  • $L_\infty$ 是不可降低的损失下限

3.2 Scaling Law 的含义

graph TB subgraph Scaling Law A[参数量 ↑] --> D[损失 ↓] B[数据量 ↑] --> D C[计算量 ↑] --> D end D --> E[能力 ↑] E --> F[更强的语言理解] E --> G[更多的知识] E --> H[更好的推理]

关键发现

  1. 性能提升是平滑的:10 倍参数 → 可预测的性能提升
  2. 没有明显的"天花板":目前还没看到尽头
  3. 三者可以互换:更多参数 or 更多数据 or 更多计算

3.3 Chinchilla Scaling Law

2022 年,DeepMind 的 Chinchilla 论文修正了 Scaling Law:

核心发现:之前的模型"太大了",数据"太少了"。

最优配比:

$$ N_{opt} \propto C^{0.5}, \quad D_{opt} \propto C^{0.5} $$

这意味着:参数量和数据量应该同比例增长

模型参数量训练 Token参数/Token
GPT-3175B300B0.58
Chinchilla70B1.4T0.05
LLaMA65B1.4T0.046
LLaMA 270B2T0.035

Chinchilla 的启示

  • 用更少的参数、更多的数据,可以达到同样的效果
  • 推理成本更低(参数少)
  • 训练时间可能更长(数据多)

3.4 计算量估算

预训练需要多少算力?

FLOPs 估算公式

$$ C \approx 6 \times N \times D $$

其中 $N$ 是参数量,$D$ 是训练 token 数。

def estimate_training_flops(params_billions, tokens_billions):
    """估算训练所需的 FLOPs"""
    N = params_billions * 1e9
    D = tokens_billions * 1e9
    
    flops = 6 * N * D
    
    return flops


def estimate_training_cost(flops, gpu_type="A100"):
    """估算训练成本"""
    # GPU 性能(FLOPs/秒)和成本
    gpu_specs = {
        "A100": {"tflops": 312e12, "cost_per_hour": 3.0},
        "H100": {"tflops": 1000e12, "cost_per_hour": 5.0},
    }
    
    spec = gpu_specs[gpu_type]
    
    # 假设 40% 利用率(实际可能更低)
    effective_tflops = spec["tflops"] * 0.4
    
    # 训练时间(秒)
    training_seconds = flops / effective_tflops
    training_hours = training_seconds / 3600
    
    # GPU 数量假设(1000 张)
    num_gpus = 1000
    wall_clock_hours = training_hours / num_gpus
    
    # 成本
    total_cost = num_gpus * wall_clock_hours * spec["cost_per_hour"]
    
    return {
        "total_flops": flops,
        "gpu_hours": training_hours,
        "wall_clock_hours": wall_clock_hours,
        "wall_clock_days": wall_clock_hours / 24,
        "estimated_cost": total_cost,
        "num_gpus": num_gpus,
    }


# 估算 LLaMA-70B 的训练成本
flops = estimate_training_flops(params_billions=70, tokens_billions=2000)
cost = estimate_training_cost(flops, gpu_type="A100")

print(f"训练 LLaMA-70B (2T tokens):")
print(f"  总 FLOPs: {cost['total_flops']:.2e}")
print(f"  GPU 小时: {cost['gpu_hours']:,.0f}")
print(f"  使用 {cost['num_gpus']} 张 A100:")
print(f"    - 训练天数: {cost['wall_clock_days']:.1f} 天")
print(f"    - 预估成本: ${cost['estimated_cost']:,.0f}")

输出:

训练 LLaMA-70B (2T tokens):
  总 FLOPs: 8.40e+23
  GPU 小时: 6,730,769
  使用 1000 张 A100:
    - 训练天数: 280.4 天
    - 预估成本: $20,192,308

注意:实际成本可能更高(考虑通信开销、失败重试等)。

3.5 涌现能力

当模型足够大时,会出现一些"涌现能力"——小模型完全不会,大模型突然就会了。

graph TB subgraph 涌现能力示例 A[In-Context Learning
从例子中学习] B[Chain-of-Thought
思维链推理] C[Instruction Following
指令遵循] D[Code Generation
代码生成] end Small[小模型
< 10B] -.->|不具备| A Large[大模型
> 100B] -->|突然具备| A style Large fill:#4ecdc4

涌现能力的特点:

  • 非线性:不是逐渐变好,而是突然出现
  • 不可预测:无法提前知道会涌现什么能力
  • 规模依赖:只在足够大的模型上出现

四、预训练的工程实践

4.1 训练框架选择

框架特点适用场景
DeepSpeed微软出品,ZeRO 优化通用大模型训练
Megatron-LMNVIDIA 出品,极致优化超大规模(100B+)
FSDPPyTorch 原生中小规模,易用
Colossal-AI国产,一站式快速上手

4.2 并行策略

训练大模型需要多种并行策略组合:

graph TB subgraph 并行策略 DP[数据并行
Data Parallel] TP[张量并行
Tensor Parallel] PP[流水线并行
Pipeline Parallel] end DP --> 3D[3D 并行] TP --> 3D PP --> 3D 3D --> Train[大模型训练]

简单解释

  • 数据并行:每张卡有完整模型,数据不同
  • 张量并行:把矩阵切开,每张卡算一部分
  • 流水线并行:把层切开,不同卡负责不同层

4.3 混合精度训练

使用 FP16/BF16 代替 FP32,可以:

  • 显存减半
  • 训练加速 2-3x
  • 几乎不损失精度
import torch
from torch.cuda.amp import autocast, GradScaler

def train_with_mixed_precision(model, dataloader, optimizer):
    """混合精度训练"""
    scaler = GradScaler()
    
    for batch in dataloader:
        optimizer.zero_grad()
        
        # 自动混合精度
        with autocast():
            loss = model(batch)
        
        # 缩放梯度,防止下溢
        scaler.scale(loss).backward()
        
        # 更新参数
        scaler.step(optimizer)
        scaler.update()

4.4 训练稳定性

大模型训练最怕的就是 训练崩溃——几天的训练白费。

常见问题和解决方案

# 1. 梯度裁剪:防止梯度爆炸
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

# 2. 学习率预热:开始时用小学习率
def get_lr(step, warmup_steps=2000, max_lr=3e-4):
    if step < warmup_steps:
        return max_lr * step / warmup_steps
    else:
        # Cosine decay
        progress = (step - warmup_steps) / (total_steps - warmup_steps)
        return max_lr * 0.5 * (1 + math.cos(math.pi * progress))

# 3. Loss Spike 处理
def handle_loss_spike(loss, loss_history, threshold=10.0):
    if len(loss_history) > 0:
        avg_loss = sum(loss_history[-100:]) / len(loss_history[-100:])
        if loss > avg_loss * threshold:
            print(f"Loss spike detected: {loss:.4f} vs avg {avg_loss:.4f}")
            # 可以选择:跳过这个 batch、降低学习率、回滚到上一个 checkpoint
            return True
    return False

# 4. Checkpoint 策略
def save_checkpoint(model, optimizer, step, loss):
    checkpoint = {
        'step': step,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }
    
    # 保存多个 checkpoint,以防最新的损坏
    torch.save(checkpoint, f'checkpoint_step_{step}.pt')
    
    # 只保留最近 5 个
    cleanup_old_checkpoints(keep=5)

4.5 学习率调度

import math

class CosineAnnealingWithWarmup:
    """带预热的余弦退火学习率"""
    
    def __init__(self, optimizer, warmup_steps, total_steps, min_lr=1e-5, max_lr=3e-4):
        self.optimizer = optimizer
        self.warmup_steps = warmup_steps
        self.total_steps = total_steps
        self.min_lr = min_lr
        self.max_lr = max_lr
        self.current_step = 0
    
    def step(self):
        self.current_step += 1
        lr = self.get_lr()
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
    
    def get_lr(self):
        if self.current_step < self.warmup_steps:
            # 线性预热
            return self.max_lr * self.current_step / self.warmup_steps
        else:
            # 余弦退火
            progress = (self.current_step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
            return self.min_lr + 0.5 * (self.max_lr - self.min_lr) * (1 + math.cos(math.pi * progress))


# 可视化学习率曲线
import matplotlib.pyplot as plt

def plot_lr_schedule(warmup_steps=2000, total_steps=100000):
    scheduler = CosineAnnealingWithWarmup(
        optimizer=None,  # 仅用于可视化
        warmup_steps=warmup_steps,
        total_steps=total_steps
    )
    
    lrs = []
    for _ in range(total_steps):
        lrs.append(scheduler.get_lr())
        scheduler.current_step += 1
    
    plt.figure(figsize=(10, 4))
    plt.plot(lrs)
    plt.xlabel('Step')
    plt.ylabel('Learning Rate')
    plt.title('Cosine Annealing with Warmup')
    plt.axvline(x=warmup_steps, color='r', linestyle='--', label='Warmup End')
    plt.legend()
    plt.savefig('lr_schedule.png')

五、实战:从零预训练一个小模型

让我们从头开始预训练一个小型语言模型。

5.1 完整训练代码

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import math
from tqdm import tqdm


# ========== 模型定义 ==========

class GPTConfig:
    """模型配置"""
    vocab_size: int = 50257
    block_size: int = 1024    # 最大序列长度
    n_layer: int = 12
    n_head: int = 12
    n_embd: int = 768
    dropout: float = 0.1


class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)
        
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        
        # 因果掩码
        self.register_buffer(
            "bias",
            torch.tril(torch.ones(config.block_size, config.block_size))
            .view(1, 1, config.block_size, config.block_size)
        )
    
    def forward(self, x):
        B, T, C = x.size()
        
        # 一次性计算 Q, K, V
        q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
        
        # 分头
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        
        # 注意力计算
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        att = self.attn_dropout(att)
        
        y = att @ v
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.resid_dropout(self.c_proj(y))
        
        return y


class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
        self.gelu = nn.GELU()
        self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
        self.dropout = nn.Dropout(config.dropout)
    
    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x


class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = MLP(config)
    
    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x


class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        
        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            wpe = nn.Embedding(config.block_size, config.n_embd),
            drop = nn.Dropout(config.dropout),
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f = nn.LayerNorm(config.n_embd),
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        
        # 权重共享
        self.transformer.wte.weight = self.lm_head.weight
        
        # 初始化
        self.apply(self._init_weights)
        
        # 参数量统计
        n_params = sum(p.numel() for p in self.parameters())
        print(f"模型参数量: {n_params/1e6:.2f}M")
    
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
    
    def forward(self, idx, targets=None):
        device = idx.device
        b, t = idx.size()
        assert t <= self.config.block_size
        
        pos = torch.arange(0, t, dtype=torch.long, device=device)
        
        tok_emb = self.transformer.wte(idx)
        pos_emb = self.transformer.wpe(pos)
        x = self.transformer.drop(tok_emb + pos_emb)
        
        for block in self.transformer.h:
            x = block(x)
        
        x = self.transformer.ln_f(x)
        logits = self.lm_head(x)
        
        loss = None
        if targets is not None:
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                targets.view(-1),
                ignore_index=-1
            )
        
        return logits, loss


# ========== 数据集 ==========

class TextDataset(Dataset):
    def __init__(self, data, block_size):
        self.data = data
        self.block_size = block_size
    
    def __len__(self):
        return len(self.data) - self.block_size
    
    def __getitem__(self, idx):
        chunk = self.data[idx:idx + self.block_size + 1]
        x = torch.tensor(chunk[:-1], dtype=torch.long)
        y = torch.tensor(chunk[1:], dtype=torch.long)
        return x, y


# ========== 训练循环 ==========

def train(model, train_loader, optimizer, scheduler, config, device):
    model.train()
    total_loss = 0
    
    pbar = tqdm(train_loader, desc="Training")
    for batch_idx, (x, y) in enumerate(pbar):
        x, y = x.to(device), y.to(device)
        
        # 前向传播
        logits, loss = model(x, y)
        
        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        
        # 梯度裁剪
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        
        # 更新参数
        optimizer.step()
        scheduler.step()
        
        total_loss += loss.item()
        
        # 更新进度条
        pbar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'lr': f'{scheduler.get_lr():.2e}'
        })
    
    return total_loss / len(train_loader)


@torch.no_grad()
def generate(model, idx, max_new_tokens, temperature=1.0, top_k=None):
    """自回归生成"""
    model.eval()
    
    for _ in range(max_new_tokens):
        # 截断到 block_size
        idx_cond = idx if idx.size(1) <= model.config.block_size else idx[:, -model.config.block_size:]
        
        logits, _ = model(idx_cond)
        logits = logits[:, -1, :] / temperature
        
        # Top-K 采样
        if top_k is not None:
            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
            logits[logits < v[:, [-1]]] = float('-inf')
        
        probs = F.softmax(logits, dim=-1)
        idx_next = torch.multinomial(probs, num_samples=1)
        idx = torch.cat([idx, idx_next], dim=1)
    
    return idx


# ========== 主函数 ==========

def main():
    # 配置
    config = GPTConfig()
    config.vocab_size = 50257
    config.block_size = 256    # 小一点,便于训练
    config.n_layer = 6         # 6 层
    config.n_head = 6
    config.n_embd = 384
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"使用设备: {device}")
    
    # 加载数据(这里用随机数据演示)
    # 实际应该用真实文本数据
    print("准备数据...")
    data = torch.randint(0, config.vocab_size, (1000000,)).tolist()
    
    dataset = TextDataset(data, config.block_size)
    train_loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)
    
    # 创建模型
    print("创建模型...")
    model = GPT(config).to(device)
    
    # 优化器
    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, betas=(0.9, 0.95), weight_decay=0.1)
    
    # 学习率调度
    total_steps = len(train_loader) * 10  # 10 个 epoch
    scheduler = CosineAnnealingWithWarmup(
        optimizer, 
        warmup_steps=1000,
        total_steps=total_steps
    )
    
    # 训练
    print("开始训练...")
    for epoch in range(10):
        loss = train(model, train_loader, optimizer, scheduler, config, device)
        print(f"Epoch {epoch+1}, Loss: {loss:.4f}")
        
        # 生成示例
        prompt = torch.tensor([[0]], dtype=torch.long, device=device)  # 起始 token
        generated = generate(model, prompt, max_new_tokens=50, temperature=0.8)
        print(f"Generated: {generated[0].tolist()[:20]}...")
        
        # 保存 checkpoint
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
        }, f'checkpoint_epoch_{epoch}.pt')


if __name__ == "__main__":
    main()

5.2 使用真实数据训练

import tiktoken
from datasets import load_dataset

def prepare_real_data():
    """准备真实训练数据"""
    
    # 加载数据集
    dataset = load_dataset("wikitext", "wikitext-103-raw-v1", split="train")
    
    # 使用 GPT-2 tokenizer
    enc = tiktoken.get_encoding("gpt2")
    
    # Tokenize 所有文本
    all_tokens = []
    for item in tqdm(dataset, desc="Tokenizing"):
        text = item['text']
        if text.strip():  # 跳过空行
            tokens = enc.encode(text)
            all_tokens.extend(tokens)
    
    print(f"总 token 数: {len(all_tokens):,}")
    
    return all_tokens


# 使用 HuggingFace datasets 的流式处理(大数据集)
def prepare_streaming_data():
    """流式处理大数据集"""
    from datasets import load_dataset
    import tiktoken
    
    enc = tiktoken.get_encoding("gpt2")
    
    # 流式加载,不会一次性加载到内存
    dataset = load_dataset("c4", "en", split="train", streaming=True)
    
    def tokenize_function(examples):
        tokens = enc.encode(examples['text'])
        return {'input_ids': tokens}
    
    tokenized = dataset.map(tokenize_function)
    
    return tokenized

六、预训练的替代方案

6.1 从开源模型继续训练

不一定要从头开始!可以从开源模型继续训练:

from transformers import AutoModelForCausalLM, AutoTokenizer

def continue_pretraining(base_model_name, train_data, output_dir):
    """从开源模型继续预训练"""
    
    # 加载预训练模型
    model = AutoModelForCausalLM.from_pretrained(base_model_name)
    tokenizer = AutoTokenizer.from_pretrained(base_model_name)
    
    # 准备数据
    # ...
    
    # 继续训练
    # 通常使用更小的学习率
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
    
    # 训练循环
    # ...
    
    # 保存
    model.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)

适用场景

  • 适配特定领域(医疗、法律、金融)
  • 适配特定语言
  • 注入新知识

6.2 知识蒸馏

用大模型"教"小模型:

def distillation_loss(student_logits, teacher_logits, temperature=2.0):
    """知识蒸馏损失"""
    soft_targets = F.softmax(teacher_logits / temperature, dim=-1)
    soft_predictions = F.log_softmax(student_logits / temperature, dim=-1)
    
    return F.kl_div(soft_predictions, soft_targets, reduction='batchmean') * (temperature ** 2)


def train_with_distillation(student, teacher, dataloader, optimizer, alpha=0.5):
    """带蒸馏的训练"""
    student.train()
    teacher.eval()
    
    for batch in dataloader:
        inputs, targets = batch
        
        # 学生模型前向
        student_logits, ce_loss = student(inputs, targets)
        
        # 教师模型前向
        with torch.no_grad():
            teacher_logits, _ = teacher(inputs)
        
        # 蒸馏损失
        distill_loss = distillation_loss(student_logits, teacher_logits)
        
        # 总损失 = CE 损失 + 蒸馏损失
        loss = alpha * ce_loss + (1 - alpha) * distill_loss
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

七、总结

预训练核心要点

mindmap root((预训练)) 目标 预测下一个Token 压缩即智能 数据 网页/书籍/代码 质量>数量 去重很重要 规模 Scaling Law 参数/数据/算力 涌现能力 工程 并行策略 混合精度 训练稳定性

关键 Takeaway

  1. 预训练本质是数据压缩:预测下一个词需要理解语言、知识、推理
  2. 数据决定上限:高质量、多样化的数据是关键
  3. Scaling Law 提供了路线图:更大的模型 + 更多的数据 = 更好的效果
  4. Chinchilla 法则:参数和数据应该同比例增长
  5. 工程挑战巨大:需要大量 GPU、复杂的并行策略、稳定的训练流程
  6. 不一定要从头开始:可以从开源模型继续训练或蒸馏

预训练成本参考

模型参数量训练数据估计成本
GPT-21.5B40GB~$50K
GPT-3175B570GB~$5M
LLaMA-65B65B1.4T tokens~$2M
GPT-4~1.8T??~$100M?

下一步学习

  • [ ] SFT:让模型学会听话
  • [ ] RLHF/DPO:人类偏好对齐
  • [ ] 分布式训练深入

参考资料

  1. Scaling Laws for Neural Language Models - OpenAI Scaling Law 论文
  2. Training Compute-Optimal Large Language Models - Chinchilla 论文
  3. LLaMA: Open and Efficient Foundation Language Models - LLaMA 论文
  4. The Pile: An 800GB Dataset of Diverse Text - 数据集构建
  5. nanoGPT - Karpathy 的教学实现

评论区
暂无评论
avatar