从零手搓大语言模型:LoRA 参数高效微调篇
本文为"从零手搓大语言模型"系列第 4 篇,介绍 LoRA 的原理、实现细节与应用场景。
一、LoRA 的动机
SFT 后的模型是一个通用对话模型。若需适配特定领域(医疗、法律、客服等),全参数微调面临两个问题:
- 计算成本高:需更新全部参数
- 灾难性遗忘:容易覆盖通用能力
LoRA(Low-Rank Adaptation)的解决方案:冻结原模型全部参数,仅训练额外插入的极少量参数。
二、核心原理
2.1 低秩分解
原始线性层:y = Wx(W 为 768×768 矩阵,589,824 个参数)
LoRA 在旁边插入一条低秩旁路:
y = Wx + B(A(x))
↑ ↑
原始 LoRA旁路
(冻结)(可训练)
其中:
- A:768 → 16(降维矩阵,12,288 个参数)
- B:16 → 768(升维矩阵,12,288 个参数)
- 合计 24,576 个参数,为原矩阵的 4.2%
2.2 初始化策略
self.A.weight.data.normal_(mean=0.0, std=0.02) # A 正态随机初始化
self.B.weight.data.zero_() # B 零初始化
B 初始化为零意味着训练开始时 LoRA 输出恒为零,模型行为与原模型完全一致。训练过程中 B 逐步偏离零值,渐进式引入新能力。
三、代码实现
3.1 LoRA 模块定义
class LoRA(nn.Module):
def __init__(self, in_features, out_features, rank=16):
self.A = nn.Linear(in_features, rank, bias=False)
self.B = nn.Linear(rank, out_features, bias=False)
def forward(self, x):
return self.B(self.A(x))
3.2 动态插入 LoRA
def apply_lora(model, rank=16):
for name, module in model.named_modules():
if isinstance(module, nn.Linear) and module.in_features == module.out_features:
lora = LoRA(module.in_features, module.out_features, rank)
setattr(module, "lora", lora)
original_forward = module.forward
def forward_with_lora(x, layer1=original_forward, layer2=lora):
return layer1(x) + layer2(x)
module.forward = forward_with_lora
通过 Python 运行时方法替换(monkey patching),在不修改模型源码的前提下将 LoRA 分支注入每个目标层。
3.3 冻结原模型
for name, param in model.named_parameters():
if 'lora' in name:
param.requires_grad = True # LoRA 参数可训练
else:
param.requires_grad = False # 原模型参数冻结
requires_grad = False 使 PyTorch 在反向传播时跳过该参数的梯度计算,节省计算和显存。
四、训练配置对比
| 对比项 | Full SFT | LoRA |
|---|---|---|
| 可训练参数 | 64M(100%) | ~1M(1.5%) |
| 起始权重 | pretrain | full_sft |
| 学习率 | 1e-5 | 1e-4 |
| epochs | 2 | 10 |
| 训练数据量 | 1.6GB | 数 KB ~ 数 MB |
| 保存文件大小 | ~131MB | ~几 MB |
LoRA 学习率更大(1e-4 vs 1e-5),因为可训练参数少、梯度噪声低,可以使用更大步幅。训练轮数更多(10 epochs),因为数据量极小,需要多次遍历充分学习。
五、LoRA 权重的使用
5.1 推理时叠加
python eval_llm.py --weight full_sft --lora_weight lora_medical
加载基础模型后,将 LoRA 权重叠加到对应层上。一个基础模型可搭配多个 LoRA 权重服务不同场景。
5.2 合并导出
def merge_lora(model, lora_path, save_path):
# W_merged = W_original + B @ A
state_dict[f'{name}.weight'] += (module.lora.B.weight @ module.lora.A.weight)
torch.save(state_dict, save_path)
将 LoRA 权重永久合并回基础模型,导出为新的完整权重,推理时无需额外加载。
六、适用场景
| 场景 | 示例 |
|---|---|
| 垂直领域适配 | 医疗问答、法律咨询、金融分析 |
| 自我认知定制 | 让模型回答"我是 XX 公司的 AI 助手" |
| 风格调整 | 更正式/更口语化/特定语言风格 |
| 多租户服务 | 一个基础模型 + N 个 LoRA 服务 N 个客户 |
七、小结
LoRA 的本质是用低秩矩阵近似全参数更新。它以极小的参数开销(< 2%)实现领域适配,保留基础模型的通用能力。这种"外挂补丁"的思路是当前生产环境中最常用的微调方案。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐


所有评论(0)