三、动态规划与策略迭代
·
【强化学习系列(三)】基于模型的强化学习 (Model-Based RL)
一、基于模型 vs 无模型:本质区别
在强化学习中,智能体学习如何行动的核心思路有两条路径:
- 无模型 (Model-Free):智能体直接在环境中试错,通过与环境的交互来学习价值函数或策略。不知道环境长什么样。
- 基于模型 (Model-Based):智能体先学习一个环境模型(即状态转移和奖励函数),然后基于这个模型来进行规划或决策。知道环境长什么样。
用一张图来概括:
┌──────────────────────────────────────────────────────────────────────┐
│ 强化学习的两条路线 │
├──────────────────────────────────────────────────────────────────────┤
│ │
│ 【无模型 Model-Free】 【基于模型 Model-Based】 │
│ │
│ ┌─────────┐ action ┌────────┐ ┌─────────┐ action ┌────────┐│
│ │ 智能体 │ ───────► │ 真实环境 │ │ 智能体 │ ───────► │ 世界模型 ││
│ │ Agent │ │ Real │ │ Agent │ │ World ││
│ │ │ ◄─────── │ Env │ │ │ │ Model ││
│ └─────────┘ observe └────────┘ └─────────┘ ◄─────── └────────┘│
│ ▲ │ ▲ │
│ │ │ │ │
│ └─────────────────────┘ │ │
│ 真实交互 想象规划 │
└──────────────────────────────────────────────────────────────────────┘
核心对比表
| 对比维度 | 无模型强化学习 (Model-Free) | 基于模型强化学习 (Model-Based) |
|---|---|---|
| 核心思想 | 通过真实交互试错学习 | 先学环境模型,再基于模型规划 |
| 环境模型 | ❌ 不学习,视为黑盒 | ✅ 学习状态转移 $P(s’ |
| 样本效率 | 低(需要大量真实交互) | 高(可以在模型中"想象"交互) |
| 规划能力 | 有限,只能靠价值函数近似 | 强,可进行前向搜索、轨迹优化 |
| 模型误差 | 不存在(无模型) | 存在"模型误差"(model error)导致性能瓶颈 |
| 收敛稳定性 | 相对稳定,但方差大 | 存在误差累积问题,训练复杂 |
| 典型算法 | Q-Learning, DQN, Policy Gradient, PPO | Dyna-Q, MCTS, Dreamer, MuZero |
| 代表应用 | Atari游戏, 机器人控制 | AlphaGo, AlphaZero, 围棋/象棋/将棋 |
| 计算成本 | 交互成本高,推理成本低 | 交互成本低,但规划/规划计算成本高 |
| 泛化能力 | 泛化到未见状态较难 | 模型可解释性稍好,但泛化仍困难 |
直观类比
🎯 无模型 vs 基于模型的形象比喻
想象你要学会游泳:
- 无模型:直接跳进游泳池呛水,呛多了自然就会了(试错学习)
- 基于模型:先在岸上观看教学视频、想象水流、模拟动作要领,建立一个心理模型,然后再下水(模型辅助学习)
二、学习环境模型:状态转移与奖励函数
2.1 什么是环境模型?
环境模型是对环境动态特性的数学描述,包含两个核心组件:
┌─────────────────────────────────────────────────────────────┐
│ 环境模型 Environment Model │
├─────────────────────────────────────────────────────────────┤
│ │
│ 1️⃣ 状态转移函数 (Transition Function) │
│ P(s'|s, a) = Pr(下一状态 | 当前状态, 采取的动作) │
│ │
│ 记忆口诀:【状态转移:给了首尾,定中间】 │
│ s(当前) + a(动作) ──► s'(下一状态) │
│ │
│ 2️⃣ 奖励函数 (Reward Function) │
│ R(s, a) 或 R(s, a, s') = 智能体获得的即时奖励 │
│ │
│ 记忆口诀:【奖励函数:做了啥事,给多少分】 │
│ │
│ 组合起来:环境模型 M = (P, R) │
│ │
└─────────────────────────────────────────────────────────────┘
2.2 如何学习环境模型?
学习环境模型本质上是一个监督学习问题:
┌──────────────────────────────────────────────────────────────┐
│ 学习环境模型 = 监督学习 │
├──────────────────────────────────────────────────────────────┤
│ │
│ 训练数据来自真实环境交互: │
│ │
│ 收集样本: (s, a, r, s') 四元组 │
│ │ │ │ │ │
│ ▼ ▼ ▼ ▼ │
│ ┌─────────────────────────────────────────┐ │
│ │ 输入: (s, a) ──► 输出: (s', r) │ │
│ │ │ │
│ │ 回归问题: 预测下一个状态 s' │ │
│ │ 回归问题: 预测奖励 r │ │
│ └─────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ 可以使用神经网络进行函数逼近 │
│ s' = f_θ(s, a) r = g_θ(s, a) │
│ │
└──────────────────────────────────────────────────────────────┘
2.3 模型学习的具体方法
方法一:表格型方法(适用于离散、小状态空间)
# 伪代码:基于计数的环境模型学习
class TableModel:
def __init__(self, n_states, n_actions):
# N(s, a, s') = 从(s,a)转移到s'的计数
self.N = np.zeros((n_states, n_actions, n_states))
# R(s, a) = 在(s,a)获得的平均奖励
self.R = np.zeros((n_states, n_actions))
# N(s, a) = 访问(s,a)的总计数
self.N_sa = np.zeros((n_states, n_actions))
def update(self, s, a, r, s_next):
self.N[s, a, s_next] += 1
self.R[s, a] += r # 累计,可后续取平均
self.N_sa[s, a] += 1
def get_transition_prob(self, s, a):
"""P(s'|s,a) = N(s,a,s') / N(s,a)"""
if self.N_sa[s, a] == 0:
return np.zeros(self.N.shape[2])
return self.N[s, a] / self.N_sa[s, a]
def get_reward(self, s, a):
"""R(s,a) = 累计奖励 / 访问次数"""
if self.N_sa[s, a] == 0:
return 0.0
return self.R[s, a] / self.N_sa[s, a]
方法二:神经网络方法(适用于大规模状态空间)
# 伪代码:基于神经网络的环境模型
import torch
import torch.nn as nn
class WorldModel(nn.Module):
"""
世界模型:输入(s,a),输出(s', r)
也称为 Transition Model 或 Dynamics Model
"""
def __init__(self, state_dim, action_dim, hidden_dim=256):
super().__init__()
# 状态转移网络: (s,a) -> s'
self.transition_net = nn.Sequential(
nn.Linear(state_dim + action_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, state_dim) # 输出下一状态
)
# 奖励网络: (s,a) -> r
self.reward_net = nn.Sequential(
nn.Linear(state_dim + action_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1) # 输出标量奖励
)
def forward(self, s, a):
"""
Args:
s: 当前状态 (batch_size, state_dim)
a: 动作 (batch_size, action_dim)
Returns:
s_next: 预测的下一状态
r: 预测的奖励
"""
sa = torch.cat([s, a], dim=-1) # 拼接
s_next = self.transition_net(sa)
r = self.reward_net(sa)
return s_next, r
def predict(self, s, a):
"""单步预测"""
with torch.no_grad():
s_tensor = torch.FloatTensor(s).unsqueeze(0)
a_tensor = torch.FloatTensor(a).unsqueeze(0)
s_next, r = self.forward(s_tensor, a_tensor)
return s_next.numpy().squeeze(0), r.numpy().squeeze(0)
2.4 模型学习的挑战
┌─────────────────────────────────────────────────────────────────┐
│ 模型学习的三大挑战 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ 1️⃣ 误差累积 (Error Accumulation) │
│ ┌──────────────────────────────────────────┐ │
│ │ 真实: s0 ──► s1 ──► s2 ──► s3 ──► ... │ │
│ │ 想象: s0 ──►ŝ1 ──►ŝ2 ──►ŝ3 ──► ... │ │
│ │ ↑ ↑ ↑ │ │
│ │ 误差1 累积误差2 更大误差3 │ │
│ └──────────────────────────────────────────┘ │
│ 记忆口诀:【模型误差像滚雪球,越滚越大】 │
│ │
│ 2️⃣ 复合误差 (Compounding Errors) │
│ 多步想象后误差呈指数级增长,需要-reward shaping 或 short horizon │
│ │
│ 3️⃣ 分布偏移 (Distribution Shift) │
│ 模型只在「见过的状态-动作对」上训练, │
│ 规划可能探索未见区域,模型外推能力差 │
│ │
└─────────────────────────────────────────────────────────────────┘
三、Monte Carlo Tree Search (MCTS):四步详解
3.1 MCTS是什么?
蒙特卡洛树搜索是一种基于模型的规划算法,核心思想是:
通过随机模拟来评估决策路径,用树搜索来高效地分配计算资源。
MCTS 不需要遍历整个搜索空间,而是通过迭代地"探索"和"利用"来找到最优路径。
3.2 MCTS四步详解(ASCII流程图)
┌─────────────────────────────────────────────────────────────────────────────┐
│ MCTS 算法四步循环 (每回合一次) │
│ │
│ ┌─────────────────────────────────────┐ │
│ │ 🔄 重复 N 次 │ │
│ └─────────────────────────────────────┘ │
│ │ │
│ ┌───────────────────┼───────────────────┐ │
│ ▼ ▼ ▼ │
│ ┌───────────────┐ ┌───────────────┐ ┌───────────────┐ │
│ │ 1️⃣ SELECTION │ │ 2️⃣ EXPANSION │ │3️⃣ SIMULATION│ │
│ │ 选择 │ │ 扩展 │ │ 模拟 │ │
│ └───────┬───────┘ └───────┬───────┘ └───────┬───────┘ │
│ │ │ │ │
│ │ ┌──────────────┘ │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ ┌─────────────────────────────┐ │ │
│ │ │ 从叶子节点随机选择 │ │ │
│ │ │ 一个未访问的子节点扩展 │ │ │
│ │ └─────────────────────────────┘ │ │
│ │ │ │
│ │ ┌──────────────────────────────────┘ │
│ │ │ │
│ │ ▼ │
│ │ ┌─────────────────────────────┐ │
│ └──│ 从扩展节点开始, rollout │◄── 随机策略 │
│ │ 直到游戏结束 │ │
│ └─────────────┬─────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────┐ │
│ │ 4️⃣ BACKPROPAGATION │ │
│ │ 反向传播 │ │
│ │ 更新路径上所有节点的 │ │
│ │ Q值/访问次数 │ │
│ └─────────────────────────────┘ │
│ [完] │
└─────────────────────────────────────────────────────────────────────────────┘
3.3 每一步详细解析
步骤1:SELECTION — 选择(从根到叶子)
目标:从根节点开始,使用树策略 (Tree Policy) 选择一个"最有前途"的叶子节点
核心机制:UCB1 (Upper Confidence Bound 1) 上置信界公式
┌─────────────────────────────────────────┐
│ │
│ UCB1 = Q(s) + C × √(ln N(parent) │
│ ────────────) │
│ N(child) │
│ │
│ 第一项 Q(s) → 利用项(越高越好) │
│ 第二项 → 探索项(越高→鼓励探索) │
│ C → 探索常数(通常=√2) │
│ │
└─────────────────────────────────────────┘
选择过程可视化:
根节点(Root)
│
┌───────┼───────┐
│ │ │
[A] [B] [C] ← 选择 UCB 值最高的子节点
0.8 0.6 0.9 如果某节点还未被访问过,
│ │ │ UCB = ∞(强制探索)
[A1] [B1] [C1]
│
...
记忆口诀:【选节点看UCB,既看收益又看探索】
步骤2:EXPANSION — 扩展
目标:向选中的叶子节点添加一个或多个子节点(代表可能的下一状态)
扩展策略:
- 添加一个随机未访问的子节点(最常用)
- 添加所有未访问的子节点(批量扩展)
扩展可视化:
扩展前: 扩展后:
叶子节点L 叶子节点L
│ │
[A1] [A1]
/ │ \
▼ ▼ ▼
[A2][A3][A4] ← 新增节点(未访问,Q=0)
步骤3:SIMULATION — 模拟(Rollout / Playout)
目标:从扩展节点开始,使用默认策略(通常是随机策略)一直模拟到游戏结束
Rollout 策略(从简单到复杂):
① 随机 rollout:每个时间步随机选择动作
② 轻量级神经网络:使用简化的策略网络
③ 固定深度搜索 + 启发式评估
模拟可视化:
扩展节点 ──随机动作──► s1 ──随机动作──► s2 ──...──► Terminal
│ │ │
[r=0] [r=0] [r=+1]
│
游戏结束!
返回奖励 +1
步骤4:BACKPROPAGATION — 反向传播
目标:将模拟获得的回报反向传播,更新路径上所有节点的统计信息
更新规则:
N(node) += 1 # 访问次数 +1
Q(node) = Q(node) + (G - Q(node)) / N(node) # 更新Q值(增量均值)
其中 G = 从该节点到游戏结束获得的累计奖励
反向传播可视化:
模拟返回 G=+1
│
▼
┌─────────────────┐
│ Terminal State│
└────────┬────────┘
│ Backprop
▼
┌─────────────────┐
│ 节点 s2 │ N+=1, Q更新
└────────┬────────┘
│ Backprop
▼
┌─────────────────┐
│ 节点 s1 │ N+=1, Q更新
└────────┬────────┘
│ Backprop
▼
┌─────────────────┐
│ 节点 s0 (根) │ N+=1, Q更新
└─────────────────┘
3.4 MCTS完整算法流程
"""
Monte Carlo Tree Search 伪代码
"""
import math
import random
class MCTSNode:
def __init__(self, state, parent=None, action=None):
self.state = state # 节点对应的状态
self.parent = parent # 父节点
self.action = action # 从父节点到达此节点的动作
self.children = {} # 子节点字典 {action: child_node}
self.N = 0 # 访问次数
self.Q = 0.0 # 平均Q值(累计奖励均值)
def is_fully_expanded(self):
"""检查是否所有子节点都已扩展"""
return len(self.children) == len(self.state.get_legal_actions())
def ucb1(self, c=math.sqrt(2)):
"""UCB1公式:平衡探索与利用"""
if self.N == 0:
return float('inf') # 未访问节点优先级最高
if self.parent is None:
return float('inf') # 根节点
return self.Q + c * math.sqrt(math.log(self.parent.N) / self.N)
def mcts(root_state, num_simulations=1000, c=math.sqrt(2)):
"""
MCTS 主算法
Args:
root_state: 初始状态
num_simulations: 模拟次数
c: UCB探索常数
Returns:
选择根节点的最优子动作
"""
root = MCTSNode(state=root_state)
for _ in range(num_simulations):
node = root
path = [node] # 记录搜索路径,用于反向传播
# ========== 步骤1: SELECTION ==========
# 从根到叶子,选择UCB值最高的子节点
while node.is_fully_expanded() and node.children:
best_child = max(node.children.values(), key=lambda n: n.ucb1(c))
node = best_child
path.append(node)
# ========== 步骤2: EXPANSION ==========
# 如果节点未完全扩展,则扩展一个子节点
legal_actions = node.state.get_legal_actions()
unexpanded_actions = [a for a in legal_actions if a not in node.children]
if unexpanded_actions:
action = random.choice(unexpanded_actions)
next_state = node.state.take_action(action)
child_node = MCTSNode(state=next_state, parent=node, action=action)
node.children[action] = child_node
node = child_node
path.append(node)
# ========== 步骤3: SIMULATION ==========
# 从当前节点 rollout 到游戏结束
rollout_state = node.state.copy()
while not rollout_state.is_terminal():
action = random.choice(rollout_state.get_legal_actions())
rollout_state = rollout_state.take_action(action)
reward = rollout_state.get_reward() # 假设为1=赢,0=输
# ========== 步骤4: BACKPROPAGATION ==========
# 反向传播更新路径上所有节点
for n in path:
n.N += 1
n.Q += (reward - n.Q) / n.N # 增量更新Q值
# 返回根节点下访问次数最多的子动作
best_action = max(root.children.keys(),
key=lambda a: root.children[a].N)
return best_action
# ========== UCB1 公式详解 ==========
"""
UCB1 (Upper Confidence Bound 1) 公式:
UCB1 = Q̄(s,a) + C × √(ln(N(parent)) / N(s,a))
其中:
- Q̄(s,a): 子节点的平均奖励(利用项)
- C: 探索常数,通常取 √2 ≈ 1.414
- N(parent): 父节点的访问次数
- N(s,a): 当前节点的访问次数
数学性质:
① 当节点访问次数少时,第二项(探索项)大,鼓励探索
② 当节点访问次数多时,第一项(利用项)主导,选择高收益节点
③ UCB1 在理论上保证:遗憾值有上界
记忆口诀:【UCB三部分,收益加探索除以次数】
"""
3.5 MCTS vs 传统树搜索
| 对比维度 | 传统 Minimax 搜索 | MCTS (蒙特卡洛树搜索) |
|---|---|---|
| 搜索深度 | 完整深度(指数爆炸) | 迭代限制,可中断 |
| 评估函数 | 需要手工评估函数 | 通过随机模拟自动评估 |
| 分支因子 | 高时效率低 | 通过UCB自适应剪枝 |
| 适用场景 | 完美信息、确定性游戏 | 不完美信息、随机性游戏 |
| 计算预算 | 固定深度 | 固定迭代次数(任意停止) |
四、AlphaGo 中的 MCTS 应用
4.1 AlphaGo 系统架构
AlphaGo 是 MCTS 在围棋领域取得突破性成功的代表作,它将 MCTS 与深度学习完美结合:
┌─────────────────────────────────────────────────────────────────────────┐
│ AlphaGo 系统架构 │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ 策略网络 Policy Network │ │
│ │ 输入: 当前棋盘状态 s │ │
│ │ 输出: 每个动作 a 的概率分布 p(a|s) │ │
│ │ 训练: 监督学习(从人类棋谱)+ 强化学习(自我对弈) │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ 价值网络 Value Network │ │
│ │ 输入: 当前棋盘状态 s │ │
│ │ 输出: 该状态的胜率 v(s) ∈ [0, 1] │ │
│ │ 训练: 回归问题,从大量自我对弈样本中学习 │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ MCTS + 深度学习融合 │ │
│ │ │ │
│ │ ① SELECTION: 使用 PUCT 公式(UCB的围棋变体) │ │
│ │ PUCT = Q(s,a) + U(s,a) │ │
│ │ 其中 U(s,a) ∝ P(s,a) / (1 + N(s,a)) │ │
│ │ —— 策略网络的先验概率引导探索方向! │ │
│ │ │ │
│ │ ② EXPANSION: 扩展节点时,用策略网络计算先验概率 │ │
│ │ │ │
│ │ ③ SIMULATION: 用价值网络评估叶子节点,不再完全随机 rollout │ │
│ │ │ │
│ │ ④ BACKPROP: 用价值网络输出更新Q值 │ │
│ │ │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────┘
4.2 AlphaGo 的 MCTS 改进点
┌─────────────────────────────────────────────────────────────────────────┐
│ AlphaGo 对标准 MCTS 的三大改进 │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ 改进1️⃣:策略网络先验 (Prior from Policy Network) │
│ │
│ 标准 MCTS: UCB 探索项是 uniform 的(无信息) │
│ AlphaGo: U(s,a) ∝ P(a|s) / (1 + N(s,a)) │
│ ↑ 策略网络提供先验概率引导探索 │
│ │
│ 效果: 原本需要数万次模拟才能找到的好棋, │
│ 现在可能只需数百次——因为策略网络"懂"围棋! │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ 改进2️⃣:价值网络评估替代随机 Rollout │
│ │
│ 标准 MCTS: 从叶子节点随机 rollout 到游戏结束(可能数千步) │
│ AlphaGo: 直接用价值网络 v(s') 评估叶子节点 │
│ │
│ 效果: 消除随机 rollout 的高方差,评估更稳定 │
│ │
│ ───────────────────────────────────────────────────────────────────── │
│ │
│ 改进3️⃣:PUCT 公式替代 UCB1 │
│ │
│ UCB1: Q + C×√(ln(N_parent)/N_child) │
│ PUCT: Q + P(a|s) × √(N_parent) / (1 + N_child) │
│ ↑ 策略先验概率 │
│ │
│ 记忆口诀:【AlphaGo有策略网络,PUCT取代UCB1】 │
│ │
└─────────────────────────────────────────────────────────────────────────┘
4.3 AlphaGo 的训练流程
"""
AlphaGo 训练流程概述
"""
# ┌──────────────────────────────────────────────────────────────────┐
# │ 第一阶段:监督学习 │
# │ 用人类围棋高手的棋谱训练策略网络 │
# │ (类似模仿学习,但数据量巨大) │
# └──────────────────────────────────────────────────────────────────┘
#
# 人类棋谱 → (state, action) 对 → 训练策略网络 p(a|s)
# ↓
# 正确率可达 57%(远超随机的0.1%)
#
# ┌──────────────────────────────────────────────────────────────────┐
# │ 第二阶段:强化学习 │
# │ 用策略梯度方法,让策略网络自我对弈提升 │
# │ 对手是当前策略网络 vs 历史最强版本 │
# └──────────────────────────────────────────────────────────────────┘
#
# 当前策略网络 π → 与历史最强版本对弈 → 胜负奖励 → 策略优化
# ↓
# 进一步提升,超越监督学习版本
#
# ┌──────────────────────────────────────────────────────────────────┐
# │ 第三阶段:训练价值网络 │
# │ 用自我对弈的棋局数据,训练状态→胜率映射 │
# └──────────────────────────────────────────────────────────────────┘
#
# 大量棋谱 (state → 最终胜/负) → 训练 v(s) 回归网络
#
# ┌──────────────────────────────────────────────────────────────────┐
# │ 第四阶段:MCTS + 深度网络 │
# │ 在线规划时,策略网络+价值网络+MCTS 协同工作 │
# └──────────────────────────────────────────────────────────────────┘
4.4 AlphaGo vs AlphaGo Zero vs AlphaZero
| 版本 | 人类棋谱 | 特征 | 自我对弈 |
|---|---|---|---|
| AlphaGo | ✅ 需要 | 策略网络 + 价值网络 + MCTS | ✅ |
| AlphaGo Zero | ❌ 不需要 | 单一残差网络,同时输出策略和价值 | ✅ |
| AlphaZero | ❌ 不需要 | AlphaGo Zero架构,通用化(围棋/象棋/将棋) | ✅ |
五、Dreamer算法:世界模型 + 在想象中学习
5.1 Dreamer 的核心思想
Dreamer (Hafner et al., 2020) 是一个典型的基于模型的强化学习算法,核心理念是:
🎯 “在想象中学习” (Learning in Imagination)
智能体首先学习一个世界模型(World Model),然后在模型内部进行大量的强化学习学习,完全不需要与真实环境交互!
┌─────────────────────────────────────────────────────────────────────────┐
│ Dreamer: 在想象中学习 │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ 真实环境交互(少量) 世界模型内部(大量) │
│ │
│ ┌────────┐ ┌──────────────────────────┐ │
│ │ 真实 │ ──少量交互──► │ 世界模型 (World Model) │ │
│ │ 环境 │ 收集经验 │ ┌────────┐ ┌───────────┐ │ │
│ └────────┘ │ │ 动态模型 │ │ 奖励模型 │ │ │
│ │ │ p(s'|s,a)│ │ p(r|s,a) │ │ │
│ │ └────────┘ └───────────┘ │ │
│ └──────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌──────────────────────────┐ │
│ │ 想象轨迹生成 │ │
│ │ s0 → a0 → r0 → s1 → ... │ │
│ │ 在模型中生成 10000+ 条 │ │
│ │ 想象轨迹 │ │
│ └──────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌──────────────────────────┐ │
│ │ 演员-评论家学习 │ │
│ │ 使用想象轨迹训练策略和价值 │ │
│ └──────────────────────────┘ │
│ │
│ 记忆口诀:【Dreamer三步走:建模→想象→学习】 │
│ │
└─────────────────────────────────────────────────────────────────────────┘
5.2 Dreamer 算法流程图
┌─────────────────────────────────────────────────────────────────────────────┐
│ Dreamer 算法完整流程 │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ ┌──────────────────────────────────────────────────────────────────────┐ │
│ │ 阶段1:与世界模型交互 │ │
│ │ │ │
│ │ 真实环境 │ │
│ │ ┌─────────┐ a0 ┌─────────┐ a1 ┌─────────┐ │ │
│ │ │ s0=观察 │ ───► │ s1=观察 │ ───► │ s2=观察 │ ───► ... │ │
│ │ └─────────┘ └─────────┘ └─────────┘ │ │
│ │ │ │ │ │
│ │ │ ε0 │ │ │
│ │ ▼ ▼ │ │
│ │ ┌─────────┐ ┌──────────┐ │ │
│ │ │ a0=动作 │ │经验缓冲池│ │ │
│ │ └────┬────┘ │ D │ │ │
│ │ │ └────┬───┘ │ │
│ │ └──────────────────────────────────────────────────────────┘ │ │
│ └──────────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌──────────────────────────────────────────────────────────────────────┐ │
│ │ 阶段2:训练世界模型 │ │
│ │ │ │
│ │ 从经验缓冲池 D 采样批量数据: (s_t, a_t, r_t, s_{t+1}) │ │
│ │ │ │
│ │ ┌─────────────────────────────────────────────────────┐ │ │
│ │ │ 世界模型 (World Model) │ │ │
│ │ │ │ │ │
│ │ │ 编码器: s_t → z_t (潜在状态) │ │ │
│ │ │ ↓ │ │ │
│ │ │ 隐式动态模型: p(z_{t+1}|z_t, a_t) │ │ │
│ │ │ ↓ │ │ │
│ │ │ 解码器: z_t → ŝ_t (重建观察) │ │ │
│ │ │ ↓ │ │ │
│ │ │ 奖励预测: r̂_t = Reward(z_t, a_t) │ │ │
│ │ │ │ │ │
│ │ │ 损失函数: L = L重建(ŝ_t, s_t) + L奖励(r̂_t, r_t) │ │ │
│ │ └─────────────────────────────────────────────────────┘ │ │
│ │ ↓ │ │
│ │ 梯度更新 │ │
│ └──────────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌──────────────────────────────────────────────────────────────────────┐ │
│ │ 阶段3:在想象中学习策略 │ │
│ │ │ │
│ │ 使用学习好的世界模型,生成"想象轨迹"(完全在模型中进行) │ │
│ │ │ │
│ │ z_0 ~ p(z_0) 初始潜在状态分布 │ │
│ │ │ │ │
│ │ │ For t = 0 to H (想象 horizon): │ │
│ │ │ a_t ~ π_θ(a_t | z_t) ← 演员网络采样动作 │ │
│ │ │ r_t, z_{t+1} ~ p(z_{t+1}|z_t, a_t) ← 世界模型 │ │
│ │ │ λ_t = γ^t ← 计算折扣回报 │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ ┌─────────────────────────────────────────────────────┐ │ │
│ │ │ 演员-评论家 (Actor-Critic) 学习 │ │ │
│ │ │ │ │ │
│ │ │ 演员 π_θ: 最大化想象轨迹上的价值 V^π(z_t) │ │ │
│ │ │ 评论家 V_ψ: 拟合想象轨迹上的回报期望 │ │ │
│ │ │ │ │ │
│ │ │ 使用 TD 学习更新评论家,再用策略梯度更新演员 │ │ │
│ │ └─────────────────────────────────────────────────────┘ │ │
│ └──────────────────────────────────────────────────────────────────────┘ │
│ │ │
│ │ 循环回到阶段1,继续交互 │
│ └─────────────────────────────────────► │
│ │
└─────────────────────────────────────────────────────────────────────────────┘
5.3 Dreamer 的关键组件
"""
Dreamer 算法的核心组件伪代码
"""
import torch
import torch.nn as nn
# ┌────────────────────────────────────────────────────────────────┐
# │ 组件1: 世界模型 (World Model) │
# └────────────────────────────────────────────────────────────────┘
class WorldModel(nn.Module):
"""
世界模型包含三个部分:
1. 编码器 (Encoder): 图像/高维观察 → 潜在向量 z
2. 动态模型 (Dynamics): 潜在状态 + 动作 → 下一潜在状态
3. 奖励预测器 (Reward Predictor)
"""
def __init__(self, obs_dim, action_dim, latent_dim=32, hidden_dim=256):
super().__init__()
self.latent_dim = latent_dim
# 编码器: o_t → z_t (变分编码)
self.encoder = nn.Sequential(
nn.Linear(obs_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, latent_dim * 2) # 输出均值和方差
)
# 动态模型: (z_t, a_t) → z_{t+1}
self.dynamics_net = nn.Sequential(
nn.Linear(latent_dim + action_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, latent_dim * 2) # 输出均值和方差
)
# 奖励预测器: (z_t, a_t) → r_t
self.reward_net = nn.Sequential(
nn.Linear(latent_dim + action_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
def forward(self, obs, action):
"""世界模型前向传播"""
# 编码观察
z_mean, z_logvar = torch.chunk(self.encoder(obs), 2, dim=-1)
z = z_mean + torch.randn_like(z_mean) * torch.exp(0.5 * z_logvar)
# 动态预测
z_a = torch.cat([z, action], dim=-1)
z_next_mean, z_next_logvar = torch.chunk(self.dynamics_net(z_a), 2, dim=-1)
# 奖励预测
r_pred = self.reward_net(z_a)
return z, z_next_mean, z_next_logvar, r_pred
# ┌────────────────────────────────────────────────────────────────┐
# │ 组件2: 演员-评论家 (Actor-Critic) │
# └────────────────────────────────────────────────────────────────┘
class ActorCritic(nn.Module):
"""
演员-评论家网络: 在潜在空间中进行学习
"""
def __init__(self, latent_dim, action_dim, hidden_dim=256, action_bound=1.0):
super().__init__()
self.action_bound = action_bound
# 评论家: 状态价值函数 V(z)
self.critic = nn.Sequential(
nn.Linear(latent_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
# 演员: 策略函数 π(a|z)
self.actor = nn.Sequential(
nn.Linear(latent_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, action_dim * 2) # 输出均值和方差
)
def forward(self, z):
v = self.critic(z)
mean, logvar = torch.chunk(self.actor(z), 2, dim=-1)
logvar = torch.clamp(logvar, -20, 2)
return v, mean, logvar
def sample(self, z):
"""从策略中采样动作"""
_, mean, logvar = self.forward(z)
std = torch.exp(0.5 * logvar)
action = torch.randn_like(mean) * std + mean
action = torch.tanh(action) * self.action_bound
return action
def dreamer_training_loop(env, world_model, actor_critic,
imagination_horizon=15, num_epochs=100):
"""
Dreamer 训练循环伪代码
核心思想: 在世界模型的"想象"中进行强化学习
"""
replay_buffer = ReplayBuffer(capacity=100000)
for epoch in range(num_epochs):
# ---- 阶段1: 收集真实经验 ----
obs = env.reset()
for step in range(1000): # 与环境交互收集数据
action = actor_critic.sample(
world_model.encoder(torch.FloatTensor(obs).unsqueeze(0))
).detach().numpy().squeeze(0)
next_obs, reward, done, _ = env.step(action)
replay_buffer.push(obs, action, reward, next_obs, done)
obs = next_obs if not done else env.reset()
# ---- 阶段2: 训练世界模型 ----
for _ in range(100):
batch = replay_buffer.sample(batch_size=256)
obs, actions, rewards, next_obs, dones = batch
# 训练世界模型: 重建观察 + 预测奖励
z, z_next_mean, z_next_logvar, r_pred = world_model(obs, actions)
# ... 计算损失并更新 world_model ...
# ---- 阶段3: 在想象中学习策略 ----
# 想象轨迹收集
imagined_observations = []
imagined_actions = []
imagined_rewards = []
imagined_values = []
z = torch.randn(batch_size, latent_dim) # 从先验采样初始状态
for t in range(imagination_horizon):
action = actor_critic.sample(z)
_, z_next_mean, z_next_logvar, r_pred = world_model(z, action)
z = z_next_mean + torch.randn_like(z_next_mean) * torch.exp(0.5 * z_next_logvar)
v, _, _ = actor_critic(z)
imagined_observations.append(z)
imagined_actions.append(action)
imagined_rewards.append(r_pred)
imagined_values.append(v)
# 计算回报并更新 actor-critic
# ... 使用 REINFORCE 或策略梯度更新 ...
return world_model, actor_critic
# ┌────────────────────────────────────────────────────────────────┐
# │ Dreamer vs Model-Free 对比 │
# └────────────────────────────────────────────────────────────────┘
"""
┌──────────────────────────────────────────────────────────────┐
│ Dreamer 的优势 │
├──────────────────────────────────────────────────────────────┤
│ ✅ 样本效率极高: 想象轨迹可以无限生成,不受真实交互限制 │
│ ✅ 规划能力: 可以在模型中预测未来,进行前向规划 │
│ ✅ 安全探索: 高风险动作可以在模型中先模拟 │
│ │
│ ⚠️ 核心挑战: 模型误差会累积,想象轨迹越长,误差越大 │
│ 解决思路: 限制 imagination horizon,或使用模型不确定性估计 │
└──────────────────────────────────────────────────────────────┘
"""
六、基于模型RL的优势与挑战
6.1 核心优势
┌─────────────────────────────────────────────────────────────────────────┐
│ 基于模型强化学习:四大核心优势 │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ 1️⃣ 样本效率高 (Sample Efficiency) │
│ ┌──────────────────────────────────────────────────────┐ │
│ │ │ │
│ │ Model-Free: ████████████████████████████ ≈ 100万步 │ │
│ │ Model-Based: ███ ≈ 1万步 │ │
│ │ │ │
│ │ 在模型内部可以"想象"无数次交互,无需真实环境采样 │ │
│ └──────────────────────────────────────────────────────┘ │
│ │
│ 2️⃣ 规划能力 (Planning Capability) │
│ ┌──────────────────────────────────────────────────────┐ │
│ │ • 前向搜索: MCTS, 蒙特卡洛规划 │ │
│ │ • 轨迹优化: iLQR, Model Predictive Control (MPC) │ │
│ │ • 模型可以是"可解释的": 人类能理解模型的预测 │ │
│ └──────────────────────────────────────────────────────┘ │
│ │
│ 3️⃣ 安全探索 (Safe Exploration) │
│ ┌──────────────────────────────────────────────────────┐ │
│ │ 高风险行为可以先在模型中模拟: │ │
│ │ "如果我这么做,会发生什么?" → 模型回答! │ │
│ │ 不需要在真实环境中冒险尝试危险动作 │ │
│ └──────────────────────────────────────────────────────┘ │
│ │
│ 4️⃣ 跨任务迁移 (Transfer Learning) │
│ ┌──────────────────────────────────────────────────────┐ │
│ │ 世界模型捕捉的是环境的"物理规律" │ │
│ │ 奖励变了?只改奖励函数,模型可复用! │ │
│ │ 适合 sim-to-real 迁移 │ │
│ └──────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────┘
6.2 核心挑战
┌─────────────────────────────────────────────────────────────────────────┐
│ 基于模型强化学习:四大核心挑战 │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ 1️⃣ 模型误差累积 (Model Error Accumulation) ⭐⭐⭐ │
│ │
│ 问题: 多步预测后误差呈指数增长 │
│ │
│ s0(真实) ──► s1(真实) ──► s2(真实) ──► ... │
│ │ │ │ │
│ ▼ ▼ ▼ │
│ ŝ0=s0 ŝ1 ≠ s1 ŝ2 偏离更远 │
│ │ ↑ ↑ │
│ └──模型预测──┴───────────┘ │
│ │
│ 记忆口诀:【一步错,步步错,滚雪球效应】 │
│ │
│ 解决思路: │
│ - 限制规划步数(short horizon) │
│ - 使用集成模型(ensemble)估计不确定性 │
│ - Dyna-style: 模型+真实经验混合 │
│ │
│ 2️⃣ 分布偏移 (Distribution Shift) │
│ │
│ 问题: 模型只在「见过的状态-动作对」上训练 │
│ 规划器可能探索未见区域,模型外推能力差 │
│ │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ 训练分布: ○ ○ ○ ○ ○ ○ ○ ○ │ │
│ │ 测试/规划: ╳ ╳ ╳ ← 模型不知道! │ │
│ └─────────────────────────────────────────────────────┘ │
│ │
│ 3️⃣ 复合误差 (Compounding Errors) │
│ │
│ 问题: 单步误差会被后续决策放大 │
│ 类比: 天气预报误差,今天误差1°C,7天后可能差10°C │
│ │
│ 4️⃣ 训练复杂度高 │
│ │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ Model-Free: 单一目标(最大化价值) │ │
│ │ Model-Based: 多目标 │ │
│ │ - 拟合环境动态 (回归问题) │ │
│ │ - 预测奖励 (回归问题) │ │
│ │ - 优化策略 (强化学习问题) │ │
│ │ 三者需要同时优化,相互影响! │ │
│ └─────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────┘
6.3 经典算法一览
| 算法 | 模型类型 | 规划方式 | 代表工作 |
|---|---|---|---|
| Dyna-Q | 表格 | Q-规划 | Sutton 1991 |
| 蒙特卡洛树搜索 | 无模型(在线) | 树搜索 | Kocsis & Szepesvari 2006 |
| UCT | 无模型(在线) | UCB树搜索 | Gelly et al. 2006 |
| AlphaGo | 深度网络 | MCTS+深度策略/价值网络 | Silver et al. 2016 |
| World Models | VAE+RNN | 在隐空间规划 | Ha & Schmidhuber 2018 |
| Dreamer | 变分模型 | 想象轨迹+Actor-Critic | Hafner et al. 2020 |
| MuZero | 隐式模型 | MCTS+无真实模型规划 | Schrittwieser et al. 2020 |
| Model-Based OPC | MPC | 模型预测控制 | Chua et al. 2018 |
七、记忆口诀
🎯 基于模型强化学习核心口诀
┌─────────────────────────────────────────────────────────────┐ │ │ │ 【学模型,定环境,P和R要分清】 │ │ P(s'|s,a) 状态转移,R(s,a) 奖励函数 │ │ │ │ 【MCTS四步走,SEL-EXP-SIM-BACK】 │ │ Selection选节点,Expansion扩分支 │ │ Simulation随机rollout,Backprop反向传 │ │ UCB公式记心头,Q加探索除以次数 │ │ │ │ 【AlphaGo三改进,策略先验+价值网】 │ │ PUCT取代UCB1,策略引导探索方向 │ │ 价值网络取代rollout,评估更稳不摇晃 │ │ │ │ 【Dreamer三步走,建模→想象→学习】 │ │ 世界模型学环境,想象轨迹无限生 │ │ 演员评论家来优化,样本效率第一名 │ │ │ │ 【模型误差像滚雪球,越滚越大别忘啦】 │ │ horizon要控制,不确定性要估计 │ │ │ └─────────────────────────────────────────────────────────────┘
八、延伸阅读与参考文献
| 论文 | 作者 | 年份 | 关键贡献 |
|---|---|---|---|
| Dyna-Q | Sutton | 1991 | Dyna架构:模型+无模型混合 |
| UCT | Kocsis & Szepesvari | 2006 | UCB用于MCTS的理论基础 |
| Monte Carlo Go | Gelly et al. | 2006 | MCTS应用于围棋 |
| Mastering the Game of Go | Silver et al. | 2016 | AlphaGo: MCTS+深度学习 |
| World Models | Ha & Schmidhuber | 2018 | 在隐空间学习世界模型 |
| Dreamer: Learning Behaviors | Hafner et al. | 2020 | 变分世界模型+想象学习 |
| Mastering Atari with discrete world models | Hafner et al. | 2023 | Dreamer v2/v3 Atari成就 |
| MuZero | Schrittwieser et al. | 2020 | 无需真实模型的AlphaZero |
结语
基于模型的强化学习是强化学习领域的重要分支,它通过学习环境模型来实现高效规划和样本高效学习。
核心要点回顾:
- 🌟 模型学习:监督学习 P(s’|s,a) 和 R(s,a)
- 🌟 MCTS:四步循环(选择/扩展/模拟/反传)+ UCB平衡探索利用
- 🌟 AlphaGo:策略网络+价值网络+MCTS的完美结合
- 🌟 Dreamer:世界模型+想象轨迹+Actor-Critic的端到端学习
- ⚠️ 关键挑战:模型误差累积、分布偏移、复合误差
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐


所有评论(0)