【强化学习系列(三)】基于模型的强化学习 (Model-Based RL)


一、基于模型 vs 无模型:本质区别

在强化学习中,智能体学习如何行动的核心思路有两条路径:

  • 无模型 (Model-Free):智能体直接在环境中试错,通过与环境的交互来学习价值函数或策略。不知道环境长什么样
  • 基于模型 (Model-Based):智能体先学习一个环境模型(即状态转移和奖励函数),然后基于这个模型来进行规划或决策。知道环境长什么样

用一张图来概括:

┌──────────────────────────────────────────────────────────────────────┐
│                        强化学习的两条路线                               │
├──────────────────────────────────────────────────────────────────────┤
│                                                                      │
│   【无模型 Model-Free】              【基于模型 Model-Based】          │
│                                                                      │
│   ┌─────────┐  action   ┌────────┐    ┌─────────┐  action   ┌────────┐│
│   │  智能体  │ ───────► │ 真实环境 │    │  智能体  │ ───────► │  世界模型 ││
│   │ Agent   │           │  Real   │    │ Agent   │           │  World  ││
│   │         │ ◄─────── │  Env    │    │         │           │  Model  ││
│   └─────────┘  observe └────────┘    └─────────┘ ◄─────── └────────┘│
│                            ▲                     │              ▲    │
│                            │                     │              │    │
│                            └─────────────────────┘              │    │
│                                 真实交互                      想象规划  │
└──────────────────────────────────────────────────────────────────────┘

核心对比表

对比维度 无模型强化学习 (Model-Free) 基于模型强化学习 (Model-Based)
核心思想 通过真实交互试错学习 先学环境模型,再基于模型规划
环境模型 ❌ 不学习,视为黑盒 ✅ 学习状态转移 $P(s’
样本效率 低(需要大量真实交互) 高(可以在模型中"想象"交互)
规划能力 有限,只能靠价值函数近似 强,可进行前向搜索、轨迹优化
模型误差 不存在(无模型) 存在"模型误差"(model error)导致性能瓶颈
收敛稳定性 相对稳定,但方差大 存在误差累积问题,训练复杂
典型算法 Q-Learning, DQN, Policy Gradient, PPO Dyna-Q, MCTS, Dreamer, MuZero
代表应用 Atari游戏, 机器人控制 AlphaGo, AlphaZero, 围棋/象棋/将棋
计算成本 交互成本高,推理成本低 交互成本低,但规划/规划计算成本高
泛化能力 泛化到未见状态较难 模型可解释性稍好,但泛化仍困难

直观类比

🎯 无模型 vs 基于模型的形象比喻

想象你要学会游泳:

  • 无模型:直接跳进游泳池呛水,呛多了自然就会了(试错学习
  • 基于模型:先在岸上观看教学视频、想象水流、模拟动作要领,建立一个心理模型,然后再下水(模型辅助学习

二、学习环境模型:状态转移与奖励函数

2.1 什么是环境模型?

环境模型是对环境动态特性的数学描述,包含两个核心组件:

┌─────────────────────────────────────────────────────────────┐
│                    环境模型 Environment Model                 │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│   1️⃣ 状态转移函数 (Transition Function)                       │
│       P(s'|s, a) = Pr(下一状态 | 当前状态, 采取的动作)          │
│                                                             │
│       记忆口诀:【状态转移:给了首尾,定中间】                   │
│       s(当前) + a(动作) ──► s'(下一状态)                      │
│                                                             │
│   2️⃣ 奖励函数 (Reward Function)                              │
│       R(s, a) 或 R(s, a, s') = 智能体获得的即时奖励             │
│                                                             │
│       记忆口诀:【奖励函数:做了啥事,给多少分】                  │
│                                                             │
│   组合起来:环境模型 M = (P, R)                               │
│                                                             │
└─────────────────────────────────────────────────────────────┘

2.2 如何学习环境模型?

学习环境模型本质上是一个监督学习问题:

┌──────────────────────────────────────────────────────────────┐
│                   学习环境模型 = 监督学习                        │
├──────────────────────────────────────────────────────────────┤
│                                                              │
│   训练数据来自真实环境交互:                                     │
│                                                              │
│   收集样本: (s, a, r, s')  四元组                             │
│            │   │   │    │                                   │
│            ▼   ▼   ▼    ▼                                   │
│   ┌─────────────────────────────────────────┐               │
│   │  输入: (s, a)  ──►  输出: (s', r)         │               │
│   │                                         │               │
│   │  回归问题: 预测下一个状态 s'              │               │
│   │  回归问题: 预测奖励 r                    │               │
│   └─────────────────────────────────────────┘               │
│                      │                                        │
│                      ▼                                        │
│   可以使用神经网络进行函数逼近                                   │
│   s' = f_θ(s, a)     r = g_θ(s, a)                          │
│                                                              │
└──────────────────────────────────────────────────────────────┘

2.3 模型学习的具体方法

方法一:表格型方法(适用于离散、小状态空间)
# 伪代码:基于计数的环境模型学习
class TableModel:
    def __init__(self, n_states, n_actions):
        # N(s, a, s') = 从(s,a)转移到s'的计数
        self.N = np.zeros((n_states, n_actions, n_states))
        # R(s, a) = 在(s,a)获得的平均奖励
        self.R = np.zeros((n_states, n_actions))
        # N(s, a) = 访问(s,a)的总计数
        self.N_sa = np.zeros((n_states, n_actions))
    
    def update(self, s, a, r, s_next):
        self.N[s, a, s_next] += 1
        self.R[s, a] += r  # 累计,可后续取平均
        self.N_sa[s, a] += 1
    
    def get_transition_prob(self, s, a):
        """P(s'|s,a) = N(s,a,s') / N(s,a)"""
        if self.N_sa[s, a] == 0:
            return np.zeros(self.N.shape[2])
        return self.N[s, a] / self.N_sa[s, a]
    
    def get_reward(self, s, a):
        """R(s,a) = 累计奖励 / 访问次数"""
        if self.N_sa[s, a] == 0:
            return 0.0
        return self.R[s, a] / self.N_sa[s, a]
方法二:神经网络方法(适用于大规模状态空间)
# 伪代码:基于神经网络的环境模型
import torch
import torch.nn as nn

class WorldModel(nn.Module):
    """
    世界模型:输入(s,a),输出(s', r)
    也称为 Transition Model 或 Dynamics Model
    """
    def __init__(self, state_dim, action_dim, hidden_dim=256):
        super().__init__()
        
        # 状态转移网络: (s,a) -> s'
        self.transition_net = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, state_dim)  # 输出下一状态
        )
        
        # 奖励网络: (s,a) -> r
        self.reward_net = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)  # 输出标量奖励
        )
    
    def forward(self, s, a):
        """
        Args:
            s: 当前状态 (batch_size, state_dim)
            a: 动作 (batch_size, action_dim)
        Returns:
            s_next: 预测的下一状态
            r: 预测的奖励
        """
        sa = torch.cat([s, a], dim=-1)  # 拼接
        s_next = self.transition_net(sa)
        r = self.reward_net(sa)
        return s_next, r
    
    def predict(self, s, a):
        """单步预测"""
        with torch.no_grad():
            s_tensor = torch.FloatTensor(s).unsqueeze(0)
            a_tensor = torch.FloatTensor(a).unsqueeze(0)
            s_next, r = self.forward(s_tensor, a_tensor)
            return s_next.numpy().squeeze(0), r.numpy().squeeze(0)

2.4 模型学习的挑战

┌─────────────────────────────────────────────────────────────────┐
│                     模型学习的三大挑战                              │
├─────────────────────────────────────────────────────────────────┤
│                                                                  │
│  1️⃣ 误差累积 (Error Accumulation)                                │
│     ┌──────────────────────────────────────────┐                │
│     │ 真实: s0 ──► s1 ──► s2 ──► s3 ──► ...   │                │
│     │ 想象: s0 ──►ŝ1 ──►ŝ2 ──►ŝ3 ──► ...       │                │
│     │           ↑       ↑       ↑              │                │
│     │         误差1   累积误差2  更大误差3      │                │
│     └──────────────────────────────────────────┘                │
│     记忆口诀:【模型误差像滚雪球,越滚越大】                         │
│                                                                  │
│  2️⃣ 复合误差 (Compounding Errors)                               │
│     多步想象后误差呈指数级增长,需要-reward shaping 或 short horizon │
│                                                                  │
│  3️⃣ 分布偏移 (Distribution Shift)                               │
│     模型只在「见过的状态-动作对」上训练,                            │
│     规划可能探索未见区域,模型外推能力差                            │
│                                                                  │
└─────────────────────────────────────────────────────────────────┘

三、Monte Carlo Tree Search (MCTS):四步详解

3.1 MCTS是什么?

蒙特卡洛树搜索是一种基于模型的规划算法,核心思想是:

通过随机模拟来评估决策路径,用树搜索来高效地分配计算资源。

MCTS 不需要遍历整个搜索空间,而是通过迭代地"探索"和"利用"来找到最优路径。

3.2 MCTS四步详解(ASCII流程图)

┌─────────────────────────────────────────────────────────────────────────────┐
│                    MCTS 算法四步循环 (每回合一次)                              │
│                                                                             │
│                        ┌─────────────────────────────────────┐              │
│                        │          🔄  重复 N 次              │              │
│                        └─────────────────────────────────────┘              │
│                                         │                                    │
│                    ┌───────────────────┼───────────────────┐                │
│                    ▼                   ▼                   ▼                │
│            ┌───────────────┐   ┌───────────────┐   ┌───────────────┐        │
│            │  1️⃣ SELECTION │   │ 2️⃣ EXPANSION │   │3️⃣ SIMULATION│        │
│            │   选择        │   │   扩展        │   │   模拟        │        │
│            └───────┬───────┘   └───────┬───────┘   └───────┬───────┘        │
│                    │                   │                   │                │
│                    │    ┌──────────────┘                   │                │
│                    │    │                                  │                │
│                    │    ▼                                  │                │
│                    │  ┌─────────────────────────────┐      │                │
│                    │  │   从叶子节点随机选择         │      │                │
│                    │  │   一个未访问的子节点扩展      │      │                │
│                    │  └─────────────────────────────┘      │                │
│                    │                                      │                │
│                    │    ┌──────────────────────────────────┘                │
│                    │    │                                                  │
│                    │    ▼                                                  │
│                    │  ┌─────────────────────────────┐                       │
│                    └──│  从扩展节点开始, rollout    │◄── 随机策略            │
│                       │  直到游戏结束               │                       │
│                       └─────────────┬─────────────┘                       │
│                                     │                                       │
│                                     ▼                                       │
│                       ┌─────────────────────────────┐                       │
│                       │ 4️⃣ BACKPROPAGATION        │                       │
│                       │    反向传播                 │                       │
│                       │    更新路径上所有节点的     │                       │
│                       │    Q值/访问次数            │                       │
│                       └─────────────────────────────┘                       │
│                                                                     [完]    │
└─────────────────────────────────────────────────────────────────────────────┘

3.3 每一步详细解析

步骤1:SELECTION — 选择(从根到叶子)
目标:从根节点开始,使用树策略 (Tree Policy) 选择一个"最有前途"的叶子节点

核心机制:UCB1 (Upper Confidence Bound 1) 上置信界公式

         ┌─────────────────────────────────────────┐
         │                                         │
         │    UCB1 = Q(s) + C × √(ln N(parent)    │
         │                         ────────────)   │
         │                            N(child)    │
         │                                         │
         │    第一项 Q(s)   → 利用项(越高越好)      │
         │    第二项        → 探索项(越高→鼓励探索) │
         │    C             → 探索常数(通常=√2)   │
         │                                         │
         └─────────────────────────────────────────┘

选择过程可视化:
         
         根节点(Root)
            │
    ┌───────┼───────┐
    │       │       │
   [A]    [B]     [C]      ← 选择 UCB 值最高的子节点
   0.8    0.6    0.9       如果某节点还未被访问过,
    │      │      │        UCB = ∞(强制探索)
   [A1]   [B1]    [C1]
                    │
                   ...

记忆口诀:【选节点看UCB,既看收益又看探索】
步骤2:EXPANSION — 扩展
目标:向选中的叶子节点添加一个或多个子节点(代表可能的下一状态)

扩展策略:
  - 添加一个随机未访问的子节点(最常用)
  - 添加所有未访问的子节点(批量扩展)

扩展可视化:

  扩展前:                   扩展后:

       叶子节点L                  叶子节点L
          │                         │
       [A1]                    [A1]
                              /  │  \
                             ▼   ▼   ▼
                           [A2][A3][A4]  ← 新增节点(未访问,Q=0)
步骤3:SIMULATION — 模拟(Rollout / Playout)
目标:从扩展节点开始,使用默认策略(通常是随机策略)一直模拟到游戏结束

Rollout 策略(从简单到复杂):
  ① 随机 rollout:每个时间步随机选择动作
  ② 轻量级神经网络:使用简化的策略网络
  ③ 固定深度搜索 + 启发式评估

模拟可视化:

  扩展节点 ──随机动作──► s1 ──随机动作──► s2 ──...──► Terminal
       │                │                │
      [r=0]           [r=0]            [r=+1]
                                         │
                                      游戏结束!
                                    返回奖励 +1
步骤4:BACKPROPAGATION — 反向传播
目标:将模拟获得的回报反向传播,更新路径上所有节点的统计信息

更新规则:
  N(node)  += 1                              # 访问次数 +1
  Q(node)  =  Q(node) + (G - Q(node)) / N(node)  # 更新Q值(增量均值)
  
  其中 G = 从该节点到游戏结束获得的累计奖励

反向传播可视化:

       模拟返回 G=+1
            │
            ▼
  ┌─────────────────┐
  │   Terminal State│
  └────────┬────────┘
           │ Backprop
           ▼
  ┌─────────────────┐
  │   节点 s2        │  N+=1, Q更新
  └────────┬────────┘
           │ Backprop
           ▼
  ┌─────────────────┐
  │   节点 s1        │  N+=1, Q更新
  └────────┬────────┘
           │ Backprop
           ▼
  ┌─────────────────┐
  │   节点 s0 (根)   │  N+=1, Q更新
  └─────────────────┘

3.4 MCTS完整算法流程

"""
Monte Carlo Tree Search 伪代码
"""
import math
import random

class MCTSNode:
    def __init__(self, state, parent=None, action=None):
        self.state = state          # 节点对应的状态
        self.parent = parent        # 父节点
        self.action = action        # 从父节点到达此节点的动作
        self.children = {}          # 子节点字典 {action: child_node}
        self.N = 0                  # 访问次数
        self.Q = 0.0                # 平均Q值(累计奖励均值)
    
    def is_fully_expanded(self):
        """检查是否所有子节点都已扩展"""
        return len(self.children) == len(self.state.get_legal_actions())
    
    def ucb1(self, c=math.sqrt(2)):
        """UCB1公式:平衡探索与利用"""
        if self.N == 0:
            return float('inf')  # 未访问节点优先级最高
        if self.parent is None:
            return float('inf')  # 根节点
        return self.Q + c * math.sqrt(math.log(self.parent.N) / self.N)


def mcts(root_state, num_simulations=1000, c=math.sqrt(2)):
    """
    MCTS 主算法
    
    Args:
        root_state: 初始状态
        num_simulations: 模拟次数
        c: UCB探索常数
    
    Returns:
        选择根节点的最优子动作
    """
    root = MCTSNode(state=root_state)
    
    for _ in range(num_simulations):
        node = root
        path = [node]  # 记录搜索路径,用于反向传播
        
        # ========== 步骤1: SELECTION ==========
        # 从根到叶子,选择UCB值最高的子节点
        while node.is_fully_expanded() and node.children:
            best_child = max(node.children.values(), key=lambda n: n.ucb1(c))
            node = best_child
            path.append(node)
        
        # ========== 步骤2: EXPANSION ==========
        # 如果节点未完全扩展,则扩展一个子节点
        legal_actions = node.state.get_legal_actions()
        unexpanded_actions = [a for a in legal_actions if a not in node.children]
        
        if unexpanded_actions:
            action = random.choice(unexpanded_actions)
            next_state = node.state.take_action(action)
            child_node = MCTSNode(state=next_state, parent=node, action=action)
            node.children[action] = child_node
            node = child_node
            path.append(node)
        
        # ========== 步骤3: SIMULATION ==========
        # 从当前节点 rollout 到游戏结束
        rollout_state = node.state.copy()
        while not rollout_state.is_terminal():
            action = random.choice(rollout_state.get_legal_actions())
            rollout_state = rollout_state.take_action(action)
        
        reward = rollout_state.get_reward()  # 假设为1=赢,0=输
        
        # ========== 步骤4: BACKPROPAGATION ==========
        # 反向传播更新路径上所有节点
        for n in path:
            n.N += 1
            n.Q += (reward - n.Q) / n.N  # 增量更新Q值
    
    # 返回根节点下访问次数最多的子动作
    best_action = max(root.children.keys(), 
                      key=lambda a: root.children[a].N)
    return best_action


# ========== UCB1 公式详解 ==========
"""
UCB1 (Upper Confidence Bound 1) 公式:

    UCB1 = Q̄(s,a) + C × √(ln(N(parent)) / N(s,a))

其中:
  - Q̄(s,a): 子节点的平均奖励(利用项)
  - C: 探索常数,通常取 √2 ≈ 1.414
  - N(parent): 父节点的访问次数
  - N(s,a): 当前节点的访问次数

数学性质:
  ① 当节点访问次数少时,第二项(探索项)大,鼓励探索
  ② 当节点访问次数多时,第一项(利用项)主导,选择高收益节点
  ③ UCB1 在理论上保证:遗憾值有上界

记忆口诀:【UCB三部分,收益加探索除以次数】
"""

3.5 MCTS vs 传统树搜索

对比维度 传统 Minimax 搜索 MCTS (蒙特卡洛树搜索)
搜索深度 完整深度(指数爆炸) 迭代限制,可中断
评估函数 需要手工评估函数 通过随机模拟自动评估
分支因子 高时效率低 通过UCB自适应剪枝
适用场景 完美信息、确定性游戏 不完美信息、随机性游戏
计算预算 固定深度 固定迭代次数(任意停止)

四、AlphaGo 中的 MCTS 应用

4.1 AlphaGo 系统架构

AlphaGo 是 MCTS 在围棋领域取得突破性成功的代表作,它将 MCTS 与深度学习完美结合:

┌─────────────────────────────────────────────────────────────────────────┐
│                        AlphaGo 系统架构                                   │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│   ┌─────────────────────────────────────────────────────────────────┐    │
│   │                    策略网络 Policy Network                      │    │
│   │   输入: 当前棋盘状态 s                                           │    │
│   │   输出: 每个动作 a 的概率分布 p(a|s)                               │    │
│   │   训练: 监督学习(从人类棋谱)+ 强化学习(自我对弈)                  │    │
│   └─────────────────────────────────────────────────────────────────┘    │
│                              │                                           │
│                              ▼                                           │
│   ┌─────────────────────────────────────────────────────────────────┐    │
│   │                    价值网络 Value Network                        │    │
│   │   输入: 当前棋盘状态 s                                           │    │
│   │   输出: 该状态的胜率 v(s) ∈ [0, 1]                               │    │
│   │   训练: 回归问题,从大量自我对弈样本中学习                           │    │
│   └─────────────────────────────────────────────────────────────────┘    │
│                              │                                           │
│                              ▼                                           │
│   ┌─────────────────────────────────────────────────────────────────┐    │
│   │                    MCTS + 深度学习融合                             │    │
│   │                                                                  │    │
│   │   ① SELECTION: 使用 PUCT 公式(UCB的围棋变体)                    │    │
│   │      PUCT = Q(s,a) + U(s,a)                                      │    │
│   │      其中 U(s,a) ∝ P(s,a) / (1 + N(s,a))                         │    │
│   │      —— 策略网络的先验概率引导探索方向!                          │    │
│   │                                                                  │    │
│   │   ② EXPANSION: 扩展节点时,用策略网络计算先验概率                   │    │
│   │                                                                  │    │
│   │   ③ SIMULATION: 用价值网络评估叶子节点,不再完全随机 rollout        │    │
│   │                                                                  │    │
│   │   ④ BACKPROP: 用价值网络输出更新Q值                               │    │
│   │                                                                  │    │
│   └─────────────────────────────────────────────────────────────────┘    │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

4.2 AlphaGo 的 MCTS 改进点

┌─────────────────────────────────────────────────────────────────────────┐
│                   AlphaGo 对标准 MCTS 的三大改进                           │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  改进1️⃣:策略网络先验 (Prior from Policy Network)                         │
│                                                                          │
│    标准 MCTS: UCB 探索项是 uniform 的(无信息)                            │
│    AlphaGo:    U(s,a) ∝ P(a|s) / (1 + N(s,a))                           │
│                ↑ 策略网络提供先验概率引导探索                              │
│                                                                          │
│    效果: 原本需要数万次模拟才能找到的好棋,                                 │
│          现在可能只需数百次——因为策略网络"懂"围棋!                        │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  改进2️⃣:价值网络评估替代随机 Rollout                                     │
│                                                                          │
│    标准 MCTS: 从叶子节点随机 rollout 到游戏结束(可能数千步)                │
│    AlphaGo:    直接用价值网络 v(s') 评估叶子节点                          │
│                                                                          │
│    效果: 消除随机 rollout 的高方差,评估更稳定                              │
│                                                                          │
│  ─────────────────────────────────────────────────────────────────────  │
│                                                                          │
│  改进3️⃣:PUCT 公式替代 UCB1                                              │
│                                                                          │
│    UCB1:   Q + C×√(ln(N_parent)/N_child)                               │
│    PUCT:   Q + P(a|s) × √(N_parent) / (1 + N_child)                    │
│             ↑ 策略先验概率                                                │
│                                                                          │
│    记忆口诀:【AlphaGo有策略网络,PUCT取代UCB1】                          │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

4.3 AlphaGo 的训练流程

"""
AlphaGo 训练流程概述
"""

# ┌──────────────────────────────────────────────────────────────────┐
# │                    第一阶段:监督学习                              │
# │         用人类围棋高手的棋谱训练策略网络                            │
# │         (类似模仿学习,但数据量巨大)                                │
# └──────────────────────────────────────────────────────────────────┘
#
#   人类棋谱 → (state, action) 对 → 训练策略网络 p(a|s)
#                                         ↓
#                              正确率可达 57%(远超随机的0.1%)
#
# ┌──────────────────────────────────────────────────────────────────┐
# │                    第二阶段:强化学习                              │
# │         用策略梯度方法,让策略网络自我对弈提升                       │
# │         对手是当前策略网络 vs 历史最强版本                          │
# └──────────────────────────────────────────────────────────────────┘
#
#   当前策略网络 π → 与历史最强版本对弈 → 胜负奖励 → 策略优化
#                                         ↓
#                              进一步提升,超越监督学习版本
#
# ┌──────────────────────────────────────────────────────────────────┐
# │                    第三阶段:训练价值网络                           │
# │         用自我对弈的棋局数据,训练状态→胜率映射                      │
# └──────────────────────────────────────────────────────────────────┘
#
#   大量棋谱 (state → 最终胜/负) → 训练 v(s) 回归网络
#
# ┌──────────────────────────────────────────────────────────────────┐
# │                    第四阶段:MCTS + 深度网络                       │
# │         在线规划时,策略网络+价值网络+MCTS 协同工作                  │
# └──────────────────────────────────────────────────────────────────┘

4.4 AlphaGo vs AlphaGo Zero vs AlphaZero

版本 人类棋谱 特征 自我对弈
AlphaGo ✅ 需要 策略网络 + 价值网络 + MCTS
AlphaGo Zero ❌ 不需要 单一残差网络,同时输出策略和价值
AlphaZero ❌ 不需要 AlphaGo Zero架构,通用化(围棋/象棋/将棋)

五、Dreamer算法:世界模型 + 在想象中学习

5.1 Dreamer 的核心思想

Dreamer (Hafner et al., 2020) 是一个典型的基于模型的强化学习算法,核心理念是:

🎯 “在想象中学习” (Learning in Imagination)

智能体首先学习一个世界模型(World Model),然后在模型内部进行大量的强化学习学习,完全不需要与真实环境交互!

┌─────────────────────────────────────────────────────────────────────────┐
│                    Dreamer: 在想象中学习                                   │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│   真实环境交互(少量)              世界模型内部(大量)                    │
│                                                                          │
│   ┌────────┐                    ┌──────────────────────────┐           │
│   │ 真实   │ ──少量交互──► │     世界模型 (World Model)   │           │
│   │ 环境   │   收集经验     │  ┌────────┐  ┌───────────┐ │           │
│   └────────┘                │  │ 动态模型 │  │ 奖励模型   │ │           │
│                             │  │ p(s'|s,a)│  │ p(r|s,a)   │ │           │
│                             │  └────────┘  └───────────┘ │           │
│                             └──────────────────────────┘           │
│                                          │                             │
│                                          ▼                             │
│                             ┌──────────────────────────┐             │
│                             │    想象轨迹生成           │             │
│                             │  s0 → a0 → r0 → s1 → ...  │             │
│                             │    在模型中生成 10000+ 条  │             │
│                             │    想象轨迹                │             │
│                             └──────────────────────────┘             │
│                                          │                             │
│                                          ▼                             │
│                             ┌──────────────────────────┐             │
│                             │    演员-评论家学习         │             │
│                             │  使用想象轨迹训练策略和价值  │             │
│                             └──────────────────────────┘             │
│                                                                          │
│   记忆口诀:【Dreamer三步走:建模→想象→学习】                              │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

5.2 Dreamer 算法流程图

┌─────────────────────────────────────────────────────────────────────────────┐
│                           Dreamer 算法完整流程                                │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  ┌──────────────────────────────────────────────────────────────────────┐  │
│  │                     阶段1:与世界模型交互                               │  │
│  │                                                                      │  │
│  │     真实环境                                                      │  │
│  │   ┌─────────┐  a0  ┌─────────┐  a1  ┌─────────┐                     │  │
│  │   │ s0=观察 │ ───► │ s1=观察 │ ───► │ s2=观察 │ ───► ...            │  │
│  │   └─────────┘      └─────────┘      └─────────┘                     │  │
│  │        │                                                     │      │  │
│  │        │ ε0                                                   │      │  │
│  │        ▼                                                     ▼      │  │
│  │   ┌─────────┐                                           ┌──────────┐ │  │
│  │   │ a0=动作 │                                           │经验缓冲池│ │  │
│  │   └────┬────┘                                           │   D      │ │  │
│  │        │                                                   └────┬───┘ │  │
│  │        └──────────────────────────────────────────────────────────┘   │  │
│  └──────────────────────────────────────────────────────────────────────┘  │
│                                     │                                        │
│                                     ▼                                        │
│  ┌──────────────────────────────────────────────────────────────────────┐  │
│  │                     阶段2:训练世界模型                                │  │
│  │                                                                      │  │
│  │   从经验缓冲池 D 采样批量数据: (s_t, a_t, r_t, s_{t+1})             │  │
│  │                                                                      │  │
│  │   ┌─────────────────────────────────────────────────────┐           │  │
│  │   │               世界模型 (World Model)                 │           │  │
│  │   │                                                      │           │  │
│  │   │   编码器: s_t → z_t (潜在状态)                        │           │  │
│  │   │     ↓                                                │           │  │
│  │   │   隐式动态模型: p(z_{t+1}|z_t, a_t)                   │           │  │
│  │   │     ↓                                                │           │  │
│  │   │   解码器: z_t → ŝ_t (重建观察)                        │           │  │
│  │   │     ↓                                                │           │  │
│  │   │   奖励预测: r̂_t = Reward(z_t, a_t)                   │           │  │
│  │   │                                                      │           │  │
│  │   │   损失函数: L = L重建(ŝ_t, s_t) + L奖励(r̂_t, r_t)   │           │  │
│  │   └─────────────────────────────────────────────────────┘           │  │
│  │                              ↓                                       │  │
│  │                         梯度更新                                     │  │
│  └──────────────────────────────────────────────────────────────────────┘  │
│                                     │                                        │
│                                     ▼                                        │
│  ┌──────────────────────────────────────────────────────────────────────┐  │
│  │                     阶段3:在想象中学习策略                             │  │
│  │                                                                      │  │
│  │   使用学习好的世界模型,生成"想象轨迹"(完全在模型中进行)             │  │
│  │                                                                      │  │
│  │   z_0 ~ p(z_0)            初始潜在状态分布                             │  │
│  │       │                                                           │  │
│  │       │ For t = 0 to H (想象 horizon):                            │  │
│  │       │   a_t ~ π_θ(a_t | z_t)  ← 演员网络采样动作                   │  │
│  │       │   r_t, z_{t+1} ~ p(z_{t+1}|z_t, a_t)  ← 世界模型            │  │
│  │       │   λ_t = γ^t  ← 计算折扣回报                                 │  │
│  │       │                                                           │  │
│  │       ▼                                                           │  │
│  │   ┌─────────────────────────────────────────────────────┐          │  │
│  │   │         演员-评论家 (Actor-Critic) 学习               │          │  │
│  │   │                                                      │          │  │
│  │   │   演员 π_θ: 最大化想象轨迹上的价值 V^π(z_t)           │          │  │
│  │   │   评论家 V_ψ: 拟合想象轨迹上的回报期望                  │          │  │
│  │   │                                                      │          │  │
│  │   │   使用 TD 学习更新评论家,再用策略梯度更新演员           │          │  │
│  │   └─────────────────────────────────────────────────────┘          │  │
│  └──────────────────────────────────────────────────────────────────────┘  │
│                                     │                                        │
│                                     │  循环回到阶段1,继续交互                 │
│                                     └─────────────────────────────────────►  │
│                                                                             │
└─────────────────────────────────────────────────────────────────────────────┘

5.3 Dreamer 的关键组件

"""
Dreamer 算法的核心组件伪代码
"""

import torch
import torch.nn as nn

# ┌────────────────────────────────────────────────────────────────┐
# │              组件1: 世界模型 (World Model)                      │
# └────────────────────────────────────────────────────────────────┘

class WorldModel(nn.Module):
    """
    世界模型包含三个部分:
    1. 编码器 (Encoder): 图像/高维观察 → 潜在向量 z
    2. 动态模型 (Dynamics): 潜在状态 + 动作 → 下一潜在状态
    3. 奖励预测器 (Reward Predictor)
    """
    def __init__(self, obs_dim, action_dim, latent_dim=32, hidden_dim=256):
        super().__init__()
        self.latent_dim = latent_dim
        
        # 编码器: o_t → z_t (变分编码)
        self.encoder = nn.Sequential(
            nn.Linear(obs_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, latent_dim * 2)  # 输出均值和方差
        )
        
        # 动态模型: (z_t, a_t) → z_{t+1}
        self.dynamics_net = nn.Sequential(
            nn.Linear(latent_dim + action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, latent_dim * 2)  # 输出均值和方差
        )
        
        # 奖励预测器: (z_t, a_t) → r_t
        self.reward_net = nn.Sequential(
            nn.Linear(latent_dim + action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
    
    def forward(self, obs, action):
        """世界模型前向传播"""
        # 编码观察
        z_mean, z_logvar = torch.chunk(self.encoder(obs), 2, dim=-1)
        z = z_mean + torch.randn_like(z_mean) * torch.exp(0.5 * z_logvar)
        
        # 动态预测
        z_a = torch.cat([z, action], dim=-1)
        z_next_mean, z_next_logvar = torch.chunk(self.dynamics_net(z_a), 2, dim=-1)
        
        # 奖励预测
        r_pred = self.reward_net(z_a)
        
        return z, z_next_mean, z_next_logvar, r_pred

# ┌────────────────────────────────────────────────────────────────┐
# │              组件2: 演员-评论家 (Actor-Critic)                  │
# └────────────────────────────────────────────────────────────────┘

class ActorCritic(nn.Module):
    """
    演员-评论家网络: 在潜在空间中进行学习
    """
    def __init__(self, latent_dim, action_dim, hidden_dim=256, action_bound=1.0):
        super().__init__()
        self.action_bound = action_bound
        
        # 评论家: 状态价值函数 V(z)
        self.critic = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
        
        # 演员: 策略函数 π(a|z)
        self.actor = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim * 2)  # 输出均值和方差
        )
    
    def forward(self, z):
        v = self.critic(z)
        mean, logvar = torch.chunk(self.actor(z), 2, dim=-1)
        logvar = torch.clamp(logvar, -20, 2)
        return v, mean, logvar
    
    def sample(self, z):
        """从策略中采样动作"""
        _, mean, logvar = self.forward(z)
        std = torch.exp(0.5 * logvar)
        action = torch.randn_like(mean) * std + mean
        action = torch.tanh(action) * self.action_bound
        return action


def dreamer_training_loop(env, world_model, actor_critic, 
                           imagination_horizon=15, num_epochs=100):
    """
    Dreamer 训练循环伪代码
    
    核心思想: 在世界模型的"想象"中进行强化学习
    """
    replay_buffer = ReplayBuffer(capacity=100000)
    
    for epoch in range(num_epochs):
        # ---- 阶段1: 收集真实经验 ----
        obs = env.reset()
        for step in range(1000):  # 与环境交互收集数据
            action = actor_critic.sample(
                world_model.encoder(torch.FloatTensor(obs).unsqueeze(0))
            ).detach().numpy().squeeze(0)
            next_obs, reward, done, _ = env.step(action)
            replay_buffer.push(obs, action, reward, next_obs, done)
            obs = next_obs if not done else env.reset()
        
        # ---- 阶段2: 训练世界模型 ----
        for _ in range(100):
            batch = replay_buffer.sample(batch_size=256)
            obs, actions, rewards, next_obs, dones = batch
            
            # 训练世界模型: 重建观察 + 预测奖励
            z, z_next_mean, z_next_logvar, r_pred = world_model(obs, actions)
            # ... 计算损失并更新 world_model ...
        
        # ---- 阶段3: 在想象中学习策略 ----
        # 想象轨迹收集
        imagined_observations = []
        imagined_actions = []
        imagined_rewards = []
        imagined_values = []
        
        z = torch.randn(batch_size, latent_dim)  # 从先验采样初始状态
        for t in range(imagination_horizon):
            action = actor_critic.sample(z)
            _, z_next_mean, z_next_logvar, r_pred = world_model(z, action)
            z = z_next_mean + torch.randn_like(z_next_mean) * torch.exp(0.5 * z_next_logvar)
            v, _, _ = actor_critic(z)
            
            imagined_observations.append(z)
            imagined_actions.append(action)
            imagined_rewards.append(r_pred)
            imagined_values.append(v)
        
        # 计算回报并更新 actor-critic
        # ... 使用 REINFORCE 或策略梯度更新 ...
    
    return world_model, actor_critic


# ┌────────────────────────────────────────────────────────────────┐
# │              Dreamer vs Model-Free 对比                        │
# └────────────────────────────────────────────────────────────────┘

"""
┌──────────────────────────────────────────────────────────────┐
│                    Dreamer 的优势                              │
├──────────────────────────────────────────────────────────────┤
│  ✅ 样本效率极高: 想象轨迹可以无限生成,不受真实交互限制         │
│  ✅ 规划能力: 可以在模型中预测未来,进行前向规划                  │
│  ✅ 安全探索: 高风险动作可以在模型中先模拟                       │
│                                                              │
│  ⚠️ 核心挑战: 模型误差会累积,想象轨迹越长,误差越大            │
│  解决思路: 限制 imagination horizon,或使用模型不确定性估计      │
└──────────────────────────────────────────────────────────────┘
"""

六、基于模型RL的优势与挑战

6.1 核心优势

┌─────────────────────────────────────────────────────────────────────────┐
│                    基于模型强化学习:四大核心优势                            │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  1️⃣ 样本效率高 (Sample Efficiency)                                      │
│     ┌──────────────────────────────────────────────────────┐            │
│     │                                                       │            │
│     │   Model-Free:  ████████████████████████████ ≈ 100万步   │            │
│     │   Model-Based: ███ ≈ 1万步                             │            │
│     │                                                       │            │
│     │   在模型内部可以"想象"无数次交互,无需真实环境采样         │            │
│     └──────────────────────────────────────────────────────┘            │
│                                                                          │
│  2️⃣ 规划能力 (Planning Capability)                                      │
│     ┌──────────────────────────────────────────────────────┐            │
│     │   • 前向搜索: MCTS, 蒙特卡洛规划                       │            │
│     │   • 轨迹优化: iLQR, Model Predictive Control (MPC)   │            │
│     │   • 模型可以是"可解释的": 人类能理解模型的预测          │            │
│     └──────────────────────────────────────────────────────┘            │
│                                                                          │
│  3️⃣ 安全探索 (Safe Exploration)                                         │
│     ┌──────────────────────────────────────────────────────┐            │
│     │   高风险行为可以先在模型中模拟:                        │            │
│     │   "如果我这么做,会发生什么?"  → 模型回答!             │            │
│     │   不需要在真实环境中冒险尝试危险动作                     │            │
│     └──────────────────────────────────────────────────────┘            │
│                                                                          │
│  4️⃣ 跨任务迁移 (Transfer Learning)                                       │
│     ┌──────────────────────────────────────────────────────┐            │
│     │   世界模型捕捉的是环境的"物理规律"                       │            │
│     │   奖励变了?只改奖励函数,模型可复用!                   │            │
│     │   适合 sim-to-real 迁移                                 │            │
│     └──────────────────────────────────────────────────────┘            │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

6.2 核心挑战

┌─────────────────────────────────────────────────────────────────────────┐
│                    基于模型强化学习:四大核心挑战                            │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  1️⃣ 模型误差累积 (Model Error Accumulation) ⭐⭐⭐                        │
│                                                                          │
│     问题: 多步预测后误差呈指数增长                                        │
│                                                                          │
│     s0(真实) ──► s1(真实) ──► s2(真实) ──► ...                        │
│        │           │           │                                        │
│        ▼           ▼           ▼                                        │
│     ŝ0=s0      ŝ1 ≠ s1     ŝ2 偏离更远                                  │
│       │           ↑           ↑                                         │
│       └──模型预测──┴───────────┘                                         │
│                                                                          │
│     记忆口诀:【一步错,步步错,滚雪球效应】                                 │
│                                                                          │
│     解决思路:                                                            │
│       - 限制规划步数(short horizon)                                   │
│       - 使用集成模型(ensemble)估计不确定性                               │
│       - Dyna-style: 模型+真实经验混合                                     │
│                                                                          │
│  2️⃣ 分布偏移 (Distribution Shift)                                        │
│                                                                          │
│     问题: 模型只在「见过的状态-动作对」上训练                              │
│           规划器可能探索未见区域,模型外推能力差                            │
│                                                                          │
│     ┌─────────────────────────────────────────────────────┐             │
│     │   训练分布:      ○ ○ ○ ○ ○ ○ ○ ○                    │             │
│     │   测试/规划:              ╳ ╳ ╳ ← 模型不知道!       │             │
│     └─────────────────────────────────────────────────────┘             │
│                                                                          │
│  3️⃣ 复合误差 (Compounding Errors)                                        │
│                                                                          │
│     问题: 单步误差会被后续决策放大                                        │
│     类比: 天气预报误差,今天误差1°C,7天后可能差10°C                       │
│                                                                          │
│  4️⃣ 训练复杂度高                                                        │
│                                                                          │
│     ┌─────────────────────────────────────────────────────┐             │
│     │   Model-Free: 单一目标(最大化价值)                 │             │
│     │   Model-Based: 多目标                               │             │
│     │     - 拟合环境动态 (回归问题)                        │             │
│     │     - 预测奖励 (回归问题)                            │             │
│     │     - 优化策略 (强化学习问题)                        │             │
│     │     三者需要同时优化,相互影响!                       │             │
│     └─────────────────────────────────────────────────────┘             │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

6.3 经典算法一览

算法 模型类型 规划方式 代表工作
Dyna-Q 表格 Q-规划 Sutton 1991
蒙特卡洛树搜索 无模型(在线) 树搜索 Kocsis & Szepesvari 2006
UCT 无模型(在线) UCB树搜索 Gelly et al. 2006
AlphaGo 深度网络 MCTS+深度策略/价值网络 Silver et al. 2016
World Models VAE+RNN 在隐空间规划 Ha & Schmidhuber 2018
Dreamer 变分模型 想象轨迹+Actor-Critic Hafner et al. 2020
MuZero 隐式模型 MCTS+无真实模型规划 Schrittwieser et al. 2020
Model-Based OPC MPC 模型预测控制 Chua et al. 2018

七、记忆口诀

🎯 基于模型强化学习核心口诀

┌─────────────────────────────────────────────────────────────┐
│                                                              │
│   【学模型,定环境,P和R要分清】                                │
│   P(s'|s,a) 状态转移,R(s,a) 奖励函数                         │
│                                                              │
│   【MCTS四步走,SEL-EXP-SIM-BACK】                            │
│   Selection选节点,Expansion扩分支                            │
│   Simulation随机rollout,Backprop反向传                      │
│   UCB公式记心头,Q加探索除以次数                              │
│                                                              │
│   【AlphaGo三改进,策略先验+价值网】                           │
│   PUCT取代UCB1,策略引导探索方向                              │
│   价值网络取代rollout,评估更稳不摇晃                         │
│                                                              │
│   【Dreamer三步走,建模→想象→学习】                            │
│   世界模型学环境,想象轨迹无限生                              │
│   演员评论家来优化,样本效率第一名                             │
│                                                              │
│   【模型误差像滚雪球,越滚越大别忘啦】                         │
│   horizon要控制,不确定性要估计                                │
│                                                              │
└─────────────────────────────────────────────────────────────┘

八、延伸阅读与参考文献

论文 作者 年份 关键贡献
Dyna-Q Sutton 1991 Dyna架构:模型+无模型混合
UCT Kocsis & Szepesvari 2006 UCB用于MCTS的理论基础
Monte Carlo Go Gelly et al. 2006 MCTS应用于围棋
Mastering the Game of Go Silver et al. 2016 AlphaGo: MCTS+深度学习
World Models Ha & Schmidhuber 2018 在隐空间学习世界模型
Dreamer: Learning Behaviors Hafner et al. 2020 变分世界模型+想象学习
Mastering Atari with discrete world models Hafner et al. 2023 Dreamer v2/v3 Atari成就
MuZero Schrittwieser et al. 2020 无需真实模型的AlphaZero

结语

基于模型的强化学习是强化学习领域的重要分支,它通过学习环境模型来实现高效规划样本高效学习

核心要点回顾

  • 🌟 模型学习:监督学习 P(s’|s,a) 和 R(s,a)
  • 🌟 MCTS:四步循环(选择/扩展/模拟/反传)+ UCB平衡探索利用
  • 🌟 AlphaGo:策略网络+价值网络+MCTS的完美结合
  • 🌟 Dreamer:世界模型+想象轨迹+Actor-Critic的端到端学习
  • ⚠️ 关键挑战:模型误差累积、分布偏移、复合误差
Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐