本文为"从零手搓大语言模型"系列第 4 篇,介绍 LoRA 的原理、实现细节与应用场景。

一、LoRA 的动机

SFT 后的模型是一个通用对话模型。若需适配特定领域(医疗、法律、客服等),全参数微调面临两个问题:

  1. 计算成本高:需更新全部参数
  2. 灾难性遗忘:容易覆盖通用能力

LoRA(Low-Rank Adaptation)的解决方案:冻结原模型全部参数,仅训练额外插入的极少量参数。

二、核心原理

2.1 低秩分解

原始线性层:y = Wx(W 为 768×768 矩阵,589,824 个参数)

LoRA 在旁边插入一条低秩旁路:

y = Wx + B(A(x))
     ↑     ↑
   原始   LoRA旁路
  (冻结)(可训练)

其中:

  • A:768 → 16(降维矩阵,12,288 个参数)
  • B:16 → 768(升维矩阵,12,288 个参数)
  • 合计 24,576 个参数,为原矩阵的 4.2%

2.2 初始化策略

self.A.weight.data.normal_(mean=0.0, std=0.02)  # A 正态随机初始化
self.B.weight.data.zero_()                       # B 零初始化

B 初始化为零意味着训练开始时 LoRA 输出恒为零,模型行为与原模型完全一致。训练过程中 B 逐步偏离零值,渐进式引入新能力。

三、代码实现

3.1 LoRA 模块定义

class LoRA(nn.Module):
    def __init__(self, in_features, out_features, rank=16):
        self.A = nn.Linear(in_features, rank, bias=False)
        self.B = nn.Linear(rank, out_features, bias=False)

    def forward(self, x):
        return self.B(self.A(x))

3.2 动态插入 LoRA

def apply_lora(model, rank=16):
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear) and module.in_features == module.out_features:
            lora = LoRA(module.in_features, module.out_features, rank)
            setattr(module, "lora", lora)
            original_forward = module.forward
            
            def forward_with_lora(x, layer1=original_forward, layer2=lora):
                return layer1(x) + layer2(x)
            
            module.forward = forward_with_lora

通过 Python 运行时方法替换(monkey patching),在不修改模型源码的前提下将 LoRA 分支注入每个目标层。

3.3 冻结原模型

for name, param in model.named_parameters():
    if 'lora' in name:
        param.requires_grad = True   # LoRA 参数可训练
    else:
        param.requires_grad = False  # 原模型参数冻结

requires_grad = False 使 PyTorch 在反向传播时跳过该参数的梯度计算,节省计算和显存。

四、训练配置对比

对比项 Full SFT LoRA
可训练参数 64M(100%) ~1M(1.5%)
起始权重 pretrain full_sft
学习率 1e-5 1e-4
epochs 2 10
训练数据量 1.6GB 数 KB ~ 数 MB
保存文件大小 ~131MB ~几 MB

LoRA 学习率更大(1e-4 vs 1e-5),因为可训练参数少、梯度噪声低,可以使用更大步幅。训练轮数更多(10 epochs),因为数据量极小,需要多次遍历充分学习。

五、LoRA 权重的使用

5.1 推理时叠加

python eval_llm.py --weight full_sft --lora_weight lora_medical

加载基础模型后,将 LoRA 权重叠加到对应层上。一个基础模型可搭配多个 LoRA 权重服务不同场景。

5.2 合并导出

def merge_lora(model, lora_path, save_path):
    # W_merged = W_original + B @ A
    state_dict[f'{name}.weight'] += (module.lora.B.weight @ module.lora.A.weight)
    torch.save(state_dict, save_path)

将 LoRA 权重永久合并回基础模型,导出为新的完整权重,推理时无需额外加载。

六、适用场景

场景 示例
垂直领域适配 医疗问答、法律咨询、金融分析
自我认知定制 让模型回答"我是 XX 公司的 AI 助手"
风格调整 更正式/更口语化/特定语言风格
多租户服务 一个基础模型 + N 个 LoRA 服务 N 个客户

七、小结

LoRA 的本质是用低秩矩阵近似全参数更新。它以极小的参数开销(< 2%)实现领域适配,保留基础模型的通用能力。这种"外挂补丁"的思路是当前生产环境中最常用的微调方案。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐