目录

1.DPO训练流程

1.1 初始化

策略模型πθ​的初始化

参考模型πref​的初始化

2.DPO训练迭代(推理→计算→反向传播)

2.1 模型推理

2.2 计算隐式奖励值

2.3 计算DPO损失函数

2.4 反向传播与参数更新

3. DPO训练程序


       在大语言模型的人类偏好对齐阶段,核心目标是让模型生成的内容符合人类的价值判断、使用习惯与需求偏好。而传统的RLHF需历经“监督微调(SFT)→奖励模型(RM)训练→强化学习迭代”三步,流程繁琐且存在训练不稳定、样本利用率低、奖励偏移等问题。而DPO的核心创新在于直接以偏好数据为监督信号,通过对比优质回答与劣质回答的分布差异,优化策略模型的参数,无需额外训练奖励模型,也无需复杂的强化学习交互迭代,大幅降低了对齐训练的复杂度与成本。DPO训练的核心是最小化策略模型与“最优参考分布”的偏离:以SFT模型为基础构建参考模型,将策略模型的输出分布向“偏好数据中优质回答的分布靠拢,同时抑制劣质回答的分布,最终让策略模型的输出更贴合用户偏好。

1.DPO训练流程

DPO的训练流程如下图所示:

1.1 初始化

       DPO训练涉及策略模型与参考模型两个核心模型,二者的初始化是训练的基础,直接影响训练的收敛速度与最终效果。DPO的模型初始化无需从零开始,而是基于已有模型进行迁移,这也是大语言模型微调的核心思路(迁移学习)。

策略模型πθ​的初始化

       策略模型是DPO训练中需要被优化的核心模型,其最终目标是生成符合用户偏好的内容。初始化方式为:

基础方案:直接复制监督微调(SFT)模型的参数。SFT模型是先在通用指令数据上微调后的模型,已具备基础的对话、推理能力,以此为基础初始化策略模型,可让模型快速收敛,避免从零训练的高成本;

进阶方案:若有性能更优的兼容模型(如经过多轮SFT的模型、行业定制模型),可直接复用该模型作为策略模型的初始参数。核心要求是:初始策略模型需具备基础的语言生成能力,且与参考模型的分布差异不宜过大(否则会增加训练难度)。

参考模型πref​的初始化

       参考模型是DPO训练中的“锚点”,其核心作用是提供一个固定的参考分布,用于衡量策略模型的优化方向。DPO训练的本质是让策略模型的输出分布向参考模型的优质分布靠拢,同时与参考模型的劣质分布拉开差距,因此参考模型的初始化需满足“稳定性、兼容性”要求:

基础方案:与策略模型一致,复制SFT模型的参数。此时策略模型与参考模型初始参数相同,训练初期二者的分布差异极小,便于模型逐步学习偏好;
进阶方案:选择一个比SFT模型更强的模型作为参考模型(如经过更充分SFT的模型、在特定领域优化过的模型)。但此时需重点关注两个模型的KL 散度(Kullback-Leibler Divergence)与训练数据分布的匹配性:

KL散度:衡量两个概率分布的差异,若策略模型与参考模型的KL散度过大,会导致训练初期损失函数波动过大,难以收敛;因此需通过微调策略模型参数,缩小二者的KL散度,保证训练稳定性;

数据分布匹配:参考模型的训练数据分布需与DPO的偏好数据分布尽量一致(如参考模型是对话领域模型,DPO数据也以对话为主),否则参考模型的“参考价值”会降低,导致策略模型学习到错误的偏好规律。

2.DPO训练迭代(推理→计算→反向传播)

       DPO的核心训练过程是“单次推理→计算损失→反向传播”的迭代循环,针对每一批次的偏好数据,模型需完成 “策略模型推理、参考模型推理、损失计算、参数更新” 四个步骤,这一过程是DPO训练的核心技术环节,也是其与传统监督微调的核心区别。其流程图如下所示:

2.1 模型推理

      针对预处理后的批次数据,需同时输入策略模型(πθ)与参考模型(πref),分别对 “优质回答序列” 与“劣质回答序列”进行推理,得到模型输出的Logits(未归一化的概率值),再通过Softmax与 LogSoftmax转换为动作概率(即模型生成每个Token的概率)。具体推理流程如下:

2.2 计算隐式奖励值

       DPO没有显式的奖励模型,而是通过策略模型与参考模型的概率比值计算隐式奖励值,量化“优质回答优于劣质回答”的程度,这是DPO对比优化的核心逻辑。隐式奖励的计算公式为:

其中:

2.3 计算DPO损失函数

      DPO的损失函数是对比损失,核心目标是最大化优质回答的隐式奖励,最小化劣质回答的隐式奖励,即让策略模型的输出分布向优质回答靠拢,同时远离劣质回答。DPO损失函数的公式为:

其中:

实际实现中,为提升计算效率,会对批次内的损失取均值,即:

2.4 反向传播与参数更新

       损失函数计算完成后,需通过反向传播计算策略模型参数的梯度,再通过优化器更新参数,完成一次训练迭代。这一步是模型学习偏好规律的核心环节,与传统监督微调的反向传播逻辑一致,但优化目标是DPO的对比损失而非交叉熵损失。

3. DPO训练程序

偏好数据集定义:

class DPOPreferenceDataset(Dataset):
    """
    DPO偏好数据集:存储 (prompt, win_response, lose_response) 三元组
    """
    def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer):
        self.data = data
        self.tokenizer = tokenizer

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        sample = self.data[idx]
        prompt = sample["prompt"]
        win_response = sample["win_response"]
        lose_response = sample["lose_response"]
        
        # Token化:prompt + 优质回答 / prompt + 劣质回答
        prompt_tokens = self.tokenizer.encode(prompt, add_special_tokens=False)
        win_tokens = self.tokenizer.encode(win_response, add_special_tokens=False)
        lose_tokens = self.tokenizer.encode(lose_response, add_special_tokens=False)
        
        # 拼接输入序列
        input_win = prompt_tokens + win_tokens + [self.tokenizer.eos_token_id]
        input_lose = prompt_tokens + lose_tokens + [self.tokenizer.eos_token_id]
        
        # 截断到最大长度
        input_win = input_win[:config.max_seq_len]
        input_lose = input_lose[:config.max_seq_len]
        
        # 生成attention mask(1表示有效token,0表示padding)
        mask_win = [1] * len(input_win)
        mask_lose = [1] * len(input_lose)
        
        # 填充到最大长度
        pad_len_win = config.max_seq_len - len(input_win)
        pad_len_lose = config.max_seq_len - len(input_lose)
        input_win += [self.tokenizer.pad_token_id] * pad_len_win
        mask_win += [0] * pad_len_win
        input_lose += [self.tokenizer.pad_token_id] * pad_len_lose
        mask_lose += [0] * pad_len_lose
        
        # 标签:仅回答部分参与损失计算(prompt部分mask掉)
        label_win = [-100] * len(prompt_tokens) + win_tokens + [self.tokenizer.eos_token_id]
        label_lose = [-100] * len(prompt_tokens) + lose_tokens + [self.tokenizer.eos_token_id]
        label_win = label_win[:config.max_seq_len] + [-100] * pad_len_win
        label_lose = label_lose[:config.max_seq_len] + [-100] * pad_len_lose
        
        return {
            "input_win": torch.tensor(input_win, dtype=torch.long),
            "mask_win": torch.tensor(mask_win, dtype=torch.long),
            "label_win": torch.tensor(label_win, dtype=torch.long),
            "input_lose": torch.tensor(input_lose, dtype=torch.long),
            "mask_lose": torch.tensor(mask_lose, dtype=torch.long),
            "label_lose": torch.tensor(label_lose, dtype=torch.long),
            "prompt_len": len(prompt_tokens)  # 记录prompt长度,用于Gather操作
        }

动作概率(logprob)计算函数:

def get_logprob(
    model: PreTrainedModel,
    input_ids: torch.Tensor,
    attention_mask: torch.Tensor,
    labels: torch.Tensor,
    prompt_len: int
) -> torch.Tensor:
    """
    计算模型在给定输入下,标签序列的对数概率(logprob)
    :param model: 策略模型或参考模型
    :param input_ids: 输入token序列
    :param attention_mask: attention掩码
    :param labels: 标签序列(-100表示忽略的位置)
    :param prompt_len: prompt部分长度,用于定位回答部分
    :return: 每个样本的平均logprob(形状:[batch_size])
    """
    model.eval()  # 推理模式,不更新参数
    with torch.no_grad():
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            return_dict=True
        )
        logits = outputs.logits  # [batch_size, seq_len, vocab_size]
    
    # 计算log_softmax:稳定获取对数概率
    log_probs = torch.log_softmax(logits, dim=-1)
    
    # 提取标签对应的logprob(Gather操作)
    # 移位对齐:logits的第i个位置预测第i+1个token
    shift_log_probs = log_probs[:, :-1, :].contiguous()
    shift_labels = labels[:, 1:].contiguous()
    
    # 只计算回答部分的logprob(prompt部分mask为-100)
    loss_mask = (shift_labels != -100).float()
    selected_log_probs = shift_log_probs.gather(
        dim=-1,
        index=shift_labels.unsqueeze(-1)
    ).squeeze(-1)
    
    # 按样本平均logprob
    per_sample_logprob = (selected_log_probs * loss_mask).sum(dim=-1) / loss_mask.sum(dim=-1).clamp(min=1e-8)
    return per_sample_logprob

DPO损失函数实现

class DPOLoss(nn.Module):
    def __init__(self, beta: float = 0.1):
        super().__init__()
        self.beta = beta

    def forward(
        self,
        policy_win_logprob: torch.Tensor,
        policy_lose_logprob: torch.Tensor,
        ref_win_logprob: torch.Tensor,
        ref_lose_logprob: torch.Tensor
    ) -> torch.Tensor:
        """
        计算DPO损失
        :param policy_win_logprob: 策略模型在优质回答上的logprob
        :param policy_lose_logprob: 策略模型在劣质回答上的logprob
        :param ref_win_logprob: 参考模型在优质回答上的logprob
        :param ref_lose_logprob: 参考模型在劣质回答上的logprob
        :return: DPO损失值
        """
        # 计算隐式奖励
        win_rewards = self.beta * (policy_win_logprob - ref_win_logprob)
        lose_rewards = self.beta * (policy_lose_logprob - ref_lose_logprob)
        
        # 计算DPO损失:-log(sigmoid(win_rewards - lose_rewards))
        loss = -torch.log(torch.sigmoid(win_rewards - lose_rewards)).mean()
        return loss

DPO训练器封装

class DPOTrainer:
    def __init__(
        self,
        policy_model: PreTrainedModel,
        ref_model: PreTrainedModel,
        tokenizer: PreTrainedTokenizer,
        config: DPOConfig,
        train_dataset: DPOPreferenceDataset
    ):
        self.policy_model = policy_model.to(config.device)
        self.ref_model = ref_model.to(config.device)
        self.ref_model.eval()  # 参考模型固定,不参与训练
        self.tokenizer = tokenizer
        self.config = config
        self.train_dataset = train_dataset
        self.loss_fn = DPOLoss(beta=config.beta)
        self.optimizer = optim.AdamW(self.policy_model.parameters(), lr=config.lr)
        self.dataloader = DataLoader(
            train_dataset,
            batch_size=config.batch_size,
            shuffle=True,
            num_workers=4
        )

    def train_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
        """
        单步训练:处理一个批次数据,计算损失并更新参数
        """
        # 移动数据到设备
        input_win = batch["input_win"].to(self.config.device)
        mask_win = batch["mask_win"].to(self.config.device)
        label_win = batch["label_win"].to(self.config.device)
        input_lose = batch["input_lose"].to(self.config.device)
        mask_lose = batch["mask_lose"].to(self.config.device)
        label_lose = batch["label_lose"].to(self.config.device)
        prompt_len = batch["prompt_len"]

        # 1. 计算策略模型的logprob
        policy_win_logprob = get_logprob(
            self.policy_model, input_win, mask_win, label_win, prompt_len
        )
        policy_lose_logprob = get_logprob(
            self.policy_model, input_lose, mask_lose, label_lose, prompt_len
        )

        # 2. 计算参考模型的logprob(固定参数,无梯度)
        ref_win_logprob = get_logprob(
            self.ref_model, input_win, mask_win, label_win, prompt_len
        )
        ref_lose_logprob = get_logprob(
            self.ref_model, input_lose, mask_lose, label_lose, prompt_len
        )

        # 3. 计算DPO损失
        loss = self.loss_fn(
            policy_win_logprob, policy_lose_logprob,
            ref_win_logprob, ref_lose_logprob
        )

        # 4. 反向传播更新策略模型参数
        self.optimizer.zero_grad()
        loss.backward()
        # 梯度裁剪,防止梯度爆炸
        torch.nn.utils.clip_grad_norm_(self.policy_model.parameters(), max_norm=1.0)
        self.optimizer.step()

        return loss.item()

    def train(self):
        """
        完整训练循环
        """
        self.policy_model.train()
        for epoch in range(self.config.num_epochs):
            total_loss = 0.0
            for batch_idx, batch in enumerate(self.dataloader):
                loss = self.train_step(batch)
                total_loss += loss

                if (batch_idx + 1) % 10 == 0:
                    print(f"Epoch [{epoch+1}/{self.config.num_epochs}], "
                          f"Batch [{batch_idx+1}/{len(self.dataloader)}], "
                          f"Loss: {loss:.4f}")
            
            avg_loss = total_loss / len(self.dataloader)
            print(f"Epoch [{epoch+1}/{self.config.num_epochs}], Average Loss: {avg_loss:.4f}")

        # 保存训练后的策略模型
        self.policy_model.save_pretrained("dpo_trained_model")
        self.tokenizer.save_pretrained("dpo_trained_model")
        print("训练完成,模型已保存至 dpo_trained_model 目录")

主程序入口(调用示例)

def main():
    # 1. 加载tokenizer和基础SFT模型
    tokenizer = AutoTokenizer.from_pretrained(config.model_name_or_path)
    tokenizer.pad_token = tokenizer.eos_token  # 设定pad token
    sft_model = AutoModelForCausalLM.from_pretrained(config.model_name_or_path)

    # 2. 初始化策略模型和参考模型(均复制SFT模型)
    policy_model = AutoModelForCausalLM.from_pretrained(config.model_name_or_path)
    ref_model = AutoModelForCausalLM.from_pretrained(config.model_name_or_path)

    # 3. 构造示例偏好数据(实际使用时替换为真实数据集)
    demo_data = [
        {
            "prompt": "虎皮青椒用什么做?",
            "win_response": "青椒、大蒜、生抽、盐等食材,先将青椒煎至表皮起皱,再调味翻炒。",
            "lose_response": "老虎和辣椒,把老虎和辣椒一起炒就成了虎皮青椒。"
        },
        {
            "prompt": "Python中如何遍历字典?",
            "win_response": "可以用for key, value in dict.items()遍历键值对,或for key in dict.keys()遍历键。",
            "lose_response": "用for循环直接写for i in dict就可以遍历所有元素。"
        }
    ]

    # 4. 初始化数据集和训练器
    train_dataset = DPOPreferenceDataset(demo_data, tokenizer)
    trainer = DPOTrainer(
        policy_model=policy_model,
        ref_model=ref_model,
        tokenizer=tokenizer,
        config=config,
        train_dataset=train_dataset
    )

    # 5. 启动训练
    trainer.train()

if __name__ == "__main__":
    main()

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐