Day2:SFT 有监督微调原理
·
一、什么是 SFT?
SFT 全称是 Supervised Fine-Tuning(有监督微调),是大模型微调的基础步骤,也是目前工业界最主流的微调方式。
1. 大白话定义
给大模型喂「输入(prompt)+ 期望输出(response)」的配对数据,让模型学习这种固定的对应关系,从而学会按你的要求说话、输出固定格式或特定风格。
你可以把它理解为:
- 模型是个通用员工,会做所有事但没固定章法
- SFT 就是公司给它做岗前培训,教它「遇到这种问题就这么回答」
- 培训完,它就成了一个懂业务、守规矩的 “专属员工”
2. 为什么 SFT 是微调的基础?
- 它能把通用大模型对齐到你的任务目标上
- 是后续 RLHF(人类反馈强化学习)等高级微调的前置步骤
- 技术门槛相对低,效果立竿见影,适合入门
二、SFT 核心原理拆解
1. 核心思想:模仿学习
SFT 的本质,是让模型模仿数据里的 “标准答案”。 举个例子:
表格
| 输入(prompt) | 期望输出(response) |
|---|---|
| 你是一个客服,用户说 “退款” 该怎么回复? | 亲,请问您是要申请全额退款还是部分退款呢?可以跟我说一下具体原因吗? |
模型会通过学习大量这样的「输入 - 输出对」,慢慢形成 “客服话术” 的固定风格,之后用户说 “退款”,它就会自动按培训过的话术回复。
2. 训练过程(和预训练的区别)
- 预训练:用海量无标注数据,让模型学习语言规律和通用知识(成本极高,普通人做不了)
- SFT 微调:用少量有标注的业务数据,让模型学习特定任务的规则(成本可控,我们可以做)
3. 损失函数(简单了解即可)
SFT 用的是交叉熵损失函数,目标是让模型生成的回答,和数据里的标准答案尽可能接近。 训练时,模型会不断调整参数,让 “自己的输出” 和 “标准答案” 的差距越来越小,这就是微调的过程。
三、SFT 完整流程(面试高频考点)
-
数据准备
- 核心要求:数据格式必须是「prompt + response」配对
- 数据质量:必须准确、符合业务规范,脏数据会教坏模型
- 数据量:通常几千到几万条就有明显效果,不用像预训练那样百万级
-
数据格式转换 把原始数据转换成模型能识别的格式,比如 Alpaca 格式:
json
{ "instruction": "你是客服,用户说退款怎么回复?", "input": "", "output": "亲,请问您是要申请全额退款还是部分退款呢?可以跟我说一下具体原因吗?" } -
加载基础模型与配置
- 选择开源基础模型(比如 LLaMA、ChatGLM、Qwen 等)
- 配置训练参数(学习率、批次大小、训练轮数等)
-
训练微调
- 启动训练,让模型学习你的数据规律
- 监控训练损失,确保模型在正常收敛(损失值持续下降)
-
评估与验证
- 用训练时没见过的测试数据,测试模型的输出效果
- 重点看:回答是否符合业务要求、风格是否稳定、格式是否正确
四、SFT vs 提示词工程(面试必问)
表格
| 维度 | 提示词工程 | SFT 有监督微调 |
|---|---|---|
| 实现方式 | 写 prompt,临时指挥模型 | 用配对数据训练,把规则刻进模型 |
| 效果稳定性 | 不稳定,换个问题就容易跑偏 | 稳定,长期保持固定风格和格式 |
| 适用场景 | 临时需求、快速验证想法 | 固定业务场景、大规模统一风格 |
| 成本 | 低,只需要写 prompt | 中,需要准备数据 + 训练 |
| 可扩展性 | 差,prompt 太长会超出上下文限制 | 强,训练完模型自带能力,不占上下文 |
一句话总结:
- 提示词是 “临时抱佛脚”,适合快速试错
- SFT 是 “长期改造”,适合固定业务场景的落地
五、代码示例(简化版,帮你建立直观感受)
python
运行
# 伪代码,展示SFT的核心流程
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
# 1. 加载数据集(已经是prompt-response格式)
dataset = load_dataset("json", data_files="sft_data.json")
# 2. 加载基础模型和分词器
model_name = "qwen-7b"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
# 3. 数据预处理:把prompt+response拼接成模型输入
def preprocess_function(examples):
inputs = tokenizer(
[f"问题:{q}\n回答:{a}" for q, a in zip(examples["instruction"], examples["output"])],
truncation=True,
max_length=512,
padding="max_length"
)
inputs["labels"] = inputs["input_ids"].copy() # 标签就是输入本身,让模型学习生成下一个token
return inputs
tokenized_dataset = dataset.map(preprocess_function, batched=True)
# 4. 配置训练参数
training_args = TrainingArguments(
output_dir="./sft_model",
per_device_train_batch_size=4,
learning_rate=2e-5,
num_train_epochs=3,
logging_steps=10,
save_strategy="epoch"
)
# 5. 启动训练
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset["train"]
)
trainer.train()
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐

所有评论(0)