AMD Hello-ROCm 学习活动笔记-任务二初体验微调
题记
在Datawhale微信公众号上看到6月有关大模型专题学习,其中AMD中文教程:Hello-ROCm,正好近来空闲时间比较多,可以利用这个机会,深入实操有关大模型 的微调相关的事宜,深入理解数据处理、训练、效果对比。
体验AI训练师工作
本次的任务是将把一个只会通用聊天的模型,通过特定领域的数据训练成能够精准识别6 种人类情绪的“情感分析专家” 。这个把通用模型改造成特定领域专家的过程,就叫 模型微调(Fine-Tuning) 。
真实的 AI 微调涉及复杂的代码、环境配置和硬件调优,即便是计算机专业的同学也可能被各种报错劝退。
为了让你专注于 理解微调核心逻辑 ,我们已经将所有复杂操作打包优化,准备了“一键顺滑运行”的实验报告(Notebook)。
任务简介
本次的任务在 AI-ModelScope/emotion上单卡微调Gemma 4 E4B-it(魔搭 ModelScope 版本)。由于都是已经实现写好的代码,只需要按部就班的执行这些代码,体验一把AI微调工程师的工作。根据个人体验将整个体验总结如下
一、AI微调过程总结
该 Jupyter Notebook 完整实现了一个基于 ModelScope 的单卡 LoRA 微调流程,目标是将 Google Gemma 4 E4B-it(指令微调版本)适配到情绪分类任务(dair-ai/emotion 数据集)。整体流程分为以下关键阶段:
1. 环境配置与依赖安装
- 安装
modelscope、transformers、datasets、trl、peft、scikit-learn等核心库。 - 设置全局配置:输出目录、数据量限制(训练4000/验证400/测试400)、随机种子42、BF16精度、系统提示词等。
- 固定随机种子确保可复现性。
2. 数据准备(ModelScope 下载与本地加载)
- 模型下载:通过
modelscope.snapshot_download将google/gemma-4-E4B-it仓库拉取到本地./models目录。 - 数据集下载:通过
dataset_snapshot_download将AI-ModelScope/emotion仓库拉取到本地,然后使用datasets.load_dataset("parquet", ...)从本地 parquet 文件加载,并显式将label字段转换为ClassLabel(避免版本兼容问题)。 - 数据转换:将原始
(text, label)样本转换为指令微调格式 ——prompt包含 system 和 user 消息,completion仅包含 assistant 的情绪标签(如"joy")。
3. 模型与 Tokenizer 加载
- 从本地模型路径加载 tokenizer,设置
pad_token,若chat_template缺失则从 ModelScope 拉取官方模板并注入。 - 加载基础模型(
AutoModelForCausalLM),指定torch_dtype=BF16,关闭use_cache,移至 GPU。 - 验证
apply_chat_template工作正常。
4. 微调前评估
- 实现
evaluate_model:对测试集逐条生成预测,提取标签(正则匹配或首 token),计算准确率、macro F1、无效预测数,并输出分类报告和混淆矩阵。 - 基础模型在测试集上准确率 62.5%,macro F1 0.482,部分类别(love、surprise)表现较差。
5. LoRA 配置与训练
- LoRA 参数:
r=16, lora_alpha=32, dropout=0.05, target_modules="all-linear"。 - 训练参数(SFTConfig):
per_device_train_batch_size=4, gradient_accumulation_steps=4(有效 batch=16)- 学习率
1e-4,线性调度,warmup 50 步 - 训练 1 个 epoch,
max_length=256 - 启用梯度检查点,BF16,优化器
adamw_torch(避免 bitsandbytes ROCm 问题)
- SFTTrainer:接收模型、数据集、LoRA 配置、tokenizer,自动处理 chat template 和损失计算(
completion_only_loss=True)。 - 可训练参数约 5,050 万,占总参数 0.63%。
- 训练过程记录 loss 和验证 loss,最终
training_loss ≈ 0.314。
6. 微调后评估与保存
- 使用训练后的 LoRA 模型再次评估测试集(指标应显著提升)。
- 保存 adapter 权重、tokenizer、训练指标 JSON、对比 CSV、分类报告、混淆矩阵等。
- 提供示例推理和重新加载 adapter 的代码。
7. 常见问题处理
- 针对 ModelScope 授权、数据集 parquet 路径、ROCm 环境显存不足、全量数据使用等给出了具体建议。
二、LLM 微调过程注意事项(通用 + 本案例经验)
1. 数据质量与格式
- 标签一致性:确保模型输出严格限定在预定义标签集合内(通过系统提示、正则后处理、
completion_only_loss等)。 - 指令格式匹配:使用模型原生的
chat_template(如 Gemma 的<|turn|>格式),避免手动拼接导致模板错误。 - 数据平衡:检查类别分布,若某些样本极少(如
love仅 26 条),可考虑过采样或调整损失权重。 - 数据量控制:先用小样本(如 4000 条)验证流程,再扩大至全量(1.6 万)。
2. 模型与 Tokenizer
- Tokenizer 完整性:确认
pad_token、eos_token、chat_template已正确设置,否则训练时可能产生无限生成或格式错误。 - 模型加载:
- 明确
torch_dtype(BF16/FP16),避免默认 FP32 显存爆炸。 - 设置
use_cache=False以启用梯度检查点。 - 在 ROCm 环境下避免使用
bitsandbytes4bit 量化(兼容性差),改用纯 BF16/FP16。
- 明确
- 设备管理:单卡训练直接
.to(device),不用device_map="auto"。
3. LoRA 配置
target_modules:"all-linear"简单但可能增加显存;可改为["q_proj", "v_proj"]等关键层平衡效果与资源。- 验证 LoRA 参数是否真的被附加(打印可训练参数量)。
- Rank 与 Alpha:
r=16, alpha=32是常见起点;若过拟合可降低r,若欠拟合可增大alpha。 - Dropout:
0.05适合小数据集;大数据集可适当提高。
4. 训练超参数
- Batch Size:受显存限制,单卡可用小 batch + 梯度累积。例如
per_device_batch=4, grad_accum=4等效 batch 16。 - 学习率:LoRA 常用
1e-4到5e-5;使用线性调度 + warmup 稳定训练。 - 序列长度(
max_length):情绪分类文本短,256 足够;若处理长文本需增加并注意显存。 - Epoch 数:分类任务收敛快,1-3 轮即可;过多易过拟合(尤其小数据集)。
- 优化器:
adamw_torch比adamw_bnb_8bit更稳定(尤其在非 CUDA 环境)。
5. 评估与监控
- 训练中评估:设置
eval_strategy="steps"和eval_steps,监控验证 loss 和下游指标。 - 生成评估:注意
max_new_tokens设小(如 4),关闭采样(do_sample=False)以获得确定性输出。 - 后处理:用正则或集合过滤无效输出(如
"INVALID"类),否则会拉低准确率。 - 对比基线:始终保存微调前的评估结果,量化提升幅度。
6. 资源与显存优化
- 梯度检查点:必须开启,以时间换显存。
- 混合精度:BF16(若硬件支持)优于 FP16(梯度溢出风险低)。
packing=False:对短文本任务关闭 packing 更简单;长文本可开启提高效率。- 数据加载:
dataloader_num_workers根据 CPU 核心数设置,避免 I/O 瓶颈。
7. 可复现性
- 固定 Python、NumPy、PyTorch 随机种子。
- 固定
data_seed和seed参数。 - 记录所有超参数和库版本(可用
pip freeze)。
8. 模型保存与推理
- 保存 LoRA adapter 而非全量权重,节省空间。后续推理时需先加载基础模型再加载 adapter。
- 若需部署,可将 LoRA 合并到基础模型(
merge_and_unload())以避免额外加载开销。 - 推理时恢复
use_cache=True加速生成。
9. 环境与平台特定问题
- ModelScope 替代 HF:
- 需在网页端接受模型许可协议(如 Gemma)。
- 数据集用
dataset_snapshot_download+ 本地 parquet 加载可绕过MsDataset的版本兼容 bug。
- ROCm(AMD GPU):
- PyTorch 通过
torch.cuda接口访问 ROCm,无需修改代码。 - 避免使用
bitsandbytes,注意某些算子可能较慢。
- PyTorch 通过
- 网络:首次运行需下载模型(~8GB)和数据集,确保稳定连接,或提前缓存。
10. 常见陷阱
- 未设置
pad_token导致训练时 attention mask 出错。 chat_template缺失导致 prompt 格式错误,模型输出混乱。- LoRA 的
target_modules未匹配到任何层(可训练参数为 0)。 - 评估时未限制
max_new_tokens,模型生成长文本导致标签提取失败。 - 多卡环境强制单卡时需注意
CUDA_VISIBLE_DEVICES设置。
以上注意事项基于该 Notebook 的实践经验,可推广至大多数指令微调或分类任务的 LLM 微调场景。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐

所有评论(0)