从0完成轻量级大模型全链路训练与对齐框架——Lora微调
一 LoRA介绍
LoRA 的灵感来自一个学术发现:大模型的权重矩阵其实是“低秩”的。 也就是说,虽然参数多,但模型学习新知识时,真正起作用的变化其实可以用很小的矩阵来表达。
1.1 工作原理:
假设模型里有一个原始参数矩阵 W₀(维度是 d × k):
-
冻结主干:微调时,W₀ 完全不动,不参与梯度更新。
-
旁路矩阵:在 W₀ 旁边增加两个极小的矩阵 A 和 B。
-
矩阵 A 的维度是 d × r
-
矩阵 B 的维度是 r × k
-
这里的 r 就是 Rank(秩),通常设得很小(如 8, 16, 64)。
-
-
计算逻辑:
原本的输出是 h = W₀x,现在变成了:
-
训练阶段:只训练 A 和 B。因为 r 很小,参数量比 W₀ 小了几个数量级。
1.2 LoRA 的核心优势
-
显存需求骤减:因为训练的参数极少,原本需要多张 A100 才能跑的模型,现在一张消费级显卡(如 RTX 4090)可能就能跑。
-
无推理延迟:这是 LoRA 最大的优势。在推理阶段,你可以直接将训练好的 B × A 的结果加回到 W₀ 中:
-
灵活切换:主模型(底座)只有一个,针对不同任务(如翻译、写代码、角色扮演)只需加载几兆到几百兆的 LoRA 权重包,实现“一主多适配器”。
二 具体实现
2.1 LoRA网络结构
根据上一节的提到的公式,我们可知,LoRA最主要的实现就是两个低秩矩阵,但是在标准的 LoRA 论文中,旁路的计算公式不仅是 ,而是
。
所以我们还需要引入一个缩放因子 alpha
实现代码如下:
class LoRA(nn.Module):
def __init__(self, in_features, out_features, rank, alpha=16):
super().__init__()
self.rank = rank
self.alpha = alpha
self.scaling = alpha / rank # 计算缩放因子
self.A = nn.Linear(in_features, rank, bias=False)
self.B = nn.Linear(rank, out_features, bias=False)
# 矩阵A高斯初始化
self.A.weight.data.normal_(mean=0.0, std=0.02)
# 矩阵B全0初始化
self.B.weight.data.zero_()
def forward(self, x):
return self.B(self.A(x)) * self.scaling
2.2 应用LoRA
主流的应用方式是通过层名称(如匹配 q_proj, v_proj)来接入LoRA,所以在代码的最开始,我们需要实现层名称的匹配。
最后因为我们只需要训练LoRA网络层,所以我们需要冻结原始层,只训练LoRA网络层。
以下是完整代码:
def apply_lora(model, rank=16, alpha=16, target_modules=None):
target_layers = []
# 先冻结目标列表,再做注入
for name, module in list(model.named_modules()):
if not isinstance(module, nn.Linear):
continue
if ".lora." in name or hasattr(module, "lora"):
continue
if target_modules is not None and not any(target in name for target in target_modules):
continue
target_layers.append(module)
for module in target_layers:
lora = LoRA(
module.in_features,
module.out_features,
rank=rank,
alpha=alpha,
).to(device=module.weight.device, dtype=module.weight.dtype)
setattr(module, "lora", lora)
original_forward = module.forward
# 显式绑定当前层,避免闭包引用错位。
def forward_with_lora(x, base_forward=original_forward, lora_layer=lora):
return base_forward(x) + lora_layer(x)
module.forward = forward_with_lora
for name, param in model.named_parameters():
param.requires_grad = "lora" in name
print(f"LoRA applied. Rank: {rank}, Alpha: {alpha}, Layers: {len(target_layers)}")
2.3 加载LoRA
这一部分主要就是从总的 state_dict 里筛出某个模块对应的 LoRA 参数,改名后,再加载到 module.lora 这个子模块里。
比如总的 state_dict 里可能有这些键:
{
"layers.0.attn.lora.A.weight": ...,
"layers.0.attn.lora.B.weight": ...,
}
我们通过字段匹配然后进行喜欢成适配在2.1重定义的矩阵名称,最终 lora_state 会变成:
{
"A.weight": ...,
"B.weight": ...
}
完整代码如下:
def load_lora(model, path):
device = next(model.parameters()).device
state_dict = torch.load(path, map_location=device)
for name, module in model.named_modules():
if not hasattr(module, "lora"):
continue
prefix = f"{name}.lora."
lora_state = {}
for key, value in state_dict.items():
if key.startswith(prefix):
lora_state[key.replace(prefix, "")] = value
module.lora.load_state_dict(lora_state, strict=False)
2.4 保存LoRA
因为 LoRA 的思想就是:
- 原模型参数冻结
- 只训练额外插入的小矩阵参数
所以训练结束后,真正“变化”的只有 LoRA 那部分。
如果保存整个模型会有很多的缺点,比如:文件大、冗余、不方便迁移,但如果只保存 LoRA,保存的模型文件变小,加载更快,并且方便复用到同一个基座模型上。
又因为我们在前面进行加载的时候改变了原来模型的网络层名称,所以这里保存的时候我们也需要还原回来,完成代码如下:
def save_lora(model, path):
raw_model = getattr(model, "_orig_mod", model)
state_dict = {}
for name, module in raw_model.named_modules():
if not hasattr(module, "lora"):
continue
clean_name = name[7:] if name.startswith("module.") else name
lora_state = {
f"{clean_name}.lora.{key}": value.detach().cpu().half()
for key, value in module.lora.state_dict().items()
}
state_dict.update(lora_state)
torch.save(state_dict, path)
2.5 合并LoRA
这一部分总体流程:
- 把 LoRA 参数加载进模型
- 取出原始模型本体
- 构造一个新的
state_dict,去掉所有.lora.参数 - 对每个线性层,如果它带 LoRA,就把 LoRA 的增量加到原始
weight上,再保存
即先把 LoRA 参数加载到模型里,再把 LoRA 对原线性层的增量权重真正“合并”到 Linear.weight 中,最后保存成一个不再依赖 LoRA 的普通模型权重文件。
也就是把:原始权重 W + LoRA增量 ΔW,合成一个新的最终权重:W_merged = W + ΔW
这样保存后的模型,推理时就不需要再额外挂 lora 模块了。
以下是完整代码:
def merge_lora(model, lora_path, save_path):
load_lora(model, lora_path)
raw_model = getattr(model, "_orig_mod", model)
state_dict = {
key: value.detach().cpu().half()
for key, value in raw_model.state_dict().items()
if ".lora." not in key
}
with torch.no_grad():
for name, module in raw_model.named_modules():
if not isinstance(module, nn.Linear) or ".lora." in name:
continue
merged_weight = module.weight.detach().float().cpu().clone()
if hasattr(module, "lora"):
delta = (
module.lora.B.weight.detach().float()
@ module.lora.A.weight.detach().float()
) * float(module.lora.scaling)
merged_weight += delta.cpu()
state_dict[f"{name}.weight"] = merged_weight.half()
torch.save(state_dict, save_path)
2.6 训练代码
整体的训练代码与预训练训练代码其实差距不大,最主要的就是我们不再进行全参数的训练,需要冻结那些不需要进行lora微调的网络层,并且数据适配器不再使用PretrainDataset,而是是使用SFTDataset数据适配器,完整代码如下:
import os
import sys
__package__ = "trainer"
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import argparse
import time
import warnings
from contextlib import nullcontext
import torch
from torch import optim
from torch.utils.data import DataLoader
from dataset.dataloader import SFTDataset
from model.model_lora import apply_lora, save_lora
from model.model_pocketllm import PocketLLMConfig
from trainer.trainer_utils import Logger, SkipBatchSampler, get_lr, init_model, lm_checkpoint, setup_seed
warnings.filterwarnings("ignore")
CHECKPOINT_DIR = "../checkpoints"
def resolve_device(device: str) -> str:
if device.startswith("cuda") and not torch.cuda.is_available():
Logger("CUDA 不可用,已回退到 CPU")
return "cpu"
if device == "cuda" and torch.cuda.is_available():
return "cuda:0"
return device
def train_epoch(epoch, loader, iters, lora_params, start_step=0, wandb=None):
start_time = time.time()
last_step = start_step
for step, (input_ids, labels) in enumerate(loader, start=start_step + 1):
input_ids = input_ids.to(args.device)
labels = labels.to(args.device)
last_step = step
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
for param_group in optimizer.param_groups:
param_group["lr"] = lr
with autocast_ctx:
res = model(input_ids, labels=labels)
loss = res.loss + res.aux_loss
loss = loss / args.accumulation_steps
scaler.scale(loss).backward()
if step % args.accumulation_steps == 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(lora_params, args.grad_clip)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
if step % args.log_interval == 0 or step == iters:
spend_time = time.time() - start_time
current_loss = loss.item() * args.accumulation_steps
current_aux_loss = res.aux_loss.item() if res.aux_loss is not None else 0.0
current_logits_loss = current_loss - current_aux_loss
current_lr = optimizer.param_groups[-1]["lr"]
eta_min = spend_time / max(step - start_step, 1) * (iters - step) // 60
Logger(
f"Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), "
f"loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, "
f"aux_loss: {current_aux_loss:.4f}, lr: {current_lr:.8f}, epoch_time: {eta_min:.1f}min"
)
if wandb:
wandb.log(
{
"loss": current_loss,
"logits_loss": current_logits_loss,
"aux_loss": current_aux_loss,
"learning_rate": current_lr,
"epoch_time": eta_min,
}
)
if step % args.save_interval == 0 or step == iters:
model.eval()
moe_suffix = "_moe" if lm_config.use_moe else ""
# 保存 LoRA 权重
lora_save_path = f"{args.save_dir}/{args.lora_name}_{lm_config.hidden_size}{moe_suffix}.pth"
save_lora(model, lora_save_path)
lm_checkpoint(
lm_config,
weight=args.lora_name,
model=model,
optimizer=optimizer,
scaler=scaler,
epoch=epoch,
step=step,
wandb=wandb,
save_dir=CHECKPOINT_DIR,
)
model.train()
del input_ids, labels, res, loss
if last_step > start_step and last_step % args.accumulation_steps != 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(lora_params, args.grad_clip)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="PocketLLM LoRA 单卡微调")
parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
parser.add_argument("--lora_name", type=str, default="lora_medical", help="LoRA 权重名称")
parser.add_argument("--epochs", type=int, default=10, help="训练轮数")
parser.add_argument("--batch_size", type=int, default=32, help="批大小")
parser.add_argument("--learning_rate", type=float, default=1e-4, help="初始学习率")
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
parser.add_argument("--num_workers", type=int, default=8, help="数据加载线程数")
parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数")
parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
parser.add_argument("--log_interval", type=int, default=10, help="日志打印间隔")
parser.add_argument("--save_interval", type=int, default=1000, help="模型保存间隔")
parser.add_argument("--hidden_size", default=768, type=int, help="隐藏层维度")
parser.add_argument("--num_hidden_layers", default=8, type=int, help="隐藏层数量")
parser.add_argument("--max_seq_len", default=340, type=int, help="训练截断长度")
parser.add_argument("--use_moe", default=0, type=int, choices=[0, 1], help="是否使用 MoE 结构")
parser.add_argument(
"--data_path",
type=str,
default="../dataset/data/lora_medical.jsonl",
help="LoRA 训练数据路径",
)
parser.add_argument("--from_weight", default="full_sft", type=str, help="基座权重名称")
parser.add_argument("--from_resume", default=0, type=int, choices=[0, 1], help="是否自动续训")
parser.add_argument("--use_wandb", action="store_true", help="是否启用实验日志平台")
parser.add_argument("--wandb_project", type=str, default="PocketLLM-LoRA", help="实验项目名称")
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否启用 torch.compile")
args = parser.parse_args()
# ========== 1. 初始化环境和随机种子 ==========
args.device = resolve_device(args.device)
setup_seed(42)
os.makedirs(args.save_dir, exist_ok=True)
# ========== 2. 配置模型参数、检查ckp ==========
lm_config = PocketLLMConfig(
hidden_size=args.hidden_size,
num_hidden_layers=args.num_hidden_layers,
use_moe=bool(args.use_moe),
)
ckp_data = lm_checkpoint(lm_config, weight=args.lora_name, save_dir=CHECKPOINT_DIR) if args.from_resume == 1 else None
# ========== 3. 设置混合精度 ==========
device_type = "cuda" if args.device.startswith("cuda") else "cpu"
amp_dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
autocast_ctx = nullcontext() if device_type == "cpu" else torch.amp.autocast(device_type=device_type, dtype=amp_dtype)
# ========== 4. 配置实验日志平台 ==========
wandb = None
if args.use_wandb:
import swanlab as wandb
wandb_id = ckp_data.get("wandb_id") if ckp_data else None
resume = "must" if wandb_id else None
wandb_run_name = (
f"PocketLLM-LoRA-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
)
wandb.init(
project=args.wandb_project,
name=wandb_run_name,
id=wandb_id,
resume=resume,
)
# ========== 5. 定义模型、数据、优化器 ==========
model, tokenizer = init_model(lm_config, args.from_weight, device=args.device)
apply_lora(model)
total_params = sum(p.numel() for p in model.parameters())
lora_params = [param for name, param in model.named_parameters() if "lora" in name]
lora_params_count = sum(param.numel() for param in lora_params)
if not lora_params:
raise RuntimeError("没有检测到可训练的 LoRA 参数,请检查 apply_lora 是否成功注入。")
Logger(f"LLM 总参数量: {total_params / 1e6:.3f} M")
Logger(f"LoRA 参数量: {lora_params_count / 1e6:.3f} M")
Logger(f"LoRA 参数占比: {lora_params_count / total_params * 100:.2f}%")
train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
scaler = torch.amp.GradScaler(device_type, enabled=(args.dtype == "float16"))
optimizer = optim.AdamW(lora_params, lr=args.learning_rate)
# ========== 6. 从ckp恢复状态 ==========
start_epoch, start_step = 0, 0
if ckp_data:
model.load_state_dict(ckp_data["model"])
if ckp_data.get("optimizer") is not None:
optimizer.load_state_dict(ckp_data["optimizer"])
if ckp_data.get("scaler") is not None:
scaler.load_state_dict(ckp_data["scaler"])
start_epoch = ckp_data["epoch"]
start_step = ckp_data.get("step", 0)
# ========== 7. 可选编译模型 ==========
if args.use_compile == 1:
model = torch.compile(model)
Logger("已启用 torch.compile")
# ========== 8. 开始训练 ==========
for epoch in range(start_epoch, args.epochs):
setup_seed(42 + epoch)
indices = torch.randperm(len(train_ds)).tolist()
skip = start_step if (epoch == start_epoch and start_step > 0) else 0
batch_sampler = SkipBatchSampler(indices, args.batch_size, skip)
loader = DataLoader(
train_ds,
batch_sampler=batch_sampler,
num_workers=args.num_workers,
pin_memory=(device_type == "cuda"),
)
if skip > 0:
Logger(f"Epoch [{epoch + 1}/{args.epochs}]:跳过前 {start_step} 个 step,从 step {start_step + 1} 开始")
train_epoch(epoch, loader, len(loader) + skip, lora_params, start_step, wandb)
else:
train_epoch(epoch, loader, len(loader), lora_params, 0, wandb)
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐



所有评论(0)