基于 DeepSeek 模型的知识蒸馏训练自定义知识库模型
·
一、核心思路
- 环境与依赖:安装必要的依赖库(如 transformers、torch、datasets 等) LoRA相关依赖,适配低显存环境;
- 数据层优化:增加数据清洗、格式校验、简单数据增强;
- 模型训练优化:引入PEFT(LoRA)实现增量蒸馏,大幅降低显存占用;
- 训练过程可控:增加损失监控、早停机制,避免过拟合;
- 效果评估:新增定量评估指标(困惑度、准确率),而非仅人工测试;
- 部署适配:补充模型量化导出、推理加速的示例。
二、完整实现代码(增强版)
1. 环境安装(补充LoRA/评估依赖)
# 基础依赖
pip install torch transformers datasets accelerate peft bitsandbytes sentencepiece
# 数据处理与评估
pip install pandas tqdm evaluate rouge-score numpy scikit-learn
# 模型量化部署(可选)
pip install optimum auto-gptq
2. 数据准备
(1)基础格式(JSONL)
knowledge_base.jsonl(新增多轮对话示例,适配复杂场景):
{"question": "什么是大语言模型的知识蒸馏?", "answer": "知识蒸馏是将大模型的知识迁移到小模型的过程,通过模仿教师模型的输出,让小模型保留核心能力,同时降低显存和算力消耗。"}
{"question": "DeepSeek模型的核心特点是什么?", "answer": "DeepSeek模型具备优秀的代码能力和通用对话能力,基于Transformer架构,支持长上下文,且有不同参数量版本(1.3B/7B/67B)适配不同场景。"}
{"question": "蒸馏后的模型相比原模型有什么优势?", "answer": "蒸馏后的模型参数量更小,推理速度更快,部署成本更低,同时保留了原模型的核心知识和推理能力,适合边缘设备或低算力服务器。"}
(2)数据增强脚本
# data_augmentation.py(用于扩充知识库数据)
import json
import random
def augment_data(input_path, output_path):
"""简单数据增强:同义词替换、句式微调"""
# 同义词词典(可根据领域扩展)
synonym_dict = {
"知识蒸馏": ["模型蒸馏", "知识迁移"],
"显存": ["内存", "显卡内存"],
"推理": ["推断", "推演"],
"核心特点": ["主要特征", "关键特性"]
}
with open(input_path, "r", encoding="utf-8") as f, open(output_path, "w", encoding="utf-8") as out_f:
for line in f:
data = json.loads(line.strip())
# 原始数据保留
out_f.write(json.dumps(data, ensure_ascii=False) + "\n")
# 同义词替换生成增强数据
aug_question = data["question"]
aug_answer = data["answer"]
for word, synonyms in synonym_dict.items():
if word in aug_question:
aug_question = aug_question.replace(word, random.choice(synonyms))
if word in aug_answer:
aug_answer = aug_answer.replace(word, random.choice(synonyms))
out_f.write(json.dumps({"question": aug_question, "answer": aug_answer}, ensure_ascii=False) + "\n")
# 使用示例
augment_data("knowledge_base.jsonl", "knowledge_base_aug.jsonl")
3. 蒸馏训练代码
import torch
import json
import numpy as np
from tqdm import tqdm
from datasets import Dataset, load_metric
from transformers import (
AutoModelForCausalLM, AutoTokenizer,
TrainingArguments, Trainer, DataCollatorForLanguageModeling,
EarlyStoppingCallback, BitsAndBytesConfig
)
from peft import LoraConfig, get_peft_model, TaskType
import warnings
warnings.filterwarnings("ignore")
# ===================== 1. 基础配置(增强版) =====================
# 设备设置:优先使用GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
# 教师模型(DeepSeek)
teacher_model_name = "deepseek-ai/deepseek-llm-7b-chat"
# 学生模型(轻量级,可根据需求调整)
student_model_name = "deepseek-ai/deepseek-llm-1.3b"
# 自定义知识库数据路径(使用增强后的数据)
data_path = "knowledge_base_aug.jsonl"
# 训练参数
output_dir = "./distilled_deepseek_model"
batch_size = 4 # 因LoRA优化,可适当增大批次
epochs = 5
learning_rate = 3e-5 # 降低学习率提升稳定性
gradient_accumulation_steps = 2 # 梯度累积,变相增大批次
# LoRA配置(核心:增量训练,降低显存)
lora_r = 8
lora_alpha = 32
lora_dropout = 0.05
# ===================== 2. 量化配置(精细化) =====================
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4", # 更适合LLM的量化类型
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True, # 二级量化,进一步节省显存
)
# ===================== 3. 加载tokenizer和模型(增强版) =====================
# 加载tokenizer(补充特殊token处理)
tokenizer = AutoTokenizer.from_pretrained(teacher_model_name)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right" # 避免推理时警告
# 加载教师模型(量化优化)
teacher_model = AutoModelForCausalLM.from_pretrained(
teacher_model_name,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
)
teacher_model.eval() # 教师模型仅用于生成软标签,不训练
# 加载学生模型 + LoRA(核心优化:仅训练LoRA权重,显存占用降低80%)
student_model = AutoModelForCausalLM.from_pretrained(
student_model_name,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True,
)
# 配置LoRA
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=lora_r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
target_modules=["q_proj", "v_proj"], # 针对Transformer的关键层
)
student_model = get_peft_model(student_model, peft_config)
# 打印可训练参数(验证LoRA生效)
student_model.print_trainable_parameters() # 输出:trainable params: ~0.1% of total params
# ===================== 4. 数据处理(增强版) =====================
def load_and_validate_data(path):
"""加载并校验数据格式,过滤无效数据"""
data = []
with open(path, "r", encoding="utf-8") as f:
for idx, line in enumerate(f):
try:
line = json.loads(line.strip())
# 校验必填字段
if "question" not in line or "answer" not in line:
print(f"跳过无效数据(第{idx+1}行):缺少question/answer字段")
continue
# 过滤空内容
if not line["question"].strip() or not line["answer"].strip():
print(f"跳过无效数据(第{idx+1}行):内容为空")
continue
# 构建对话格式(匹配DeepSeek的输入格式)
prompt = f"用户:{line['question']}\n助手:"
data.append({
"prompt": prompt,
"answer": line["answer"],
"full_text": prompt + line["answer"]
})
except json.JSONDecodeError:
print(f"跳过无效数据(第{idx+1}行):JSON格式错误")
continue
return Dataset.from_list(data)
def preprocess_function(examples):
"""优化数据预处理,适配批量处理"""
# 编码文本(优化max_length,避免截断关键内容)
encodings = tokenizer(
examples["full_text"],
truncation=True,
max_length=1024, # 增大上下文长度
padding="max_length",
return_tensors="pt",
)
# 构建标签(精准mask prompt部分)
labels = []
for prompt, full_text in zip(examples["prompt"], examples["full_text"]):
# 计算prompt的token长度(避免截断导致长度计算错误)
prompt_ids = tokenizer.encode(prompt, truncation=True, max_length=1024)
prompt_len = len(prompt_ids)
# 前prompt_len个token设为-100(不计算损失)
label = [-100] * prompt_len + encodings["input_ids"][len(labels)][prompt_len:].tolist()
# 确保label长度与input_ids一致
label = label[:1024] + [-100] * (1024 - len(label))
labels.append(label)
encodings["labels"] = torch.tensor(labels)
return encodings
# 加载并处理数据(拆分训练/验证集,避免过拟合)
dataset = load_and_validate_data(data_path)
# 拆分训练集(90%)和验证集(10%)
dataset = dataset.train_test_split(test_size=0.1, seed=42)
tokenized_train = dataset["train"].map(preprocess_function, batched=True, remove_columns=dataset["train"].column_names)
tokenized_val = dataset["test"].map(preprocess_function, batched=True, remove_columns=dataset["test"].column_names)
# ===================== 5. 蒸馏损失函数(增强版) =====================
class DistillationTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
# 1. 学生模型前向传播
student_outputs = model(
input_ids=inputs["input_ids"].to(device),
attention_mask=inputs["attention_mask"].to(device),
labels=inputs["labels"].to(device)
)
student_loss = student_outputs.loss # 监督损失(硬标签)
# 2. 教师模型前向传播(生成软标签)
with torch.no_grad():
teacher_outputs = teacher_model(
input_ids=inputs["input_ids"].to(device),
attention_mask=inputs["attention_mask"].to(device),
)
# 3. 计算蒸馏损失(优化KL散度计算,避免数值不稳定)
temperature = 3.0 # 调整温度系数,提升软标签效果
# 只计算answer部分的KL散度(mask掉prompt部分)
mask = (inputs["labels"] != -100).float()
student_logits = student_outputs.logits / temperature
teacher_logits = teacher_outputs.logits / temperature
# 稳定的KL散度计算
kl_div = torch.nn.functional.kl_div(
torch.log_softmax(student_logits, dim=-1),
torch.softmax(teacher_logits, dim=-1),
reduction="none",
log_target=False,
)
# 仅计算mask部分(answer)的KL散度
kl_loss = (kl_div.sum(-1) * mask).sum() / mask.sum() * (temperature ** 2)
# 4. 动态权重:训练初期侧重蒸馏损失,后期侧重监督损失
epoch = self.state.epoch if self.state.epoch else 0
distill_weight = max(0.3, 0.7 - epoch * 0.1) # 从0.7逐步降到0.3
supervise_weight = 1 - distill_weight
total_loss = supervise_weight * student_loss + distill_weight * kl_loss
return (total_loss, student_outputs) if return_outputs else total_loss
# ===================== 6. 训练配置(增强版) =====================
# 加载评估指标(困惑度)
perplexity_metric = load_metric("perplexity")
def compute_metrics(eval_pred):
"""自定义评估指标:困惑度(越低越好)"""
logits, labels = eval_pred
# 计算困惑度
predictions = torch.from_numpy(logits)
labels = torch.from_numpy(labels)
# 过滤mask部分(-100)
mask = labels != -100
filtered_logits = predictions[mask.unsqueeze(-1).expand(predictions.shape)]
filtered_labels = labels[mask]
# 计算困惑度
perplexity = torch.exp(torch.nn.functional.cross_entropy(filtered_logits, filtered_labels))
return {"perplexity": perplexity.item()}
training_args = TrainingArguments(
output_dir=output_dir,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
num_train_epochs=epochs,
learning_rate=learning_rate,
fp16=True,
gradient_accumulation_steps=gradient_accumulation_steps,
logging_steps=5, # 更频繁的日志输出
eval_steps=20, # 定期验证
save_steps=20,
save_total_limit=3,
remove_unused_columns=False,
push_to_hub=False,
report_to="none",
# 早停机制(避免过拟合)
load_best_model_at_end=True,
metric_for_best_model="perplexity",
greater_is_better=False, # 困惑度越低越好
# 优化器配置(提升训练稳定性)
optim="paged_adamw_8bit", # 8bit优化器,进一步节省显存
weight_decay=0.01,
warmup_ratio=0.1, # 学习率预热
)
# 数据整理器
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False,
)
# ===================== 7. 开始训练(增强版) =====================
trainer = DistillationTrainer(
model=student_model,
args=training_args,
train_dataset=tokenized_train,
eval_dataset=tokenized_val, # 验证集
data_collator=data_collator,
compute_metrics=compute_metrics, # 自定义评估指标
callbacks=[EarlyStoppingCallback(early_stopping_patience=3)], # 早停
)
# 启动训练
trainer.train()
# ===================== 8. 保存模型(增强版) =====================
# 保存LoRA权重 + 完整模型
student_model.save_pretrained(f"{output_dir}/lora_weights")
# 合并LoRA权重到主模型(方便部署)
merged_model = student_model.merge_and_unload()
merged_model.save_pretrained(f"{output_dir}/merged_model")
tokenizer.save_pretrained(f"{output_dir}/merged_model")
print(f"蒸馏后的模型已保存至:{output_dir}")
# ===================== 9. 模型评估(定量+定性) =====================
def evaluate_model_qualitative():
"""定性评估:人工可感知的回答质量"""
test_questions = [
"什么是大语言模型的知识蒸馏?",
"DeepSeek模型的核心特点是什么?",
"蒸馏后的模型相比原模型有什么优势?"
]
print("\n===== 定性评估结果 =====")
for question in test_questions:
prompt = f"用户:{question}\n助手:"
inputs = tokenizer(prompt, return_tensors="pt").to(device)
outputs = merged_model.generate(
**inputs,
max_new_tokens=200,
temperature=0.7,
top_p=0.9,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True).replace(prompt, "")
print(f"问题:{question}\n回答:{response}\n")
def evaluate_model_quantitative():
"""定量评估:困惑度 + 准确率(简易版)"""
print("\n===== 定量评估结果 =====")
# 计算验证集困惑度
eval_results = trainer.evaluate()
print(f"验证集困惑度:{eval_results['eval_perplexity']:.2f}")
# 简易准确率计算(匹配关键词)
correct = 0
total = len(tokenized_val)
for idx, example in enumerate(dataset["test"]):
question = example["question"]
true_answer = example["answer"]
prompt = f"用户:{question}\n助手:"
inputs = tokenizer(prompt, return_tensors="pt").to(device)
outputs = merged_model.generate(**inputs, max_new_tokens=200, temperature=0.1)
pred_answer = tokenizer.decode(outputs[0], skip_special_tokens=True).replace(prompt, "")
# 关键词匹配(可根据领域调整)
if any(keyword in pred_answer for keyword in true_answer.split()[:3]):
correct += 1
accuracy = correct / total
print(f"简易准确率(关键词匹配):{accuracy:.2%}")
# 执行评估
evaluate_model_qualitative()
evaluate_model_quantitative()
# ===================== 10. 模型部署优化(可选) =====================
def export_quantized_model():
"""导出量化模型(GPTQ),进一步加速推理"""
from optimum.gptq import GPTQQuantizer, load_quantized_model
quantizer = GPTQQuantizer(
bits=4,
dataset="c4",
tokenizer=tokenizer,
group_size=128,
)
# 量化并保存
quantized_model = quantizer.quantize_model(merged_model, tokenizer)
quantized_model.save_pretrained(f"{output_dir}/quantized_model")
tokenizer.save_pretrained(f"{output_dir}/quantized_model")
print("量化模型已保存,推理速度提升3-5倍")
# 导出量化模型(按需执行)
# export_quantized_model()
三、关键增强点解释
1. LoRA轻量化训练(核心)
- 原理:仅训练模型的少量适配层(LoRA层),而非全量参数,显存占用降低80%以上;
- 效果:原本需要12GB显存的训练,现在仅需2-4GB即可运行,普通消费级显卡(如RTX 3060/3070)也能训练;
- 代码关键:
peft_config配置LoRA参数,get_peft_model封装学生模型,print_trainable_parameters()验证仅0.1%参数可训练。
2. 数据层优化
- 数据校验:过滤格式错误、空内容的无效数据,避免训练崩溃;
- 数据增强:通过同义词替换扩充数据集,提升模型泛化能力;
- 训练/验证集拆分:避免过拟合,通过验证集监控训练效果。
3. 训练过程可控
- 早停机制:
EarlyStoppingCallback当验证集困惑度连续3次上升时停止训练,避免过拟合; - 动态损失权重:训练初期侧重学习教师模型的通用能力(高蒸馏权重),后期侧重学习知识库(高监督权重);
- 定量评估:引入困惑度(Perplexity)指标,客观衡量模型效果(越低越好),而非仅人工测试。
4. 部署优化
- 模型合并:将LoRA权重合并到主模型,避免推理时依赖PEFT库;
- GPTQ量化:导出4bit量化模型,推理速度提升3-5倍,适配低算力部署场景。
四、适配场景与参数调整指南
| 场景 | 显存要求 | 核心参数调整建议 |
|---|---|---|
| 高端显卡(24GB+) | 24GB+ | batch_size=8,epochs=5,lora_r=16 |
| 中端显卡(8-12GB) | 8-12GB | batch_size=2,gradient_accumulation_steps=4,lora_r=8 |
| 低端显卡(4-8GB) | 4-8GB | 改用student_model_name=“distilbert-base-chinese”,batch_size=1 |
| CPU训练(应急) | 16GB+内存 | 关闭FP16,batch_size=1,epochs=2,仅用于测试 |
效果调优建议
- 若模型回答偏离知识库:增大监督损失权重(如
distill_weight=0.2,supervise_weight=0.8); - 若模型通用能力差:增大蒸馏损失权重(如
distill_weight=0.6,supervise_weight=0.4); - 若过拟合(验证集困惑度上升):减少epochs、增大lora_dropout、扩充数据集。
五、常见问题解决方案
| 问题现象 | 原因 | 解决方案 |
|---|---|---|
| 显存溢出(OOM) | 批量过大/未启用LoRA | 启用LoRA、降低batch_size、开启梯度累积 |
| 训练损失为NaN | 数值不稳定 | 降低学习率、调整温度系数、使用paged_adamw_8bit优化器 |
| 模型生成重复内容 | 过拟合/温度系数过低 | 增大temperature、增加数据量、启用早停 |
| 加载模型时报trust_remote_code错误 | 模型需要自定义代码 | 增加trust_remote_code=True参数 |
总结
- 核心优化:引入LoRA实现增量蒸馏,显存占用降低80%,普通显卡也能训练;补充数据校验、早停、定量评估,解决训练不稳定、效果不可控问题。
- 关键参数:LoRA的
r(8-16)、温度系数(2-3)、损失权重(动态调整)是影响效果的核心,需根据显存和数据量适配。 - 落地建议:优先用增强版数据处理保证数据质量,训练后通过困惑度和关键词准确率评估效果,最后导出量化模型用于部署。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐


所有评论(0)