前言:从"续写机器"到"AI 助手"
预训练模型很强大,但有个问题:它只会续写,不会对话。
用户: 请帮我写一首关于春天的诗
GPT (预训练版): 请帮我写一首关于春天的诗,这是很多人在春天来临时的想法。
春天是一年四季中最美的季节...
[继续滔滔不绝地介绍春天]
用户: 请帮我写一首关于春天的诗
GPT (SFT版): 春风拂面柳丝长,
桃花灼灼映斜阳。
燕子归来寻旧巢,
一江春水向东流。看出区别了吗?
- 预训练模型:把你的问题当成文章开头,继续往下写
- SFT 模型:理解你在"提问",然后"回答"
SFT(Supervised Fine-Tuning,监督微调) 就是教会模型这个区别的过程。
一、为什么预训练模型不够用?
1.1 预训练模型的"怪癖"
预训练模型在海量文本上学习"预测下一个词",它学到了:
- 语法和语义
- 世界知识
- 一定的推理能力
但它没有学到:
- 什么是"问题",什么是"回答"
- 用户的意图是什么
- 如何组织有帮助的回复
1.2 看看预训练模型的"奇葩"回答
# 模拟预训练模型的行为
prompt = "What is the capital of France?"
# 预训练模型可能的输出(续写模式)
output_1 = "What is the capital of France? This is a question often asked in geography classes..."
output_2 = "What is the capital of France?\nA) Paris\nB) London\nC) Berlin\nD) Madrid"
output_3 = "What is the capital of France? The capital of France is a topic that..."
# 我们期望的输出
expected = "The capital of France is Paris."预训练模型不知道你想要一个简洁的答案,它只是在续写看起来合理的文本。
1.3 InstructGPT 的启示
2022 年,OpenAI 发表了 InstructGPT 论文,提出了三步训练法:
监督微调] Step2[Step 2: RM
训练奖励模型] Step3[Step 3: RLHF
强化学习优化] end Step1 --> Step2 --> Step3 Base[预训练模型] --> Step1 Step3 --> Final[ChatGPT] style Final fill:#4ecdc4
这篇文章聚焦 Step 1: SFT——用人工标注的指令数据,教模型学会"对话"。
二、指令数据:SFT 的燃料
2.1 什么是指令数据?
指令数据通常是 (指令, 输入, 输出) 三元组:
{
"instruction": "将以下句子翻译成英文",
"input": "今天天气真好",
"output": "The weather is really nice today."
}或者更简单的 (指令, 输出) 二元组:
{
"instruction": "写一首关于春天的诗",
"output": "春风拂面柳丝长,\n桃花灼灼映斜阳。\n燕子归来寻旧巢,\n一江春水向东流。"
}多轮对话格式:
{
"conversations": [
{"role": "user", "content": "你好"},
{"role": "assistant", "content": "你好!有什么我可以帮助你的吗?"},
{"role": "user", "content": "帮我写一段 Python 代码,计算斐波那契数列"},
{"role": "assistant", "content": "好的,这是一个计算斐波那契数列的 Python 函数:\n\n```python\ndef fibonacci(n):\n if n <= 1:\n return n\n return fibonacci(n-1) + fibonacci(n-2)\n```\n\n你也可以使用迭代方式来提高效率..."}
]
}2.2 指令数据的来源
2.3 Self-Instruct:用大模型生成训练数据
Stanford 的 Alpaca 项目用 GPT-3.5 生成了 52K 指令数据,成本只有 $500:
# Self-Instruct 的核心思路
def generate_instruction_data(seed_tasks, num_generate=52000):
"""用大模型生成指令数据"""
generated_data = []
for i in range(num_generate):
# 从种子任务中随机选几个作为示例
examples = random.sample(seed_tasks, k=3)
# 构造 prompt
prompt = f"""
你是一个任务生成器。请根据以下示例,生成一个新的任务。
示例 1:
指令: {examples[0]['instruction']}
输入: {examples[0].get('input', '')}
输出: {examples[0]['output']}
示例 2:
指令: {examples[1]['instruction']}
输入: {examples[1].get('input', '')}
输出: {examples[1]['output']}
示例 3:
指令: {examples[2]['instruction']}
输入: {examples[2].get('input', '')}
输出: {examples[2]['output']}
现在请生成一个新的、不同的任务:
指令:"""
# 调用 GPT-4 生成
response = call_gpt4(prompt)
# 解析响应
new_task = parse_response(response)
generated_data.append(new_task)
return generated_data
# 种子任务示例(人工编写的高质量示例)
seed_tasks = [
{
"instruction": "将以下句子改写成更正式的表达",
"input": "这玩意儿太棒了!",
"output": "这个产品的质量非常出色。"
},
{
"instruction": "解释什么是机器学习",
"input": "",
"output": "机器学习是人工智能的一个分支,它使计算机能够从数据中学习模式,而无需明确编程..."
},
# ... 更多种子任务
]2.4 开源指令数据集
| 数据集 | 规模 | 语言 | 特点 |
|---|---|---|---|
| Alpaca | 52K | 英文 | GPT-3.5 生成,经典入门 |
| ShareGPT | 90K+ | 多语言 | 真实用户对话 |
| FLAN Collection | 1.8M+ | 英文 | Google 出品,任务多样 |
| Dolly | 15K | 英文 | 人工标注,高质量 |
| BELLE | 200K+ | 中文 | 中文指令数据 |
| Firefly | 1.1M | 中文 | 多任务中文数据 |
| OpenAssistant | 160K+ | 多语言 | 众包标注,多轮对话 |
2.5 数据质量 vs 数量
一个重要发现:高质量的小数据集往往比低质量的大数据集效果更好。
LIMA 论文的核心观点:
"预训练已经学到了几乎所有的知识和能力,SFT 只是在教模型用正确的'格式'输出。"
所以,与其堆数据量,不如精心设计高质量的指令数据。
三、Chat Template:对话的"格式"
3.1 为什么需要 Chat Template?
模型只认识 token 序列,不认识"用户"和"助手"的概念。我们需要一个格式来表示对话结构。
原始对话:
User: 你好
Assistant: 你好!有什么可以帮你的?
User: 今天天气怎么样?
转换成 token 序列:
<|im_start|>user
你好<|im_end|>
<|im_start|>assistant
你好!有什么可以帮你的?<|im_end|>
<|im_start|>user
今天天气怎么样?<|im_end|>
<|im_start|>assistant3.2 常见的 Chat Template
1. ChatML 格式(OpenAI / Qwen)
<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
Hello!<|im_end|>
<|im_start|>assistant
Hi there! How can I help you today?<|im_end|>2. Llama 2 格式
<s>[INST] <<SYS>>
You are a helpful assistant.
<</SYS>>
Hello! [/INST] Hi there! How can I help you today? </s><s>[INST] What's the weather? [/INST]3. Alpaca 格式
Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
Hello!
### Response:
Hi there! How can I help you today?4. Vicuna 格式
A chat between a curious user and an artificial intelligence assistant.
USER: Hello!
ASSISTANT: Hi there! How can I help you today?
USER: What's the weather?
ASSISTANT:3.3 使用 HuggingFace 的 Chat Template
from transformers import AutoTokenizer
# 加载 tokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-7B-Chat", trust_remote_code=True)
# 对话消息
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Hi! How can I help you?"},
{"role": "user", "content": "What is machine learning?"},
]
# 应用 chat template
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True # 添加助手的开始标记,用于生成
)
print(text)
# <|im_start|>system
# You are a helpful assistant.<|im_end|>
# <|im_start|>user
# Hello!<|im_end|>
# <|im_start|>assistant
# Hi! How can I help you?<|im_end|>
# <|im_start|>user
# What is machine learning?<|im_end|>
# <|im_start|>assistant
# Tokenize
input_ids = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt"
)3.4 自定义 Chat Template
# 设置自定义 chat template(Jinja2 格式)
custom_template = """{% for message in messages %}
{% if message['role'] == 'system' %}
<|system|>{{ message['content'] }}</s>
{% elif message['role'] == 'user' %}
<|user|>{{ message['content'] }}</s>
{% elif message['role'] == 'assistant' %}
<|assistant|>{{ message['content'] }}</s>
{% endif %}
{% endfor %}
{% if add_generation_prompt %}
<|assistant|>
{% endif %}"""
tokenizer.chat_template = custom_template
# 使用
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)四、SFT 训练原理
4.1 训练目标
SFT 的训练目标和预训练一样——预测下一个 token。
但有一个关键区别:只在"回答"部分计算损失。
输入: <user>今天天气怎么样?</user><assistant>今天天气很好,阳光明媚。</assistant>
损失计算:
[用户部分: 不计算损失] [助手部分: 计算损失]为什么?因为我们想让模型学习"如何回答",而不是学习"如何提问"。
def compute_sft_loss(model, input_ids, labels):
"""
SFT 损失计算
labels 中,用户部分被标记为 -100(忽略)
"""
# 前向传播
outputs = model(input_ids)
logits = outputs.logits
# Shift: 预测下一个 token
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# 计算损失(-100 会被自动忽略)
loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=-100
)
return loss4.2 数据预处理:构造 Labels
def preprocess_sft_data(tokenizer, messages, max_length=2048):
"""
预处理 SFT 数据
关键:只在 assistant 回复部分计算损失
"""
input_ids = []
labels = []
for message in messages:
role = message['role']
content = message['content']
# 添加角色标记
if role == 'user':
role_tokens = tokenizer.encode("<|user|>", add_special_tokens=False)
else:
role_tokens = tokenizer.encode("<|assistant|>", add_special_tokens=False)
# 编码内容
content_tokens = tokenizer.encode(content, add_special_tokens=False)
end_tokens = tokenizer.encode("</s>", add_special_tokens=False)
# 完整的消息 tokens
message_tokens = role_tokens + content_tokens + end_tokens
# 构造 labels
if role == 'user':
# 用户部分:全部标记为 -100(不计算损失)
message_labels = [-100] * len(message_tokens)
else:
# 助手部分:计算损失
# 但角色标记部分也不计算
message_labels = [-100] * len(role_tokens) + content_tokens + end_tokens
input_ids.extend(message_tokens)
labels.extend(message_labels)
# 截断
if len(input_ids) > max_length:
input_ids = input_ids[:max_length]
labels = labels[:max_length]
return {
'input_ids': input_ids,
'labels': labels,
'attention_mask': [1] * len(input_ids)
}
# 示例
messages = [
{"role": "user", "content": "你好"},
{"role": "assistant", "content": "你好!有什么可以帮你的?"},
{"role": "user", "content": "1+1等于几?"},
{"role": "assistant", "content": "1+1等于2。"},
]
result = preprocess_sft_data(tokenizer, messages)
# 可视化哪些部分计算损失
for i, (token_id, label) in enumerate(zip(result['input_ids'], result['labels'])):
token = tokenizer.decode([token_id])
loss_marker = "✓" if label != -100 else "✗"
print(f"{i:3d}: {token:10s} | label={label:6d} | 计算损失: {loss_marker}")4.3 完整的数据处理流程
五、实战:使用 Transformers 进行 SFT
5.1 完整训练代码
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TrainingArguments,
Trainer,
DataCollatorForSeq2Seq,
)
from peft import LoraConfig, get_peft_model, TaskType
import json
from tqdm import tqdm
# ========== 数据集定义 ==========
class SFTDataset(Dataset):
"""SFT 数据集"""
def __init__(self, data_path, tokenizer, max_length=2048):
self.tokenizer = tokenizer
self.max_length = max_length
# 加载数据
with open(data_path, 'r', encoding='utf-8') as f:
self.data = json.load(f)
print(f"加载了 {len(self.data)} 条训练数据")
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
# 支持两种格式
if 'conversations' in item:
messages = item['conversations']
else:
# Alpaca 格式转换
messages = []
if item.get('instruction'):
content = item['instruction']
if item.get('input'):
content += f"\n\n{item['input']}"
messages.append({"role": "user", "content": content})
if item.get('output'):
messages.append({"role": "assistant", "content": item['output']})
return self.process_messages(messages)
def process_messages(self, messages):
"""处理对话消息"""
# 使用 chat template
text = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False
)
# Tokenize
tokenized = self.tokenizer(
text,
max_length=self.max_length,
truncation=True,
padding=False,
return_tensors=None,
)
input_ids = tokenized['input_ids']
# 构造 labels:找到 assistant 回复的位置
labels = self.create_labels(messages, input_ids)
return {
'input_ids': input_ids,
'labels': labels,
'attention_mask': tokenized['attention_mask'],
}
def create_labels(self, messages, input_ids):
"""创建 labels,只在 assistant 回复处计算损失"""
labels = [-100] * len(input_ids)
# 找到每个 assistant 回复的 token 范围
current_pos = 0
full_text = self.tokenizer.decode(input_ids)
for msg in messages:
if msg['role'] == 'assistant':
# 找到这条回复在 full_text 中的位置
content = msg['content']
start_idx = full_text.find(content, current_pos)
if start_idx != -1:
# 转换为 token 位置
prefix = full_text[:start_idx]
prefix_tokens = self.tokenizer.encode(prefix, add_special_tokens=False)
content_tokens = self.tokenizer.encode(content, add_special_tokens=False)
start_token = len(prefix_tokens)
end_token = start_token + len(content_tokens)
# 设置 labels
for i in range(start_token, min(end_token, len(labels))):
labels[i] = input_ids[i]
current_pos = start_idx + len(content)
return labels
# ========== 训练配置 ==========
def train_sft(
model_name: str,
data_path: str,
output_dir: str,
use_lora: bool = True,
num_epochs: int = 3,
batch_size: int = 4,
learning_rate: float = 2e-5,
max_length: int = 2048,
):
"""SFT 训练主函数"""
print(f"加载模型: {model_name}")
# 加载 tokenizer
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True,
padding_side='right',
)
# 确保有 pad token
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# 加载模型
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
device_map="auto",
)
# 使用 LoRA
if use_lora:
print("应用 LoRA...")
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=64, # LoRA 秩
lora_alpha=16, # 缩放因子
lora_dropout=0.1,
target_modules=["q_proj", "v_proj", "k_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"],
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# 加载数据集
print(f"加载数据集: {data_path}")
train_dataset = SFTDataset(data_path, tokenizer, max_length)
# 数据整理器
data_collator = DataCollatorForSeq2Seq(
tokenizer=tokenizer,
padding=True,
return_tensors="pt",
)
# 训练参数
training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=num_epochs,
per_device_train_batch_size=batch_size,
gradient_accumulation_steps=4,
learning_rate=learning_rate,
weight_decay=0.01,
warmup_ratio=0.1,
lr_scheduler_type="cosine",
logging_steps=10,
save_steps=500,
save_total_limit=3,
bf16=True,
gradient_checkpointing=True,
report_to="none",
)
# 创建 Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
data_collator=data_collator,
)
# 开始训练
print("开始训练...")
trainer.train()
# 保存模型
print(f"保存模型到: {output_dir}")
trainer.save_model()
tokenizer.save_pretrained(output_dir)
print("训练完成!")
# ========== 使用示例 ==========
if __name__ == "__main__":
train_sft(
model_name="Qwen/Qwen-7B",
data_path="sft_data.json",
output_dir="./sft_model",
use_lora=True,
num_epochs=3,
batch_size=4,
learning_rate=2e-5,
)5.2 使用 TRL 库(更简洁)
from trl import SFTTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from datasets import load_dataset
from peft import LoraConfig
def train_with_trl():
"""使用 TRL 的 SFTTrainer"""
# 加载模型和 tokenizer
model_name = "meta-llama/Llama-2-7b-hf"
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
# 加载数据集
dataset = load_dataset("tatsu-lab/alpaca", split="train")
# 格式化函数
def formatting_func(example):
text = f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
{example['instruction']}
### Input:
{example['input']}
### Response:
{example['output']}"""
return text
# LoRA 配置
peft_config = LoraConfig(
r=64,
lora_alpha=16,
lora_dropout=0.1,
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
task_type="CAUSAL_LM",
)
# 训练参数
training_args = TrainingArguments(
output_dir="./sft_output",
num_train_epochs=3,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
learning_rate=2e-5,
bf16=True,
logging_steps=10,
save_steps=500,
)
# 创建 SFTTrainer
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=dataset,
tokenizer=tokenizer,
formatting_func=formatting_func,
peft_config=peft_config,
max_seq_length=2048,
)
# 训练
trainer.train()
# 保存
trainer.save_model()
if __name__ == "__main__":
train_with_trl()5.3 准备指令数据
import json
def create_sft_dataset():
"""创建示例 SFT 数据集"""
data = [
# 简单问答
{
"conversations": [
{"role": "user", "content": "什么是机器学习?"},
{"role": "assistant", "content": "机器学习是人工智能的一个分支,它使计算机系统能够从数据中学习和改进,而无需进行明确的编程。通过分析大量数据,机器学习算法可以识别模式并做出决策或预测。"}
]
},
# 多轮对话
{
"conversations": [
{"role": "user", "content": "你好"},
{"role": "assistant", "content": "你好!我是一个AI助手,有什么可以帮助你的吗?"},
{"role": "user", "content": "帮我写一段 Python 代码计算阶乘"},
{"role": "assistant", "content": "好的,这是一个计算阶乘的 Python 函数:\n\n```python\ndef factorial(n):\n if n <= 1:\n return 1\n return n * factorial(n - 1)\n\n# 测试\nprint(factorial(5)) # 输出: 120\n```\n\n这个函数使用递归方式计算阶乘。你也可以使用循环实现:\n\n```python\ndef factorial_iterative(n):\n result = 1\n for i in range(2, n + 1):\n result *= i\n return result\n```"}
]
},
# Alpaca 格式
{
"instruction": "将以下文本翻译成英文",
"input": "人工智能正在改变世界",
"output": "Artificial intelligence is changing the world."
},
{
"instruction": "写一首关于编程的俳句",
"input": "",
"output": "代码如流水\n调试中日月交替\n终见绿灯亮"
},
# 更多数据...
]
# 保存
with open('sft_data.json', 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
print(f"创建了 {len(data)} 条训练数据")
return data
# 转换其他格式的数据
def convert_alpaca_to_conversations(alpaca_data):
"""将 Alpaca 格式转换为 conversations 格式"""
converted = []
for item in alpaca_data:
conversations = []
# 构建用户消息
user_content = item['instruction']
if item.get('input'):
user_content += f"\n\n{item['input']}"
conversations.append({
"role": "user",
"content": user_content
})
# 助手回复
conversations.append({
"role": "assistant",
"content": item['output']
})
converted.append({"conversations": conversations})
return converted
if __name__ == "__main__":
create_sft_dataset()六、SFT 的高级技巧
6.1 数据混合策略
不同类型的任务应该按比例混合:
def create_mixed_dataset(datasets_config):
"""
混合不同来源的数据集
datasets_config = {
"general_chat": {"path": "chat.json", "ratio": 0.3},
"coding": {"path": "code.json", "ratio": 0.25},
"math": {"path": "math.json", "ratio": 0.15},
"writing": {"path": "writing.json", "ratio": 0.15},
"knowledge": {"path": "knowledge.json", "ratio": 0.15},
}
"""
all_data = []
for name, config in datasets_config.items():
with open(config['path'], 'r') as f:
data = json.load(f)
# 按比例采样
target_size = int(len(data) * config['ratio'])
sampled = random.sample(data, min(target_size, len(data)))
# 添加来源标记
for item in sampled:
item['source'] = name
all_data.extend(sampled)
print(f"{name}: {len(sampled)} 条")
random.shuffle(all_data)
return all_data6.2 动态 Padding
避免浪费计算在 padding token 上:
from dataclasses import dataclass
from transformers import PreTrainedTokenizerBase
@dataclass
class DynamicPaddingCollator:
"""动态 padding 到 batch 内最大长度"""
tokenizer: PreTrainedTokenizerBase
padding: bool = True
max_length: int = None
def __call__(self, features):
# 找到 batch 内最大长度
max_len = max(len(f['input_ids']) for f in features)
if self.max_length:
max_len = min(max_len, self.max_length)
batch = {
'input_ids': [],
'attention_mask': [],
'labels': [],
}
for f in features:
# Padding
padding_length = max_len - len(f['input_ids'])
input_ids = f['input_ids'] + [self.tokenizer.pad_token_id] * padding_length
attention_mask = f['attention_mask'] + [0] * padding_length
labels = f['labels'] + [-100] * padding_length
batch['input_ids'].append(input_ids[:max_len])
batch['attention_mask'].append(attention_mask[:max_len])
batch['labels'].append(labels[:max_len])
# 转换为 tensor
return {k: torch.tensor(v) for k, v in batch.items()}6.3 NEFTune:训练时加噪声
NEFTune 是一个简单但有效的技巧:在 embedding 层加入噪声。
def neftune_forward(self, input_ids):
"""NEFTune: 在训练时给 embedding 加噪声"""
embeddings = self.original_forward(input_ids)
if self.training:
# 添加均匀分布噪声
noise = torch.rand_like(embeddings) * 2 - 1 # [-1, 1]
noise = noise * self.neftune_alpha / (embeddings.size(1) ** 0.5)
embeddings = embeddings + noise
return embeddings
def apply_neftune(model, alpha=5.0):
"""应用 NEFTune"""
embed_layer = model.get_input_embeddings()
embed_layer.original_forward = embed_layer.forward
embed_layer.neftune_alpha = alpha
embed_layer.forward = lambda x: neftune_forward(embed_layer, x)
return model6.4 Packing:提高训练效率
把多个短样本打包成一个长序列,减少 padding:
def pack_sequences(examples, max_length=2048, sep_token_id=None):
"""
将多个短序列打包成长序列
原本:
[seq1] [pad] [pad] [pad]
[seq2] [pad] [pad]
[seq3] [pad] [pad] [pad] [pad]
打包后:
[seq1] [sep] [seq2] [sep] [seq3] [pad]
"""
packed_input_ids = []
packed_labels = []
current_input_ids = []
current_labels = []
for ex in examples:
input_ids = ex['input_ids']
labels = ex['labels']
# 如果加上当前样本会超长,先保存当前打包结果
if len(current_input_ids) + len(input_ids) + 1 > max_length:
if current_input_ids:
packed_input_ids.append(current_input_ids)
packed_labels.append(current_labels)
current_input_ids = []
current_labels = []
# 添加分隔符
if current_input_ids and sep_token_id:
current_input_ids.append(sep_token_id)
current_labels.append(-100)
current_input_ids.extend(input_ids)
current_labels.extend(labels)
# 保存最后一个
if current_input_ids:
packed_input_ids.append(current_input_ids)
packed_labels.append(current_labels)
return packed_input_ids, packed_labels七、SFT 的局限性
7.1 SFT 只是"模仿"
SFT 让模型学会了输出"看起来像"人类标注的回答,但它不理解什么是"好"回答。
问题: 1+1等于几?
可能的回答 A: 1+1等于2。
可能的回答 B: 1+1等于3。
可能的回答 C: 这是一个很好的数学问题,让我们来分析一下...
SFT 模型可能会输出 B 或 C,因为它只学会了"格式",
不知道 A 才是"正确"的回答。7.2 无法学习偏好
有些回答没有明确的"对错",只有"好坏":
问题: 写一首诗
回答 A: 春眠不觉晓,处处闻啼鸟。
回答 B: 床前明月光,疑是地上霜。
哪个更好?这取决于用户的偏好。
SFT 无法学习这种偏好。7.3 幻觉问题
SFT 模型可能会"一本正经地胡说八道":
用户: 阿姆斯特朗是什么时候登月的?
SFT 模型: 尼尔·阿姆斯特朗于 1969 年 7 月 21 日成为第一个登上月球的人。
(可能还会编造更多细节,即使它并不确定)7.4 引出 RLHF
这些问题,需要 RLHF(Reinforcement Learning from Human Feedback) 来解决:
八、总结
SFT 核心流程
只在回复处计算损失] D --> E[训练] E --> F[SFT 模型] F --> G[能够理解指令] F --> H[能够对话] F --> I[但仍有局限...] style A fill:#ff6b6b style F fill:#4ecdc4
关键 Takeaway
- SFT 教模型"对话格式":从续写转变为问答
- 只在回复部分计算损失:让模型学习"如何回答"
- Chat Template 很重要:统一的格式才能正确训练
- 数据质量 > 数量:LIMA 论文证明 1000 条高质量数据就够
- SFT 的局限:只会模仿,不懂好坏,需要 RLHF 进一步优化
实践建议
| 场景 | 建议 |
|---|---|
| 快速上手 | 使用 TRL 库的 SFTTrainer |
| 资源有限 | 使用 LoRA/QLoRA |
| 中文模型 | 选择 Qwen、ChatGLM 等 |
| 数据准备 | 优先保证质量,多样性 |
| 训练稳定 | 使用梯度裁剪、warmup |
下一步学习
- [ ] RLHF/DPO:人类偏好对齐
- [ ] LoRA:参数高效微调
- [ ] 评估:如何衡量 SFT 效果
参考资料
- InstructGPT Paper - Training language models to follow instructions
- LIMA Paper - Less Is More for Alignment
- Alpaca - Stanford Alpaca
- TRL Library - HuggingFace TRL
- NEFTune Paper - Noisy Embeddings Improve Instruction Finetuning