题记

在Datawhale微信公众号上看到6月有关大模型专题学习,其中AMD中文教程:Hello-ROCm,正好近来空闲时间比较多,可以利用这个机会,深入实操有关大模型 的微调相关的事宜,深入理解数据处理、训练、效果对比。

体验AI训练师工作

本次的任务是将把一个只会通用聊天的模型,通过特定领域的数据训练成能够精准识别6 种人类情绪的“情感分析专家” 。这个把通用模型改造成特定领域专家的过程,就叫 模型微调(Fine-Tuning) 。
真实的 AI 微调涉及复杂的代码、环境配置和硬件调优,即便是计算机专业的同学也可能被各种报错劝退。
为了让你专注于 理解微调核心逻辑 ,我们已经将所有复杂操作打包优化,准备了“一键顺滑运行”的实验报告(Notebook)。

任务简介

本次的任务在 AI-ModelScope/emotion上单卡微调Gemma 4 E4B-it(魔搭 ModelScope 版本)。由于都是已经实现写好的代码,只需要按部就班的执行这些代码,体验一把AI微调工程师的工作。根据个人体验将整个体验总结如下

一、AI微调过程总结

该 Jupyter Notebook 完整实现了一个基于 ModelScope 的单卡 LoRA 微调流程,目标是将 Google Gemma 4 E4B-it(指令微调版本)适配到情绪分类任务(dair-ai/emotion 数据集)。整体流程分为以下关键阶段:

1. 环境配置与依赖安装

  • 安装 modelscopetransformersdatasetstrlpeftscikit-learn 等核心库。
  • 设置全局配置:输出目录、数据量限制(训练4000/验证400/测试400)、随机种子42、BF16精度、系统提示词等。
  • 固定随机种子确保可复现性。

2. 数据准备(ModelScope 下载与本地加载)

  • 模型下载:通过 modelscope.snapshot_downloadgoogle/gemma-4-E4B-it 仓库拉取到本地 ./models 目录。
  • 数据集下载:通过 dataset_snapshot_downloadAI-ModelScope/emotion 仓库拉取到本地,然后使用 datasets.load_dataset("parquet", ...) 从本地 parquet 文件加载,并显式将 label 字段转换为 ClassLabel(避免版本兼容问题)。
  • 数据转换:将原始 (text, label) 样本转换为指令微调格式 —— prompt 包含 system 和 user 消息,completion 仅包含 assistant 的情绪标签(如 "joy")。

3. 模型与 Tokenizer 加载

  • 从本地模型路径加载 tokenizer,设置 pad_token,若 chat_template 缺失则从 ModelScope 拉取官方模板并注入。
  • 加载基础模型(AutoModelForCausalLM),指定 torch_dtype=BF16,关闭 use_cache,移至 GPU。
  • 验证 apply_chat_template 工作正常。

4. 微调前评估

  • 实现 evaluate_model:对测试集逐条生成预测,提取标签(正则匹配或首 token),计算准确率、macro F1、无效预测数,并输出分类报告和混淆矩阵。
  • 基础模型在测试集上准确率 62.5%,macro F1 0.482,部分类别(love、surprise)表现较差。

5. LoRA 配置与训练

  • LoRA 参数r=16, lora_alpha=32, dropout=0.05, target_modules="all-linear"
  • 训练参数(SFTConfig)
    • per_device_train_batch_size=4, gradient_accumulation_steps=4(有效 batch=16)
    • 学习率 1e-4,线性调度,warmup 50 步
    • 训练 1 个 epoch,max_length=256
    • 启用梯度检查点,BF16,优化器 adamw_torch(避免 bitsandbytes ROCm 问题)
  • SFTTrainer:接收模型、数据集、LoRA 配置、tokenizer,自动处理 chat template 和损失计算(completion_only_loss=True)。
  • 可训练参数约 5,050 万,占总参数 0.63%。
  • 训练过程记录 loss 和验证 loss,最终 training_loss ≈ 0.314

6. 微调后评估与保存

  • 使用训练后的 LoRA 模型再次评估测试集(指标应显著提升)。
  • 保存 adapter 权重、tokenizer、训练指标 JSON、对比 CSV、分类报告、混淆矩阵等。
  • 提供示例推理和重新加载 adapter 的代码。

7. 常见问题处理

  • 针对 ModelScope 授权、数据集 parquet 路径、ROCm 环境显存不足、全量数据使用等给出了具体建议。

二、LLM 微调过程注意事项(通用 + 本案例经验)

1. 数据质量与格式

  • 标签一致性:确保模型输出严格限定在预定义标签集合内(通过系统提示、正则后处理、completion_only_loss 等)。
  • 指令格式匹配:使用模型原生的 chat_template(如 Gemma 的 <|turn|> 格式),避免手动拼接导致模板错误。
  • 数据平衡:检查类别分布,若某些样本极少(如 love 仅 26 条),可考虑过采样或调整损失权重。
  • 数据量控制:先用小样本(如 4000 条)验证流程,再扩大至全量(1.6 万)。

2. 模型与 Tokenizer

  • Tokenizer 完整性:确认 pad_tokeneos_tokenchat_template 已正确设置,否则训练时可能产生无限生成或格式错误。
  • 模型加载
    • 明确 torch_dtype(BF16/FP16),避免默认 FP32 显存爆炸。
    • 设置 use_cache=False 以启用梯度检查点。
    • 在 ROCm 环境下避免使用 bitsandbytes 4bit 量化(兼容性差),改用纯 BF16/FP16。
  • 设备管理:单卡训练直接 .to(device),不用 device_map="auto"

3. LoRA 配置

  • target_modules
    • "all-linear" 简单但可能增加显存;可改为 ["q_proj", "v_proj"] 等关键层平衡效果与资源。
    • 验证 LoRA 参数是否真的被附加(打印可训练参数量)。
  • Rank 与 Alphar=16, alpha=32 是常见起点;若过拟合可降低 r,若欠拟合可增大 alpha
  • Dropout0.05 适合小数据集;大数据集可适当提高。

4. 训练超参数

  • Batch Size:受显存限制,单卡可用小 batch + 梯度累积。例如 per_device_batch=4, grad_accum=4 等效 batch 16。
  • 学习率:LoRA 常用 1e-45e-5;使用线性调度 + warmup 稳定训练。
  • 序列长度(max_length:情绪分类文本短,256 足够;若处理长文本需增加并注意显存。
  • Epoch 数:分类任务收敛快,1-3 轮即可;过多易过拟合(尤其小数据集)。
  • 优化器adamw_torchadamw_bnb_8bit 更稳定(尤其在非 CUDA 环境)。

5. 评估与监控

  • 训练中评估:设置 eval_strategy="steps"eval_steps,监控验证 loss 和下游指标。
  • 生成评估:注意 max_new_tokens 设小(如 4),关闭采样(do_sample=False)以获得确定性输出。
  • 后处理:用正则或集合过滤无效输出(如 "INVALID" 类),否则会拉低准确率。
  • 对比基线:始终保存微调前的评估结果,量化提升幅度。

6. 资源与显存优化

  • 梯度检查点:必须开启,以时间换显存。
  • 混合精度:BF16(若硬件支持)优于 FP16(梯度溢出风险低)。
  • packing=False:对短文本任务关闭 packing 更简单;长文本可开启提高效率。
  • 数据加载dataloader_num_workers 根据 CPU 核心数设置,避免 I/O 瓶颈。

7. 可复现性

  • 固定 Python、NumPy、PyTorch 随机种子。
  • 固定 data_seedseed 参数。
  • 记录所有超参数和库版本(可用 pip freeze)。

8. 模型保存与推理

  • 保存 LoRA adapter 而非全量权重,节省空间。后续推理时需先加载基础模型再加载 adapter。
  • 若需部署,可将 LoRA 合并到基础模型(merge_and_unload())以避免额外加载开销。
  • 推理时恢复 use_cache=True 加速生成。

9. 环境与平台特定问题

  • ModelScope 替代 HF
    • 需在网页端接受模型许可协议(如 Gemma)。
    • 数据集用 dataset_snapshot_download + 本地 parquet 加载可绕过 MsDataset 的版本兼容 bug。
  • ROCm(AMD GPU)
    • PyTorch 通过 torch.cuda 接口访问 ROCm,无需修改代码。
    • 避免使用 bitsandbytes,注意某些算子可能较慢。
  • 网络:首次运行需下载模型(~8GB)和数据集,确保稳定连接,或提前缓存。

10. 常见陷阱

  • 未设置 pad_token 导致训练时 attention mask 出错。
  • chat_template 缺失导致 prompt 格式错误,模型输出混乱。
  • LoRA 的 target_modules 未匹配到任何层(可训练参数为 0)。
  • 评估时未限制 max_new_tokens,模型生成长文本导致标签提取失败。
  • 多卡环境强制单卡时需注意 CUDA_VISIBLE_DEVICES 设置。

以上注意事项基于该 Notebook 的实践经验,可推广至大多数指令微调或分类任务的 LLM 微调场景。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐