一、核心思路

  1. 环境与依赖:安装必要的依赖库(如 transformers、torch、datasets 等) LoRA相关依赖,适配低显存环境;
  2. 数据层优化:增加数据清洗、格式校验、简单数据增强;
  3. 模型训练优化:引入PEFT(LoRA)实现增量蒸馏,大幅降低显存占用;
  4. 训练过程可控:增加损失监控、早停机制,避免过拟合;
  5. 效果评估:新增定量评估指标(困惑度、准确率),而非仅人工测试;
  6. 部署适配:补充模型量化导出、推理加速的示例。

二、完整实现代码(增强版)

1. 环境安装(补充LoRA/评估依赖)

# 基础依赖
pip install torch transformers datasets accelerate peft bitsandbytes sentencepiece
# 数据处理与评估
pip install pandas tqdm evaluate rouge-score numpy scikit-learn
# 模型量化部署(可选)
pip install optimum auto-gptq

2. 数据准备

(1)基础格式(JSONL)

knowledge_base.jsonl(新增多轮对话示例,适配复杂场景):

{"question": "什么是大语言模型的知识蒸馏?", "answer": "知识蒸馏是将大模型的知识迁移到小模型的过程,通过模仿教师模型的输出,让小模型保留核心能力,同时降低显存和算力消耗。"}
{"question": "DeepSeek模型的核心特点是什么?", "answer": "DeepSeek模型具备优秀的代码能力和通用对话能力,基于Transformer架构,支持长上下文,且有不同参数量版本(1.3B/7B/67B)适配不同场景。"}
{"question": "蒸馏后的模型相比原模型有什么优势?", "answer": "蒸馏后的模型参数量更小,推理速度更快,部署成本更低,同时保留了原模型的核心知识和推理能力,适合边缘设备或低算力服务器。"}
(2)数据增强脚本
# data_augmentation.py(用于扩充知识库数据)
import json
import random

def augment_data(input_path, output_path):
    """简单数据增强:同义词替换、句式微调"""
    # 同义词词典(可根据领域扩展)
    synonym_dict = {
        "知识蒸馏": ["模型蒸馏", "知识迁移"],
        "显存": ["内存", "显卡内存"],
        "推理": ["推断", "推演"],
        "核心特点": ["主要特征", "关键特性"]
    }
    
    with open(input_path, "r", encoding="utf-8") as f, open(output_path, "w", encoding="utf-8") as out_f:
        for line in f:
            data = json.loads(line.strip())
            # 原始数据保留
            out_f.write(json.dumps(data, ensure_ascii=False) + "\n")
            
            # 同义词替换生成增强数据
            aug_question = data["question"]
            aug_answer = data["answer"]
            for word, synonyms in synonym_dict.items():
                if word in aug_question:
                    aug_question = aug_question.replace(word, random.choice(synonyms))
                if word in aug_answer:
                    aug_answer = aug_answer.replace(word, random.choice(synonyms))
            out_f.write(json.dumps({"question": aug_question, "answer": aug_answer}, ensure_ascii=False) + "\n")

# 使用示例
augment_data("knowledge_base.jsonl", "knowledge_base_aug.jsonl")

3. 蒸馏训练代码

import torch
import json
import numpy as np
from tqdm import tqdm
from datasets import Dataset, load_metric
from transformers import (
    AutoModelForCausalLM, AutoTokenizer,
    TrainingArguments, Trainer, DataCollatorForLanguageModeling,
    EarlyStoppingCallback, BitsAndBytesConfig
)
from peft import LoraConfig, get_peft_model, TaskType
import warnings
warnings.filterwarnings("ignore")

# ===================== 1. 基础配置(增强版) =====================
# 设备设置:优先使用GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
# 教师模型(DeepSeek)
teacher_model_name = "deepseek-ai/deepseek-llm-7b-chat"
# 学生模型(轻量级,可根据需求调整)
student_model_name = "deepseek-ai/deepseek-llm-1.3b"
# 自定义知识库数据路径(使用增强后的数据)
data_path = "knowledge_base_aug.jsonl"
# 训练参数
output_dir = "./distilled_deepseek_model"
batch_size = 4  # 因LoRA优化,可适当增大批次
epochs = 5
learning_rate = 3e-5  # 降低学习率提升稳定性
gradient_accumulation_steps = 2  # 梯度累积,变相增大批次
# LoRA配置(核心:增量训练,降低显存)
lora_r = 8
lora_alpha = 32
lora_dropout = 0.05

# ===================== 2. 量化配置(精细化) =====================
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",  # 更适合LLM的量化类型
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,  # 二级量化,进一步节省显存
)

# ===================== 3. 加载tokenizer和模型(增强版) =====================
# 加载tokenizer(补充特殊token处理)
tokenizer = AutoTokenizer.from_pretrained(teacher_model_name)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"  # 避免推理时警告

# 加载教师模型(量化优化)
teacher_model = AutoModelForCausalLM.from_pretrained(
    teacher_model_name,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True,
)
teacher_model.eval()  # 教师模型仅用于生成软标签,不训练

# 加载学生模型 + LoRA(核心优化:仅训练LoRA权重,显存占用降低80%)
student_model = AutoModelForCausalLM.from_pretrained(
    student_model_name,
    torch_dtype=torch.float16,
    device_map="auto",
    trust_remote_code=True,
)
# 配置LoRA
peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    inference_mode=False,
    r=lora_r,
    lora_alpha=lora_alpha,
    lora_dropout=lora_dropout,
    target_modules=["q_proj", "v_proj"],  # 针对Transformer的关键层
)
student_model = get_peft_model(student_model, peft_config)
# 打印可训练参数(验证LoRA生效)
student_model.print_trainable_parameters()  # 输出:trainable params: ~0.1% of total params

# ===================== 4. 数据处理(增强版) =====================
def load_and_validate_data(path):
    """加载并校验数据格式,过滤无效数据"""
    data = []
    with open(path, "r", encoding="utf-8") as f:
        for idx, line in enumerate(f):
            try:
                line = json.loads(line.strip())
                # 校验必填字段
                if "question" not in line or "answer" not in line:
                    print(f"跳过无效数据(第{idx+1}行):缺少question/answer字段")
                    continue
                # 过滤空内容
                if not line["question"].strip() or not line["answer"].strip():
                    print(f"跳过无效数据(第{idx+1}行):内容为空")
                    continue
                # 构建对话格式(匹配DeepSeek的输入格式)
                prompt = f"用户:{line['question']}\n助手:"
                data.append({
                    "prompt": prompt,
                    "answer": line["answer"],
                    "full_text": prompt + line["answer"]
                })
            except json.JSONDecodeError:
                print(f"跳过无效数据(第{idx+1}行):JSON格式错误")
                continue
    return Dataset.from_list(data)

def preprocess_function(examples):
    """优化数据预处理,适配批量处理"""
    # 编码文本(优化max_length,避免截断关键内容)
    encodings = tokenizer(
        examples["full_text"],
        truncation=True,
        max_length=1024,  # 增大上下文长度
        padding="max_length",
        return_tensors="pt",
    )
    # 构建标签(精准mask prompt部分)
    labels = []
    for prompt, full_text in zip(examples["prompt"], examples["full_text"]):
        # 计算prompt的token长度(避免截断导致长度计算错误)
        prompt_ids = tokenizer.encode(prompt, truncation=True, max_length=1024)
        prompt_len = len(prompt_ids)
        # 前prompt_len个token设为-100(不计算损失)
        label = [-100] * prompt_len + encodings["input_ids"][len(labels)][prompt_len:].tolist()
        # 确保label长度与input_ids一致
        label = label[:1024] + [-100] * (1024 - len(label))
        labels.append(label)
    encodings["labels"] = torch.tensor(labels)
    return encodings

# 加载并处理数据(拆分训练/验证集,避免过拟合)
dataset = load_and_validate_data(data_path)
# 拆分训练集(90%)和验证集(10%)
dataset = dataset.train_test_split(test_size=0.1, seed=42)
tokenized_train = dataset["train"].map(preprocess_function, batched=True, remove_columns=dataset["train"].column_names)
tokenized_val = dataset["test"].map(preprocess_function, batched=True, remove_columns=dataset["test"].column_names)

# ===================== 5. 蒸馏损失函数(增强版) =====================
class DistillationTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        # 1. 学生模型前向传播
        student_outputs = model(
            input_ids=inputs["input_ids"].to(device),
            attention_mask=inputs["attention_mask"].to(device),
            labels=inputs["labels"].to(device)
        )
        student_loss = student_outputs.loss  # 监督损失(硬标签)
        
        # 2. 教师模型前向传播(生成软标签)
        with torch.no_grad():
            teacher_outputs = teacher_model(
                input_ids=inputs["input_ids"].to(device),
                attention_mask=inputs["attention_mask"].to(device),
            )
        
        # 3. 计算蒸馏损失(优化KL散度计算,避免数值不稳定)
        temperature = 3.0  # 调整温度系数,提升软标签效果
        # 只计算answer部分的KL散度(mask掉prompt部分)
        mask = (inputs["labels"] != -100).float()
        student_logits = student_outputs.logits / temperature
        teacher_logits = teacher_outputs.logits / temperature
        
        # 稳定的KL散度计算
        kl_div = torch.nn.functional.kl_div(
            torch.log_softmax(student_logits, dim=-1),
            torch.softmax(teacher_logits, dim=-1),
            reduction="none",
            log_target=False,
        )
        # 仅计算mask部分(answer)的KL散度
        kl_loss = (kl_div.sum(-1) * mask).sum() / mask.sum() * (temperature ** 2)
        
        # 4. 动态权重:训练初期侧重蒸馏损失,后期侧重监督损失
        epoch = self.state.epoch if self.state.epoch else 0
        distill_weight = max(0.3, 0.7 - epoch * 0.1)  # 从0.7逐步降到0.3
        supervise_weight = 1 - distill_weight
        total_loss = supervise_weight * student_loss + distill_weight * kl_loss
        
        return (total_loss, student_outputs) if return_outputs else total_loss

# ===================== 6. 训练配置(增强版) =====================
# 加载评估指标(困惑度)
perplexity_metric = load_metric("perplexity")

def compute_metrics(eval_pred):
    """自定义评估指标:困惑度(越低越好)"""
    logits, labels = eval_pred
    # 计算困惑度
    predictions = torch.from_numpy(logits)
    labels = torch.from_numpy(labels)
    # 过滤mask部分(-100)
    mask = labels != -100
    filtered_logits = predictions[mask.unsqueeze(-1).expand(predictions.shape)]
    filtered_labels = labels[mask]
    # 计算困惑度
    perplexity = torch.exp(torch.nn.functional.cross_entropy(filtered_logits, filtered_labels))
    return {"perplexity": perplexity.item()}

training_args = TrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=epochs,
    learning_rate=learning_rate,
    fp16=True,
    gradient_accumulation_steps=gradient_accumulation_steps,
    logging_steps=5,  # 更频繁的日志输出
    eval_steps=20,    # 定期验证
    save_steps=20,
    save_total_limit=3,
    remove_unused_columns=False,
    push_to_hub=False,
    report_to="none",
    # 早停机制(避免过拟合)
    load_best_model_at_end=True,
    metric_for_best_model="perplexity",
    greater_is_better=False,  # 困惑度越低越好
    # 优化器配置(提升训练稳定性)
    optim="paged_adamw_8bit",  # 8bit优化器,进一步节省显存
    weight_decay=0.01,
    warmup_ratio=0.1,  # 学习率预热
)

# 数据整理器
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,
)

# ===================== 7. 开始训练(增强版) =====================
trainer = DistillationTrainer(
    model=student_model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_val,  # 验证集
    data_collator=data_collator,
    compute_metrics=compute_metrics,  # 自定义评估指标
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],  # 早停
)

# 启动训练
trainer.train()

# ===================== 8. 保存模型(增强版) =====================
# 保存LoRA权重 + 完整模型
student_model.save_pretrained(f"{output_dir}/lora_weights")
# 合并LoRA权重到主模型(方便部署)
merged_model = student_model.merge_and_unload()
merged_model.save_pretrained(f"{output_dir}/merged_model")
tokenizer.save_pretrained(f"{output_dir}/merged_model")
print(f"蒸馏后的模型已保存至:{output_dir}")

# ===================== 9. 模型评估(定量+定性) =====================
def evaluate_model_qualitative():
    """定性评估:人工可感知的回答质量"""
    test_questions = [
        "什么是大语言模型的知识蒸馏?",
        "DeepSeek模型的核心特点是什么?",
        "蒸馏后的模型相比原模型有什么优势?"
    ]
    print("\n===== 定性评估结果 =====")
    for question in test_questions:
        prompt = f"用户:{question}\n助手:"
        inputs = tokenizer(prompt, return_tensors="pt").to(device)
        outputs = merged_model.generate(
            **inputs,
            max_new_tokens=200,
            temperature=0.7,
            top_p=0.9,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id,
        )
        response = tokenizer.decode(outputs[0], skip_special_tokens=True).replace(prompt, "")
        print(f"问题:{question}\n回答:{response}\n")

def evaluate_model_quantitative():
    """定量评估:困惑度 + 准确率(简易版)"""
    print("\n===== 定量评估结果 =====")
    # 计算验证集困惑度
    eval_results = trainer.evaluate()
    print(f"验证集困惑度:{eval_results['eval_perplexity']:.2f}")
    
    # 简易准确率计算(匹配关键词)
    correct = 0
    total = len(tokenized_val)
    for idx, example in enumerate(dataset["test"]):
        question = example["question"]
        true_answer = example["answer"]
        prompt = f"用户:{question}\n助手:"
        inputs = tokenizer(prompt, return_tensors="pt").to(device)
        outputs = merged_model.generate(**inputs, max_new_tokens=200, temperature=0.1)
        pred_answer = tokenizer.decode(outputs[0], skip_special_tokens=True).replace(prompt, "")
        # 关键词匹配(可根据领域调整)
        if any(keyword in pred_answer for keyword in true_answer.split()[:3]):
            correct += 1
    accuracy = correct / total
    print(f"简易准确率(关键词匹配):{accuracy:.2%}")

# 执行评估
evaluate_model_qualitative()
evaluate_model_quantitative()

# ===================== 10. 模型部署优化(可选) =====================
def export_quantized_model():
    """导出量化模型(GPTQ),进一步加速推理"""
    from optimum.gptq import GPTQQuantizer, load_quantized_model
    
    quantizer = GPTQQuantizer(
        bits=4,
        dataset="c4",
        tokenizer=tokenizer,
        group_size=128,
    )
    # 量化并保存
    quantized_model = quantizer.quantize_model(merged_model, tokenizer)
    quantized_model.save_pretrained(f"{output_dir}/quantized_model")
    tokenizer.save_pretrained(f"{output_dir}/quantized_model")
    print("量化模型已保存,推理速度提升3-5倍")

# 导出量化模型(按需执行)
# export_quantized_model()

三、关键增强点解释

1. LoRA轻量化训练(核心)

  • 原理:仅训练模型的少量适配层(LoRA层),而非全量参数,显存占用降低80%以上;
  • 效果:原本需要12GB显存的训练,现在仅需2-4GB即可运行,普通消费级显卡(如RTX 3060/3070)也能训练;
  • 代码关键peft_config配置LoRA参数,get_peft_model封装学生模型,print_trainable_parameters()验证仅0.1%参数可训练。

2. 数据层优化

  • 数据校验:过滤格式错误、空内容的无效数据,避免训练崩溃;
  • 数据增强:通过同义词替换扩充数据集,提升模型泛化能力;
  • 训练/验证集拆分:避免过拟合,通过验证集监控训练效果。

3. 训练过程可控

  • 早停机制EarlyStoppingCallback当验证集困惑度连续3次上升时停止训练,避免过拟合;
  • 动态损失权重:训练初期侧重学习教师模型的通用能力(高蒸馏权重),后期侧重学习知识库(高监督权重);
  • 定量评估:引入困惑度(Perplexity)指标,客观衡量模型效果(越低越好),而非仅人工测试。

4. 部署优化

  • 模型合并:将LoRA权重合并到主模型,避免推理时依赖PEFT库;
  • GPTQ量化:导出4bit量化模型,推理速度提升3-5倍,适配低算力部署场景。

四、适配场景与参数调整指南

场景 显存要求 核心参数调整建议
高端显卡(24GB+) 24GB+ batch_size=8,epochs=5,lora_r=16
中端显卡(8-12GB) 8-12GB batch_size=2,gradient_accumulation_steps=4,lora_r=8
低端显卡(4-8GB) 4-8GB 改用student_model_name=“distilbert-base-chinese”,batch_size=1
CPU训练(应急) 16GB+内存 关闭FP16,batch_size=1,epochs=2,仅用于测试

效果调优建议

  • 若模型回答偏离知识库:增大监督损失权重(如distill_weight=0.2supervise_weight=0.8);
  • 若模型通用能力差:增大蒸馏损失权重(如distill_weight=0.6supervise_weight=0.4);
  • 若过拟合(验证集困惑度上升):减少epochs、增大lora_dropout、扩充数据集。

五、常见问题解决方案

问题现象 原因 解决方案
显存溢出(OOM) 批量过大/未启用LoRA 启用LoRA、降低batch_size、开启梯度累积
训练损失为NaN 数值不稳定 降低学习率、调整温度系数、使用paged_adamw_8bit优化器
模型生成重复内容 过拟合/温度系数过低 增大temperature、增加数据量、启用早停
加载模型时报trust_remote_code错误 模型需要自定义代码 增加trust_remote_code=True参数

总结

  1. 核心优化:引入LoRA实现增量蒸馏,显存占用降低80%,普通显卡也能训练;补充数据校验、早停、定量评估,解决训练不稳定、效果不可控问题。
  2. 关键参数:LoRA的r(8-16)、温度系数(2-3)、损失权重(动态调整)是影响效果的核心,需根据显存和数据量适配。
  3. 落地建议:优先用增强版数据处理保证数据质量,训练后通过困惑度和关键词准确率评估效果,最后导出量化模型用于部署。
Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐