搜 索

SFT:让大模型学会听话

  • 6阅读
  • 2025年03月15日
  • 0评论
首页 / AI/大数据 / 正文

前言:从"续写机器"到"AI 助手"

预训练模型很强大,但有个问题:它只会续写,不会对话

用户: 请帮我写一首关于春天的诗
GPT (预训练版): 请帮我写一首关于春天的诗,这是很多人在春天来临时的想法。
              春天是一年四季中最美的季节...
              [继续滔滔不绝地介绍春天]

用户: 请帮我写一首关于春天的诗
GPT (SFT版): 春风拂面柳丝长,
            桃花灼灼映斜阳。
            燕子归来寻旧巢,
            一江春水向东流。

看出区别了吗?

  • 预训练模型:把你的问题当成文章开头,继续往下写
  • SFT 模型:理解你在"提问",然后"回答"

SFT(Supervised Fine-Tuning,监督微调) 就是教会模型这个区别的过程。

graph LR A[预训练模型] --> B[SFT 微调] B --> C[指令遵循模型] subgraph 能力变化 D[续写文本] --> E[理解指令] E --> F[生成回答] end style A fill:#ff6b6b style C fill:#4ecdc4

一、为什么预训练模型不够用?

1.1 预训练模型的"怪癖"

预训练模型在海量文本上学习"预测下一个词",它学到了:

  • 语法和语义
  • 世界知识
  • 一定的推理能力

但它没有学到

  • 什么是"问题",什么是"回答"
  • 用户的意图是什么
  • 如何组织有帮助的回复

1.2 看看预训练模型的"奇葩"回答

# 模拟预训练模型的行为
prompt = "What is the capital of France?"

# 预训练模型可能的输出(续写模式)
output_1 = "What is the capital of France? This is a question often asked in geography classes..."
output_2 = "What is the capital of France?\nA) Paris\nB) London\nC) Berlin\nD) Madrid"
output_3 = "What is the capital of France? The capital of France is a topic that..."

# 我们期望的输出
expected = "The capital of France is Paris."

预训练模型不知道你想要一个简洁的答案,它只是在续写看起来合理的文本。

1.3 InstructGPT 的启示

2022 年,OpenAI 发表了 InstructGPT 论文,提出了三步训练法:

graph TB subgraph InstructGPT三步法 Step1[Step 1: SFT
监督微调] Step2[Step 2: RM
训练奖励模型] Step3[Step 3: RLHF
强化学习优化] end Step1 --> Step2 --> Step3 Base[预训练模型] --> Step1 Step3 --> Final[ChatGPT] style Final fill:#4ecdc4

这篇文章聚焦 Step 1: SFT——用人工标注的指令数据,教模型学会"对话"。


二、指令数据:SFT 的燃料

2.1 什么是指令数据?

指令数据通常是 (指令, 输入, 输出) 三元组:

{
  "instruction": "将以下句子翻译成英文",
  "input": "今天天气真好",
  "output": "The weather is really nice today."
}

或者更简单的 (指令, 输出) 二元组:

{
  "instruction": "写一首关于春天的诗",
  "output": "春风拂面柳丝长,\n桃花灼灼映斜阳。\n燕子归来寻旧巢,\n一江春水向东流。"
}

多轮对话格式:

{
  "conversations": [
    {"role": "user", "content": "你好"},
    {"role": "assistant", "content": "你好!有什么我可以帮助你的吗?"},
    {"role": "user", "content": "帮我写一段 Python 代码,计算斐波那契数列"},
    {"role": "assistant", "content": "好的,这是一个计算斐波那契数列的 Python 函数:\n\n```python\ndef fibonacci(n):\n    if n <= 1:\n        return n\n    return fibonacci(n-1) + fibonacci(n-2)\n```\n\n你也可以使用迭代方式来提高效率..."}
  ]
}

2.2 指令数据的来源

mindmap root((指令数据来源)) 人工标注 专业标注团队 众包平台 成本高质量好 Self-Instruct 用 GPT-4 生成 Alpaca 方法 成本低规模大 开源数据集 ShareGPT FLAN Collection Dolly 用户反馈 真实对话日志 thumbs up/down

2.3 Self-Instruct:用大模型生成训练数据

Stanford 的 Alpaca 项目用 GPT-3.5 生成了 52K 指令数据,成本只有 $500:

# Self-Instruct 的核心思路
def generate_instruction_data(seed_tasks, num_generate=52000):
    """用大模型生成指令数据"""
    
    generated_data = []
    
    for i in range(num_generate):
        # 从种子任务中随机选几个作为示例
        examples = random.sample(seed_tasks, k=3)
        
        # 构造 prompt
        prompt = f"""
你是一个任务生成器。请根据以下示例,生成一个新的任务。

示例 1:
指令: {examples[0]['instruction']}
输入: {examples[0].get('input', '')}
输出: {examples[0]['output']}

示例 2:
指令: {examples[1]['instruction']}
输入: {examples[1].get('input', '')}
输出: {examples[1]['output']}

示例 3:
指令: {examples[2]['instruction']}
输入: {examples[2].get('input', '')}
输出: {examples[2]['output']}

现在请生成一个新的、不同的任务:
指令:"""
        
        # 调用 GPT-4 生成
        response = call_gpt4(prompt)
        
        # 解析响应
        new_task = parse_response(response)
        generated_data.append(new_task)
    
    return generated_data


# 种子任务示例(人工编写的高质量示例)
seed_tasks = [
    {
        "instruction": "将以下句子改写成更正式的表达",
        "input": "这玩意儿太棒了!",
        "output": "这个产品的质量非常出色。"
    },
    {
        "instruction": "解释什么是机器学习",
        "input": "",
        "output": "机器学习是人工智能的一个分支,它使计算机能够从数据中学习模式,而无需明确编程..."
    },
    # ... 更多种子任务
]

2.4 开源指令数据集

数据集规模语言特点
Alpaca52K英文GPT-3.5 生成,经典入门
ShareGPT90K+多语言真实用户对话
FLAN Collection1.8M+英文Google 出品,任务多样
Dolly15K英文人工标注,高质量
BELLE200K+中文中文指令数据
Firefly1.1M中文多任务中文数据
OpenAssistant160K+多语言众包标注,多轮对话

2.5 数据质量 vs 数量

一个重要发现:高质量的小数据集往往比低质量的大数据集效果更好

graph LR subgraph LIMA论文发现 A[1000条高质量数据] --> B[效果优秀] C[100000条低质量数据] --> D[效果一般] end style A fill:#4ecdc4 style B fill:#4ecdc4

LIMA 论文的核心观点:

"预训练已经学到了几乎所有的知识和能力,SFT 只是在教模型用正确的'格式'输出。"

所以,与其堆数据量,不如精心设计高质量的指令数据。


三、Chat Template:对话的"格式"

3.1 为什么需要 Chat Template?

模型只认识 token 序列,不认识"用户"和"助手"的概念。我们需要一个格式来表示对话结构。

原始对话:
User: 你好
Assistant: 你好!有什么可以帮你的?
User: 今天天气怎么样?

转换成 token 序列:
<|im_start|>user
你好<|im_end|>
<|im_start|>assistant
你好!有什么可以帮你的?<|im_end|>
<|im_start|>user
今天天气怎么样?<|im_end|>
<|im_start|>assistant

3.2 常见的 Chat Template

1. ChatML 格式(OpenAI / Qwen)

<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
Hello!<|im_end|>
<|im_start|>assistant
Hi there! How can I help you today?<|im_end|>

2. Llama 2 格式

<s>[INST] <<SYS>>
You are a helpful assistant.
<</SYS>>

Hello! [/INST] Hi there! How can I help you today? </s><s>[INST] What's the weather? [/INST]

3. Alpaca 格式

Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
Hello!

### Response:
Hi there! How can I help you today?

4. Vicuna 格式

A chat between a curious user and an artificial intelligence assistant.

USER: Hello!
ASSISTANT: Hi there! How can I help you today?
USER: What's the weather?
ASSISTANT:

3.3 使用 HuggingFace 的 Chat Template

from transformers import AutoTokenizer

# 加载 tokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-7B-Chat", trust_remote_code=True)

# 对话消息
messages = [
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": "Hello!"},
    {"role": "assistant", "content": "Hi! How can I help you?"},
    {"role": "user", "content": "What is machine learning?"},
]

# 应用 chat template
text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True  # 添加助手的开始标记,用于生成
)

print(text)
# <|im_start|>system
# You are a helpful assistant.<|im_end|>
# <|im_start|>user
# Hello!<|im_end|>
# <|im_start|>assistant
# Hi! How can I help you?<|im_end|>
# <|im_start|>user
# What is machine learning?<|im_end|>
# <|im_start|>assistant

# Tokenize
input_ids = tokenizer.apply_chat_template(
    messages,
    tokenize=True,
    add_generation_prompt=True,
    return_tensors="pt"
)

3.4 自定义 Chat Template

# 设置自定义 chat template(Jinja2 格式)
custom_template = """{% for message in messages %}
{% if message['role'] == 'system' %}
<|system|>{{ message['content'] }}</s>
{% elif message['role'] == 'user' %}
<|user|>{{ message['content'] }}</s>
{% elif message['role'] == 'assistant' %}
<|assistant|>{{ message['content'] }}</s>
{% endif %}
{% endfor %}
{% if add_generation_prompt %}
<|assistant|>
{% endif %}"""

tokenizer.chat_template = custom_template

# 使用
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

四、SFT 训练原理

4.1 训练目标

SFT 的训练目标和预训练一样——预测下一个 token

但有一个关键区别:只在"回答"部分计算损失

输入: <user>今天天气怎么样?</user><assistant>今天天气很好,阳光明媚。</assistant>

损失计算:
[用户部分: 不计算损失] [助手部分: 计算损失]

为什么?因为我们想让模型学习"如何回答",而不是学习"如何提问"。

def compute_sft_loss(model, input_ids, labels):
    """
    SFT 损失计算
    labels 中,用户部分被标记为 -100(忽略)
    """
    # 前向传播
    outputs = model(input_ids)
    logits = outputs.logits
    
    # Shift: 预测下一个 token
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()
    
    # 计算损失(-100 会被自动忽略)
    loss = F.cross_entropy(
        shift_logits.view(-1, shift_logits.size(-1)),
        shift_labels.view(-1),
        ignore_index=-100
    )
    
    return loss

4.2 数据预处理:构造 Labels

def preprocess_sft_data(tokenizer, messages, max_length=2048):
    """
    预处理 SFT 数据
    关键:只在 assistant 回复部分计算损失
    """
    input_ids = []
    labels = []
    
    for message in messages:
        role = message['role']
        content = message['content']
        
        # 添加角色标记
        if role == 'user':
            role_tokens = tokenizer.encode("<|user|>", add_special_tokens=False)
        else:
            role_tokens = tokenizer.encode("<|assistant|>", add_special_tokens=False)
        
        # 编码内容
        content_tokens = tokenizer.encode(content, add_special_tokens=False)
        end_tokens = tokenizer.encode("</s>", add_special_tokens=False)
        
        # 完整的消息 tokens
        message_tokens = role_tokens + content_tokens + end_tokens
        
        # 构造 labels
        if role == 'user':
            # 用户部分:全部标记为 -100(不计算损失)
            message_labels = [-100] * len(message_tokens)
        else:
            # 助手部分:计算损失
            # 但角色标记部分也不计算
            message_labels = [-100] * len(role_tokens) + content_tokens + end_tokens
        
        input_ids.extend(message_tokens)
        labels.extend(message_labels)
    
    # 截断
    if len(input_ids) > max_length:
        input_ids = input_ids[:max_length]
        labels = labels[:max_length]
    
    return {
        'input_ids': input_ids,
        'labels': labels,
        'attention_mask': [1] * len(input_ids)
    }


# 示例
messages = [
    {"role": "user", "content": "你好"},
    {"role": "assistant", "content": "你好!有什么可以帮你的?"},
    {"role": "user", "content": "1+1等于几?"},
    {"role": "assistant", "content": "1+1等于2。"},
]

result = preprocess_sft_data(tokenizer, messages)

# 可视化哪些部分计算损失
for i, (token_id, label) in enumerate(zip(result['input_ids'], result['labels'])):
    token = tokenizer.decode([token_id])
    loss_marker = "✓" if label != -100 else "✗"
    print(f"{i:3d}: {token:10s} | label={label:6d} | 计算损失: {loss_marker}")

4.3 完整的数据处理流程

flowchart TB A[原始指令数据] --> B[统一格式] B --> C[应用 Chat Template] C --> D[Tokenize] D --> E[构造 Labels] E --> F[Padding/截断] F --> G[构建 DataLoader] subgraph Labels构造 E1[用户部分 → -100] E2[助手部分 → token_id] end

五、实战:使用 Transformers 进行 SFT

5.1 完整训练代码

import torch
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    DataCollatorForSeq2Seq,
)
from peft import LoraConfig, get_peft_model, TaskType
import json
from tqdm import tqdm


# ========== 数据集定义 ==========

class SFTDataset(Dataset):
    """SFT 数据集"""
    
    def __init__(self, data_path, tokenizer, max_length=2048):
        self.tokenizer = tokenizer
        self.max_length = max_length
        
        # 加载数据
        with open(data_path, 'r', encoding='utf-8') as f:
            self.data = json.load(f)
        
        print(f"加载了 {len(self.data)} 条训练数据")
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        
        # 支持两种格式
        if 'conversations' in item:
            messages = item['conversations']
        else:
            # Alpaca 格式转换
            messages = []
            if item.get('instruction'):
                content = item['instruction']
                if item.get('input'):
                    content += f"\n\n{item['input']}"
                messages.append({"role": "user", "content": content})
            if item.get('output'):
                messages.append({"role": "assistant", "content": item['output']})
        
        return self.process_messages(messages)
    
    def process_messages(self, messages):
        """处理对话消息"""
        # 使用 chat template
        text = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=False
        )
        
        # Tokenize
        tokenized = self.tokenizer(
            text,
            max_length=self.max_length,
            truncation=True,
            padding=False,
            return_tensors=None,
        )
        
        input_ids = tokenized['input_ids']
        
        # 构造 labels:找到 assistant 回复的位置
        labels = self.create_labels(messages, input_ids)
        
        return {
            'input_ids': input_ids,
            'labels': labels,
            'attention_mask': tokenized['attention_mask'],
        }
    
    def create_labels(self, messages, input_ids):
        """创建 labels,只在 assistant 回复处计算损失"""
        labels = [-100] * len(input_ids)
        
        # 找到每个 assistant 回复的 token 范围
        current_pos = 0
        full_text = self.tokenizer.decode(input_ids)
        
        for msg in messages:
            if msg['role'] == 'assistant':
                # 找到这条回复在 full_text 中的位置
                content = msg['content']
                start_idx = full_text.find(content, current_pos)
                if start_idx != -1:
                    # 转换为 token 位置
                    prefix = full_text[:start_idx]
                    prefix_tokens = self.tokenizer.encode(prefix, add_special_tokens=False)
                    content_tokens = self.tokenizer.encode(content, add_special_tokens=False)
                    
                    start_token = len(prefix_tokens)
                    end_token = start_token + len(content_tokens)
                    
                    # 设置 labels
                    for i in range(start_token, min(end_token, len(labels))):
                        labels[i] = input_ids[i]
                    
                    current_pos = start_idx + len(content)
        
        return labels


# ========== 训练配置 ==========

def train_sft(
    model_name: str,
    data_path: str,
    output_dir: str,
    use_lora: bool = True,
    num_epochs: int = 3,
    batch_size: int = 4,
    learning_rate: float = 2e-5,
    max_length: int = 2048,
):
    """SFT 训练主函数"""
    
    print(f"加载模型: {model_name}")
    
    # 加载 tokenizer
    tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        trust_remote_code=True,
        padding_side='right',
    )
    
    # 确保有 pad token
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    # 加载模型
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        trust_remote_code=True,
        torch_dtype=torch.bfloat16,
        device_map="auto",
    )
    
    # 使用 LoRA
    if use_lora:
        print("应用 LoRA...")
        lora_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            r=64,                     # LoRA 秩
            lora_alpha=16,            # 缩放因子
            lora_dropout=0.1,
            target_modules=["q_proj", "v_proj", "k_proj", "o_proj", 
                          "gate_proj", "up_proj", "down_proj"],
        )
        model = get_peft_model(model, lora_config)
        model.print_trainable_parameters()
    
    # 加载数据集
    print(f"加载数据集: {data_path}")
    train_dataset = SFTDataset(data_path, tokenizer, max_length)
    
    # 数据整理器
    data_collator = DataCollatorForSeq2Seq(
        tokenizer=tokenizer,
        padding=True,
        return_tensors="pt",
    )
    
    # 训练参数
    training_args = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=num_epochs,
        per_device_train_batch_size=batch_size,
        gradient_accumulation_steps=4,
        learning_rate=learning_rate,
        weight_decay=0.01,
        warmup_ratio=0.1,
        lr_scheduler_type="cosine",
        logging_steps=10,
        save_steps=500,
        save_total_limit=3,
        bf16=True,
        gradient_checkpointing=True,
        report_to="none",
    )
    
    # 创建 Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        data_collator=data_collator,
    )
    
    # 开始训练
    print("开始训练...")
    trainer.train()
    
    # 保存模型
    print(f"保存模型到: {output_dir}")
    trainer.save_model()
    tokenizer.save_pretrained(output_dir)
    
    print("训练完成!")


# ========== 使用示例 ==========

if __name__ == "__main__":
    train_sft(
        model_name="Qwen/Qwen-7B",
        data_path="sft_data.json",
        output_dir="./sft_model",
        use_lora=True,
        num_epochs=3,
        batch_size=4,
        learning_rate=2e-5,
    )

5.2 使用 TRL 库(更简洁)

from trl import SFTTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from datasets import load_dataset
from peft import LoraConfig

def train_with_trl():
    """使用 TRL 的 SFTTrainer"""
    
    # 加载模型和 tokenizer
    model_name = "meta-llama/Llama-2-7b-hf"
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16,
        device_map="auto",
    )
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token
    
    # 加载数据集
    dataset = load_dataset("tatsu-lab/alpaca", split="train")
    
    # 格式化函数
    def formatting_func(example):
        text = f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
{example['instruction']}

### Input:
{example['input']}

### Response:
{example['output']}"""
        return text
    
    # LoRA 配置
    peft_config = LoraConfig(
        r=64,
        lora_alpha=16,
        lora_dropout=0.1,
        target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
        task_type="CAUSAL_LM",
    )
    
    # 训练参数
    training_args = TrainingArguments(
        output_dir="./sft_output",
        num_train_epochs=3,
        per_device_train_batch_size=4,
        gradient_accumulation_steps=4,
        learning_rate=2e-5,
        bf16=True,
        logging_steps=10,
        save_steps=500,
    )
    
    # 创建 SFTTrainer
    trainer = SFTTrainer(
        model=model,
        args=training_args,
        train_dataset=dataset,
        tokenizer=tokenizer,
        formatting_func=formatting_func,
        peft_config=peft_config,
        max_seq_length=2048,
    )
    
    # 训练
    trainer.train()
    
    # 保存
    trainer.save_model()


if __name__ == "__main__":
    train_with_trl()

5.3 准备指令数据

import json

def create_sft_dataset():
    """创建示例 SFT 数据集"""
    
    data = [
        # 简单问答
        {
            "conversations": [
                {"role": "user", "content": "什么是机器学习?"},
                {"role": "assistant", "content": "机器学习是人工智能的一个分支,它使计算机系统能够从数据中学习和改进,而无需进行明确的编程。通过分析大量数据,机器学习算法可以识别模式并做出决策或预测。"}
            ]
        },
        # 多轮对话
        {
            "conversations": [
                {"role": "user", "content": "你好"},
                {"role": "assistant", "content": "你好!我是一个AI助手,有什么可以帮助你的吗?"},
                {"role": "user", "content": "帮我写一段 Python 代码计算阶乘"},
                {"role": "assistant", "content": "好的,这是一个计算阶乘的 Python 函数:\n\n```python\ndef factorial(n):\n    if n <= 1:\n        return 1\n    return n * factorial(n - 1)\n\n# 测试\nprint(factorial(5))  # 输出: 120\n```\n\n这个函数使用递归方式计算阶乘。你也可以使用循环实现:\n\n```python\ndef factorial_iterative(n):\n    result = 1\n    for i in range(2, n + 1):\n        result *= i\n    return result\n```"}
            ]
        },
        # Alpaca 格式
        {
            "instruction": "将以下文本翻译成英文",
            "input": "人工智能正在改变世界",
            "output": "Artificial intelligence is changing the world."
        },
        {
            "instruction": "写一首关于编程的俳句",
            "input": "",
            "output": "代码如流水\n调试中日月交替\n终见绿灯亮"
        },
        # 更多数据...
    ]
    
    # 保存
    with open('sft_data.json', 'w', encoding='utf-8') as f:
        json.dump(data, f, ensure_ascii=False, indent=2)
    
    print(f"创建了 {len(data)} 条训练数据")
    return data


# 转换其他格式的数据
def convert_alpaca_to_conversations(alpaca_data):
    """将 Alpaca 格式转换为 conversations 格式"""
    converted = []
    
    for item in alpaca_data:
        conversations = []
        
        # 构建用户消息
        user_content = item['instruction']
        if item.get('input'):
            user_content += f"\n\n{item['input']}"
        
        conversations.append({
            "role": "user",
            "content": user_content
        })
        
        # 助手回复
        conversations.append({
            "role": "assistant", 
            "content": item['output']
        })
        
        converted.append({"conversations": conversations})
    
    return converted


if __name__ == "__main__":
    create_sft_dataset()

六、SFT 的高级技巧

6.1 数据混合策略

不同类型的任务应该按比例混合:

def create_mixed_dataset(datasets_config):
    """
    混合不同来源的数据集
    
    datasets_config = {
        "general_chat": {"path": "chat.json", "ratio": 0.3},
        "coding": {"path": "code.json", "ratio": 0.25},
        "math": {"path": "math.json", "ratio": 0.15},
        "writing": {"path": "writing.json", "ratio": 0.15},
        "knowledge": {"path": "knowledge.json", "ratio": 0.15},
    }
    """
    all_data = []
    
    for name, config in datasets_config.items():
        with open(config['path'], 'r') as f:
            data = json.load(f)
        
        # 按比例采样
        target_size = int(len(data) * config['ratio'])
        sampled = random.sample(data, min(target_size, len(data)))
        
        # 添加来源标记
        for item in sampled:
            item['source'] = name
        
        all_data.extend(sampled)
        print(f"{name}: {len(sampled)} 条")
    
    random.shuffle(all_data)
    return all_data

6.2 动态 Padding

避免浪费计算在 padding token 上:

from dataclasses import dataclass
from transformers import PreTrainedTokenizerBase

@dataclass
class DynamicPaddingCollator:
    """动态 padding 到 batch 内最大长度"""
    tokenizer: PreTrainedTokenizerBase
    padding: bool = True
    max_length: int = None
    
    def __call__(self, features):
        # 找到 batch 内最大长度
        max_len = max(len(f['input_ids']) for f in features)
        
        if self.max_length:
            max_len = min(max_len, self.max_length)
        
        batch = {
            'input_ids': [],
            'attention_mask': [],
            'labels': [],
        }
        
        for f in features:
            # Padding
            padding_length = max_len - len(f['input_ids'])
            
            input_ids = f['input_ids'] + [self.tokenizer.pad_token_id] * padding_length
            attention_mask = f['attention_mask'] + [0] * padding_length
            labels = f['labels'] + [-100] * padding_length
            
            batch['input_ids'].append(input_ids[:max_len])
            batch['attention_mask'].append(attention_mask[:max_len])
            batch['labels'].append(labels[:max_len])
        
        # 转换为 tensor
        return {k: torch.tensor(v) for k, v in batch.items()}

6.3 NEFTune:训练时加噪声

NEFTune 是一个简单但有效的技巧:在 embedding 层加入噪声。

def neftune_forward(self, input_ids):
    """NEFTune: 在训练时给 embedding 加噪声"""
    embeddings = self.original_forward(input_ids)
    
    if self.training:
        # 添加均匀分布噪声
        noise = torch.rand_like(embeddings) * 2 - 1  # [-1, 1]
        noise = noise * self.neftune_alpha / (embeddings.size(1) ** 0.5)
        embeddings = embeddings + noise
    
    return embeddings


def apply_neftune(model, alpha=5.0):
    """应用 NEFTune"""
    embed_layer = model.get_input_embeddings()
    embed_layer.original_forward = embed_layer.forward
    embed_layer.neftune_alpha = alpha
    embed_layer.forward = lambda x: neftune_forward(embed_layer, x)
    return model

6.4 Packing:提高训练效率

把多个短样本打包成一个长序列,减少 padding:

def pack_sequences(examples, max_length=2048, sep_token_id=None):
    """
    将多个短序列打包成长序列
    
    原本:
    [seq1] [pad] [pad] [pad]
    [seq2] [pad] [pad]
    [seq3] [pad] [pad] [pad] [pad]
    
    打包后:
    [seq1] [sep] [seq2] [sep] [seq3] [pad]
    """
    packed_input_ids = []
    packed_labels = []
    
    current_input_ids = []
    current_labels = []
    
    for ex in examples:
        input_ids = ex['input_ids']
        labels = ex['labels']
        
        # 如果加上当前样本会超长,先保存当前打包结果
        if len(current_input_ids) + len(input_ids) + 1 > max_length:
            if current_input_ids:
                packed_input_ids.append(current_input_ids)
                packed_labels.append(current_labels)
            current_input_ids = []
            current_labels = []
        
        # 添加分隔符
        if current_input_ids and sep_token_id:
            current_input_ids.append(sep_token_id)
            current_labels.append(-100)
        
        current_input_ids.extend(input_ids)
        current_labels.extend(labels)
    
    # 保存最后一个
    if current_input_ids:
        packed_input_ids.append(current_input_ids)
        packed_labels.append(current_labels)
    
    return packed_input_ids, packed_labels

七、SFT 的局限性

7.1 SFT 只是"模仿"

SFT 让模型学会了输出"看起来像"人类标注的回答,但它不理解什么是"好"回答

问题: 1+1等于几?

可能的回答 A: 1+1等于2。
可能的回答 B: 1+1等于3。
可能的回答 C: 这是一个很好的数学问题,让我们来分析一下...

SFT 模型可能会输出 B 或 C,因为它只学会了"格式",
不知道 A 才是"正确"的回答。

7.2 无法学习偏好

有些回答没有明确的"对错",只有"好坏":

问题: 写一首诗

回答 A: 春眠不觉晓,处处闻啼鸟。
回答 B: 床前明月光,疑是地上霜。

哪个更好?这取决于用户的偏好。
SFT 无法学习这种偏好。

7.3 幻觉问题

SFT 模型可能会"一本正经地胡说八道":

用户: 阿姆斯特朗是什么时候登月的?

SFT 模型: 尼尔·阿姆斯特朗于 1969 年 7 月 21 日成为第一个登上月球的人。
         (可能还会编造更多细节,即使它并不确定)

7.4 引出 RLHF

这些问题,需要 RLHF(Reinforcement Learning from Human Feedback) 来解决:

graph TB subgraph SFT的局限 A[只学格式] B[不懂好坏] C[容易幻觉] end subgraph RLHF解决 D[学习人类偏好] E[奖励好的回答] F[惩罚坏的回答] end A --> D B --> E C --> F style D fill:#4ecdc4 style E fill:#4ecdc4 style F fill:#4ecdc4

八、总结

SFT 核心流程

flowchart TB A[预训练模型] --> B[准备指令数据] B --> C[应用 Chat Template] C --> D[构造 Labels
只在回复处计算损失] D --> E[训练] E --> F[SFT 模型] F --> G[能够理解指令] F --> H[能够对话] F --> I[但仍有局限...] style A fill:#ff6b6b style F fill:#4ecdc4

关键 Takeaway

  1. SFT 教模型"对话格式":从续写转变为问答
  2. 只在回复部分计算损失:让模型学习"如何回答"
  3. Chat Template 很重要:统一的格式才能正确训练
  4. 数据质量 > 数量:LIMA 论文证明 1000 条高质量数据就够
  5. SFT 的局限:只会模仿,不懂好坏,需要 RLHF 进一步优化

实践建议

场景建议
快速上手使用 TRL 库的 SFTTrainer
资源有限使用 LoRA/QLoRA
中文模型选择 Qwen、ChatGLM 等
数据准备优先保证质量,多样性
训练稳定使用梯度裁剪、warmup

下一步学习

  • [ ] RLHF/DPO:人类偏好对齐
  • [ ] LoRA:参数高效微调
  • [ ] 评估:如何衡量 SFT 效果

参考资料

  1. InstructGPT Paper - Training language models to follow instructions
  2. LIMA Paper - Less Is More for Alignment
  3. Alpaca - Stanford Alpaca
  4. TRL Library - HuggingFace TRL
  5. NEFTune Paper - Noisy Embeddings Improve Instruction Finetuning

评论区
暂无评论
avatar