一 LoRA介绍

LoRA 的灵感来自一个学术发现:大模型的权重矩阵其实是“低秩”的。 也就是说,虽然参数多,但模型学习新知识时,真正起作用的变化其实可以用很小的矩阵来表达。

1.1 工作原理:

假设模型里有一个原始参数矩阵 W₀(维度是 d × k):

  1. 冻结主干:微调时,W₀ 完全不动,不参与梯度更新。

  2. 旁路矩阵:在 W₀ 旁边增加两个极小的矩阵 A 和 B。

    • 矩阵 A 的维度是 d × r

    • 矩阵 B 的维度是 r × k

    • 这里的 r 就是 Rank(秩),通常设得很小(如 8, 16, 64)。

  3. 计算逻辑

    原本的输出是 h = W₀x,现在变成了:

    h = W_0x + BAx

  4. 训练阶段:只训练 A 和 B。因为 r 很小,参数量比 W₀ 小了几个数量级。

1.2 LoRA 的核心优势

  • 显存需求骤减:因为训练的参数极少,原本需要多张 A100 才能跑的模型,现在一张消费级显卡(如 RTX 4090)可能就能跑。

  • 无推理延迟:这是 LoRA 最大的优势。在推理阶段,你可以直接将训练好的 B × A 的结果加回到 W₀ 中:

W_{new} = W_0 + BA

  • 灵活切换:主模型(底座)只有一个,针对不同任务(如翻译、写代码、角色扮演)只需加载几兆到几百兆的 LoRA 权重包,实现“一主多适配器”。

二 具体实现

2.1 LoRA网络结构

根据上一节的提到的公式,我们可知,LoRA最主要的实现就是两个低秩矩阵,但是在标准的 LoRA 论文中,旁路的计算公式不仅是 $BAx$,而是 $\frac{\alpha}{r} BAx$

所以我们还需要引入一个缩放因子 alpha

实现代码如下:

class LoRA(nn.Module):
    def __init__(self, in_features, out_features, rank, alpha=16):
        super().__init__()
        self.rank = rank
        self.alpha = alpha
        self.scaling = alpha / rank  # 计算缩放因子

        self.A = nn.Linear(in_features, rank, bias=False)
        self.B = nn.Linear(rank, out_features, bias=False)

        # 矩阵A高斯初始化
        self.A.weight.data.normal_(mean=0.0, std=0.02)
        # 矩阵B全0初始化
        self.B.weight.data.zero_()

    def forward(self, x):
        return self.B(self.A(x)) * self.scaling

2.2 应用LoRA

主流的应用方式是通过层名称(如匹配 q_proj, v_proj)来接入LoRA,所以在代码的最开始,我们需要实现层名称的匹配。

最后因为我们只需要训练LoRA网络层,所以我们需要冻结原始层,只训练LoRA网络层。

以下是完整代码:

def apply_lora(model, rank=16, alpha=16, target_modules=None):
    target_layers = []

    # 先冻结目标列表,再做注入
    for name, module in list(model.named_modules()):
        if not isinstance(module, nn.Linear):
            continue
        if ".lora." in name or hasattr(module, "lora"):
            continue
        if target_modules is not None and not any(target in name for target in target_modules):
            continue
        target_layers.append(module)

    for module in target_layers:
        lora = LoRA(
            module.in_features,
            module.out_features,
            rank=rank,
            alpha=alpha,
        ).to(device=module.weight.device, dtype=module.weight.dtype)
        setattr(module, "lora", lora)
        original_forward = module.forward

        # 显式绑定当前层,避免闭包引用错位。
        def forward_with_lora(x, base_forward=original_forward, lora_layer=lora):
            return base_forward(x) + lora_layer(x)

        module.forward = forward_with_lora

    for name, param in model.named_parameters():
        param.requires_grad = "lora" in name

    print(f"LoRA applied. Rank: {rank}, Alpha: {alpha}, Layers: {len(target_layers)}")

2.3 加载LoRA

这一部分主要就是从总的 state_dict 里筛出某个模块对应的 LoRA 参数,改名后,再加载到 module.lora 这个子模块里。

比如总的 state_dict 里可能有这些键:

{
    "layers.0.attn.lora.A.weight": ...,
    "layers.0.attn.lora.B.weight": ...,
}

我们通过字段匹配然后进行喜欢成适配在2.1重定义的矩阵名称,最终 lora_state 会变成:

{
    "A.weight": ...,
    "B.weight": ...
}

完整代码如下:

def load_lora(model, path):
    device = next(model.parameters()).device
    state_dict = torch.load(path, map_location=device)

    for name, module in model.named_modules():
        if not hasattr(module, "lora"):
            continue

        prefix = f"{name}.lora."
        lora_state = {}
        for key, value in state_dict.items():
            if key.startswith(prefix):
                lora_state[key.replace(prefix, "")] = value

        module.lora.load_state_dict(lora_state, strict=False)

2.4 保存LoRA

因为 LoRA 的思想就是:

  • 原模型参数冻结
  • 只训练额外插入的小矩阵参数

所以训练结束后,真正“变化”的只有 LoRA 那部分。

如果保存整个模型会有很多的缺点,比如:文件大、冗余、不方便迁移,但如果只保存 LoRA,保存的模型文件变小,加载更快,并且方便复用到同一个基座模型上。

又因为我们在前面进行加载的时候改变了原来模型的网络层名称,所以这里保存的时候我们也需要还原回来,完成代码如下:

def save_lora(model, path):
    raw_model = getattr(model, "_orig_mod", model)
    state_dict = {}

    for name, module in raw_model.named_modules():
        if not hasattr(module, "lora"):
            continue

        clean_name = name[7:] if name.startswith("module.") else name
        lora_state = {
            f"{clean_name}.lora.{key}": value.detach().cpu().half()
            for key, value in module.lora.state_dict().items()
        }
        state_dict.update(lora_state)

    torch.save(state_dict, path)

2.5 合并LoRA

这一部分总体流程:

  • 把 LoRA 参数加载进模型
  • 取出原始模型本体
  • 构造一个新的 state_dict,去掉所有 .lora. 参数
  • 对每个线性层,如果它带 LoRA,就把 LoRA 的增量加到原始 weight 上,再保存

即先把 LoRA 参数加载到模型里,再把 LoRA 对原线性层的增量权重真正“合并”到 Linear.weight 中,最后保存成一个不再依赖 LoRA 的普通模型权重文件。

也就是把:原始权重 W + LoRA增量 ΔW,合成一个新的最终权重:W_merged = W + ΔW

这样保存后的模型,推理时就不需要再额外挂 lora 模块了。

以下是完整代码:

def merge_lora(model, lora_path, save_path):
    load_lora(model, lora_path)
    raw_model = getattr(model, "_orig_mod", model)

    state_dict = {
        key: value.detach().cpu().half()
        for key, value in raw_model.state_dict().items()
        if ".lora." not in key
    }

    with torch.no_grad():
        for name, module in raw_model.named_modules():
            if not isinstance(module, nn.Linear) or ".lora." in name:
                continue

            merged_weight = module.weight.detach().float().cpu().clone()
            if hasattr(module, "lora"):
                delta = (
                    module.lora.B.weight.detach().float()
                    @ module.lora.A.weight.detach().float()
                ) * float(module.lora.scaling)
                merged_weight += delta.cpu()

            state_dict[f"{name}.weight"] = merged_weight.half()

    torch.save(state_dict, save_path)

2.6 训练代码

整体的训练代码与预训练训练代码其实差距不大,最主要的就是我们不再进行全参数的训练,需要冻结那些不需要进行lora微调的网络层,并且数据适配器不再使用PretrainDataset,而是是使用SFTDataset数据适配器,完整代码如下:

import os
import sys

__package__ = "trainer"
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

import argparse
import time
import warnings
from contextlib import nullcontext

import torch
from torch import optim
from torch.utils.data import DataLoader

from dataset.dataloader import SFTDataset
from model.model_lora import apply_lora, save_lora
from model.model_pocketllm import PocketLLMConfig
from trainer.trainer_utils import Logger, SkipBatchSampler, get_lr, init_model, lm_checkpoint, setup_seed

warnings.filterwarnings("ignore")

CHECKPOINT_DIR = "../checkpoints"


def resolve_device(device: str) -> str:
    if device.startswith("cuda") and not torch.cuda.is_available():
        Logger("CUDA 不可用,已回退到 CPU")
        return "cpu"
    if device == "cuda" and torch.cuda.is_available():
        return "cuda:0"
    return device


def train_epoch(epoch, loader, iters, lora_params, start_step=0, wandb=None):
    start_time = time.time()
    last_step = start_step

    for step, (input_ids, labels) in enumerate(loader, start=start_step + 1):
        input_ids = input_ids.to(args.device)
        labels = labels.to(args.device)
        last_step = step

        lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
        for param_group in optimizer.param_groups:
            param_group["lr"] = lr

        with autocast_ctx:
            res = model(input_ids, labels=labels)
            loss = res.loss + res.aux_loss
            loss = loss / args.accumulation_steps

        scaler.scale(loss).backward()

        if step % args.accumulation_steps == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(lora_params, args.grad_clip)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)

        if step % args.log_interval == 0 or step == iters:
            spend_time = time.time() - start_time
            current_loss = loss.item() * args.accumulation_steps
            current_aux_loss = res.aux_loss.item() if res.aux_loss is not None else 0.0
            current_logits_loss = current_loss - current_aux_loss
            current_lr = optimizer.param_groups[-1]["lr"]
            eta_min = spend_time / max(step - start_step, 1) * (iters - step) // 60
            Logger(
                f"Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), "
                f"loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, "
                f"aux_loss: {current_aux_loss:.4f}, lr: {current_lr:.8f}, epoch_time: {eta_min:.1f}min"
            )
            if wandb:
                wandb.log(
                    {
                        "loss": current_loss,
                        "logits_loss": current_logits_loss,
                        "aux_loss": current_aux_loss,
                        "learning_rate": current_lr,
                        "epoch_time": eta_min,
                    }
                )

        if step % args.save_interval == 0 or step == iters:
            model.eval()
            moe_suffix = "_moe" if lm_config.use_moe else ""
            # 保存 LoRA 权重
            lora_save_path = f"{args.save_dir}/{args.lora_name}_{lm_config.hidden_size}{moe_suffix}.pth"
            save_lora(model, lora_save_path)
            lm_checkpoint(
                lm_config,
                weight=args.lora_name,
                model=model,
                optimizer=optimizer,
                scaler=scaler,
                epoch=epoch,
                step=step,
                wandb=wandb,
                save_dir=CHECKPOINT_DIR,
            )
            model.train()

        del input_ids, labels, res, loss

    if last_step > start_step and last_step % args.accumulation_steps != 0:
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(lora_params, args.grad_clip)
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad(set_to_none=True)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="PocketLLM LoRA 单卡微调")
    parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
    parser.add_argument("--lora_name", type=str, default="lora_medical", help="LoRA 权重名称")
    parser.add_argument("--epochs", type=int, default=10, help="训练轮数")
    parser.add_argument("--batch_size", type=int, default=32, help="批大小")
    parser.add_argument("--learning_rate", type=float, default=1e-4, help="初始学习率")
    parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
    parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
    parser.add_argument("--num_workers", type=int, default=8, help="数据加载线程数")
    parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数")
    parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
    parser.add_argument("--log_interval", type=int, default=10, help="日志打印间隔")
    parser.add_argument("--save_interval", type=int, default=1000, help="模型保存间隔")
    parser.add_argument("--hidden_size", default=768, type=int, help="隐藏层维度")
    parser.add_argument("--num_hidden_layers", default=8, type=int, help="隐藏层数量")
    parser.add_argument("--max_seq_len", default=340, type=int, help="训练截断长度")
    parser.add_argument("--use_moe", default=0, type=int, choices=[0, 1], help="是否使用 MoE 结构")
    parser.add_argument(
        "--data_path",
        type=str,
        default="../dataset/data/lora_medical.jsonl",
        help="LoRA 训练数据路径",
    )
    parser.add_argument("--from_weight", default="full_sft", type=str, help="基座权重名称")
    parser.add_argument("--from_resume", default=0, type=int, choices=[0, 1], help="是否自动续训")
    parser.add_argument("--use_wandb", action="store_true", help="是否启用实验日志平台")
    parser.add_argument("--wandb_project", type=str, default="PocketLLM-LoRA", help="实验项目名称")
    parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否启用 torch.compile")
    args = parser.parse_args()

    # ========== 1. 初始化环境和随机种子 ==========
    args.device = resolve_device(args.device)
    setup_seed(42)
    os.makedirs(args.save_dir, exist_ok=True)

    # ========== 2. 配置模型参数、检查ckp ==========
    lm_config = PocketLLMConfig(
        hidden_size=args.hidden_size,
        num_hidden_layers=args.num_hidden_layers,
        use_moe=bool(args.use_moe),
    )
    ckp_data = lm_checkpoint(lm_config, weight=args.lora_name, save_dir=CHECKPOINT_DIR) if args.from_resume == 1 else None

    # ========== 3. 设置混合精度 ==========
    device_type = "cuda" if args.device.startswith("cuda") else "cpu"
    amp_dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
    autocast_ctx = nullcontext() if device_type == "cpu" else torch.amp.autocast(device_type=device_type, dtype=amp_dtype)

    # ========== 4. 配置实验日志平台 ==========
    wandb = None
    if args.use_wandb:
        import swanlab as wandb

        wandb_id = ckp_data.get("wandb_id") if ckp_data else None
        resume = "must" if wandb_id else None
        wandb_run_name = (
            f"PocketLLM-LoRA-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
        )
        wandb.init(
            project=args.wandb_project,
            name=wandb_run_name,
            id=wandb_id,
            resume=resume,
        )

    # ========== 5. 定义模型、数据、优化器 ==========
    model, tokenizer = init_model(lm_config, args.from_weight, device=args.device)
    apply_lora(model)

    total_params = sum(p.numel() for p in model.parameters())
    lora_params = [param for name, param in model.named_parameters() if "lora" in name]
    lora_params_count = sum(param.numel() for param in lora_params)

    if not lora_params:
        raise RuntimeError("没有检测到可训练的 LoRA 参数,请检查 apply_lora 是否成功注入。")

    Logger(f"LLM 总参数量: {total_params / 1e6:.3f} M")
    Logger(f"LoRA 参数量: {lora_params_count / 1e6:.3f} M")
    Logger(f"LoRA 参数占比: {lora_params_count / total_params * 100:.2f}%")

    train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
    scaler = torch.amp.GradScaler(device_type, enabled=(args.dtype == "float16"))
    optimizer = optim.AdamW(lora_params, lr=args.learning_rate)

    # ========== 6. 从ckp恢复状态 ==========
    start_epoch, start_step = 0, 0
    if ckp_data:
        model.load_state_dict(ckp_data["model"])
        if ckp_data.get("optimizer") is not None:
            optimizer.load_state_dict(ckp_data["optimizer"])
        if ckp_data.get("scaler") is not None:
            scaler.load_state_dict(ckp_data["scaler"])
        start_epoch = ckp_data["epoch"]
        start_step = ckp_data.get("step", 0)

    # ========== 7. 可选编译模型 ==========
    if args.use_compile == 1:
        model = torch.compile(model)
        Logger("已启用 torch.compile")

    # ========== 8. 开始训练 ==========
    for epoch in range(start_epoch, args.epochs):
        setup_seed(42 + epoch)
        indices = torch.randperm(len(train_ds)).tolist()
        skip = start_step if (epoch == start_epoch and start_step > 0) else 0
        batch_sampler = SkipBatchSampler(indices, args.batch_size, skip)
        loader = DataLoader(
            train_ds,
            batch_sampler=batch_sampler,
            num_workers=args.num_workers,
            pin_memory=(device_type == "cuda"),
        )

        if skip > 0:
            Logger(f"Epoch [{epoch + 1}/{args.epochs}]:跳过前 {start_step} 个 step,从 step {start_step + 1} 开始")
            train_epoch(epoch, loader, len(loader) + skip, lora_params, start_step, wandb)
        else:
            train_epoch(epoch, loader, len(loader), lora_params, 0, wandb)

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐