【高级强化学习:算法、优化与泛化】第二章 深度强化学习的高样本效率与模型化方法
目录
2.1 Off-Policy Actor-Critic与熵优化
2.1.2.1 CQL(Conservative Q-Learning)
2.1.2.2 IQL(Implicit Q-Learning):期望价值(expectile regression)替代max操作,避免OOD动作查询的离线到在线微调代码
2.2.1.1 PETS概率性模型:基于高斯混合网络(MDN)的集成模型(Ensemble)实现DON(Distillation via Online Network)对抗模型误差
2.2.2 潜在空间规划(Latent Space Planning)
2.2.2.1 PlaNet/Dreamer架构:RSSM(Recurrent State-Space Model)的确定性路径与随机先验分解,KL散度损失加权(β=1.0)
2.2.2.2 交叉熵方法(CEM)在图像输入中的动作序列优化:通过迭代采样(精英比例top-k=10%)规划未来H步动作序列
2.3.1.1 演员-学习者(Actor-Learner)分离架构:RingBuffer实现异步数据传输,参数服务器(Parameter Server)的Ray分布式实现
2.3.1.2 V-trace偏置校正:重要性采样比率(ρ)截断(c=1.0)与Actor-Critic共享卷积编码器的梯度分离技术
2.3.2.1 Isaac Gym/Brax物理引擎:全GPU并行模拟( thousands of envs )与JAX JIT编译的端到端训练流水线
2.3.2.2 RLlib中的SampleCollector优化:通过压缩观测(LZ4/ObsCompression)与异步梯度计算减少CPU-GPU传输瓶颈
第二章 高样本效率与模型化方法
2.1 Off-Policy Actor-Critic与熵优化
2.1.1 TD3与SAC的工程细节
-
2.1.1.1 TD3双延迟更新:
-
目标策略平滑(Target Policy Smoothing)通过向目标动作添加裁剪噪声(clip noise σ=0.2)减少过拟合
深度强化学习系统在样本效率方面的提升主要依赖于对价值函数估计偏差的抑制以及对环境交互数据的复用。双延迟深度确定性策略梯度算法通过维护两组独立的参数化动作价值估计,在策略优化阶段选取二者中较小者作为优化目标,从而有效缓解持续学习过程中的估计值膨胀现象。目标策略平滑机制在目标网络的动作选择环节引入服从截断正态分布的随机扰动,通过对动作空间施加有界噪声,平滑价值函数在动作维度上的变化率,降低策略对特定动作模式过度拟合的风险。
Python
"""
TD3 (Twin Delayed Deep Deterministic Policy Gradient) 实现脚本
涉及内容:双延迟更新、目标策略平滑、双评论家网络、软更新机制
使用方式:python td3.py --env HalfCheetah-v4 --seed 0
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import gymnasium as gym
from collections import deque
import random
import argparse
class Actor(nn.Module):
def __init__(self, state_dim, action_dim, max_action, hidden_dim=256):
super(Actor, self).__init__()
self.l1 = nn.Linear(state_dim, hidden_dim)
self.l2 = nn.Linear(hidden_dim, hidden_dim)
self.l3 = nn.Linear(hidden_dim, action_dim)
self.max_action = max_action
def forward(self, state):
a = F.relu(self.l1(state))
a = F.relu(self.l2(a))
return self.max_action * torch.tanh(self.l3(a))
class Critic(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=256):
super(Critic, self).__init__()
# Twin Q architecture (Q1 and Q2)
self.l1 = nn.Linear(state_dim + action_dim, hidden_dim)
self.l2 = nn.Linear(hidden_dim, hidden_dim)
self.l3 = nn.Linear(hidden_dim, 1)
self.l4 = nn.Linear(state_dim + action_dim, hidden_dim)
self.l5 = nn.Linear(hidden_dim, hidden_dim)
self.l6 = nn.Linear(hidden_dim, 1)
def forward(self, state, action):
sa = torch.cat([state, action], 1)
q1 = F.relu(self.l1(sa))
q1 = F.relu(self.l2(q1))
q1 = self.l3(q1)
q2 = F.relu(self.l4(sa))
q2 = F.relu(self.l5(q2))
q2 = self.l6(q2)
return q1, q2
def Q1(self, state, action):
sa = torch.cat([state, action], 1)
q1 = F.relu(self.l1(sa))
q1 = F.relu(self.l2(q1))
q1 = self.l3(q1)
return q1
class ReplayBuffer:
def __init__(self, state_dim, action_dim, max_size=int(1e6)):
self.max_size = max_size
self.ptr = 0
self.size = 0
self.state = np.zeros((max_size, state_dim))
self.action = np.zeros((max_size, action_dim))
self.next_state = np.zeros((max_size, state_dim))
self.reward = np.zeros((max_size, 1))
self.not_done = np.zeros((max_size, 1))
def add(self, state, action, next_state, reward, done):
self.state[self.ptr] = state
self.action[self.ptr] = action
self.next_state[self.ptr] = next_state
self.reward[self.ptr] = reward
self.not_done[self.ptr] = 1. - done
self.ptr = (self.ptr + 1) % self.max_size
self.size = min(self.size + 1, self.max_size)
def sample(self, batch_size):
ind = np.random.randint(0, self.size, size=batch_size)
return (
torch.FloatTensor(self.state[ind]),
torch.FloatTensor(self.action[ind]),
torch.FloatTensor(self.next_state[ind]),
torch.FloatTensor(self.reward[ind]),
torch.FloatTensor(self.not_done[ind])
)
class TD3:
def __init__(self, state_dim, action_dim, max_action, device,
discount=0.99, tau=0.005, policy_noise=0.2,
noise_clip=0.5, policy_freq=2):
self.device = device
self.actor = Actor(state_dim, action_dim, max_action).to(device)
self.actor_target = Actor(state_dim, action_dim, max_action).to(device)
self.actor_target.load_state_dict(self.actor.state_dict())
self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=3e-4)
self.critic = Critic(state_dim, action_dim).to(device)
self.critic_target = Critic(state_dim, action_dim).to(device)
self.critic_target.load_state_dict(self.critic.state_dict())
self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=3e-4)
self.max_action = max_action
self.discount = discount
self.tau = tau
self.policy_noise = policy_noise
self.noise_clip = noise_clip
self.policy_freq = policy_freq
self.total_it = 0
def select_action(self, state, noise=0.1):
state = torch.FloatTensor(state.reshape(1, -1)).to(self.device)
action = self.actor(state).cpu().data.numpy().flatten()
if noise != 0:
action = (action + np.random.normal(0, noise, size=action.shape))
return np.clip(action, -self.max_action, self.max_action)
def update(self, replay_buffer, batch_size=256):
self.total_it += 1
state, action, next_state, reward, not_done = replay_buffer.sample(batch_size)
state = state.to(self.device)
action = action.to(self.device)
next_state = next_state.to(self.device)
reward = reward.to(self.device)
not_done = not_done.to(self.device)
with torch.no_grad():
# Target policy smoothing: add clipped noise to target action
noise = (torch.randn_like(action) * self.policy_noise).clamp(
-self.noise_clip, self.noise_clip
)
next_action = (self.actor_target(next_state) + noise).clamp(
-self.max_action, self.max_action
)
target_Q1, target_Q2 = self.critic_target(next_state, next_action)
target_Q = torch.min(target_Q1, target_Q2)
target_Q = reward + not_done * self.discount * target_Q
current_Q1, current_Q2 = self.critic(state, action)
critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()
# Delayed policy updates
if self.total_it % self.policy_freq == 0:
actor_loss = -self.critic.Q1(state, self.actor(state)).mean()
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
# Soft update
for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
def save(self, filename):
torch.save(self.critic.state_dict(), filename + "_critic")
torch.save(self.critic_optimizer.state_dict(), filename + "_critic_optimizer")
torch.save(self.actor.state_dict(), filename + "_actor")
torch.save(self.actor_optimizer.state_dict(), filename + "_actor_optimizer")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--env", default="HalfCheetah-v4", type=str)
parser.add_argument("--seed", default=0, type=int)
parser.add_argument("--start_timesteps", default=25000, type=int)
parser.add_argument("--eval_freq", default=5000, type=int)
parser.add_argument("--max_timesteps", default=1000000, type=int)
args = parser.parse_args()
file_name = f"TD3_{args.env}_{args.seed}"
print(f"Running TD3 on {args.env} with seed {args.seed}")
env = gym.make(args.env)
env.action_space.seed(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
max_action = float(env.action_space.high[0])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
policy = TD3(state_dim, action_dim, max_action, device)
replay_buffer = ReplayBuffer(state_dim, action_dim)
state, _ = env.reset(seed=args.seed)
episode_reward = 0
episode_timesteps = 0
episode_num = 0
for t in range(args.max_timesteps):
episode_timesteps += 1
if t < args.start_timesteps:
action = env.action_space.sample()
else:
action = policy.select_action(np.array(state))
next_state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
episode_reward += reward
replay_buffer.add(state, action, next_state, reward, float(done))
state = next_state
if t >= args.start_timesteps:
policy.update(replay_buffer)
if done:
print(f"Episode {episode_num+1} | Timesteps {t+1} | Reward {episode_reward:.2f}")
state, _ = env.reset()
episode_reward = 0
episode_timesteps = 0
episode_num += 1
if (t + 1) % args.eval_freq == 0:
policy.save(f"./models/{file_name}")
if __name__ == "__main__":
main()
-
2.1.1.2 SAC自动温度调整:
-
通过Lagrange乘子法优化温度系数α,实现最大熵目标的动态平衡(PyTorch优化器嵌套优化)
软演员-评论家算法在最大化期望累积回报的同时引入策略熵正则化项,通过鼓励动作分布的随机性维持探索能力。温度系数控制熵奖励与任务奖励之间的权重比例,其自适应调整机制构建基于对偶优化的约束满足问题,利用Lagrange乘子法在策略迭代过程中动态求解最优温度参数,使策略熵维持在预设的目标值附近,避免手动调参导致的探索强度随训练阶段变化而失配。
Python
复制
"""
SAC (Soft Actor-Critic) 实现脚本
涉及内容:自动温度调整、重参数化技巧、双Q网络、高斯策略
使用方式:python sac.py --env HalfCheetah-v4 --alpha auto
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import gymnasium as gym
from torch.distributions import Normal
import argparse
LOG_SIG_MAX = 2
LOG_SIG_MIN = -20
epsilon = 1e-6
class Actor(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=256):
super(Actor, self).__init__()
self.l1 = nn.Linear(state_dim, hidden_dim)
self.l2 = nn.Linear(hidden_dim, hidden_dim)
self.mean = nn.Linear(hidden_dim, action_dim)
self.log_std = nn.Linear(hidden_dim, action_dim)
def forward(self, state):
a = F.relu(self.l1(state))
a = F.relu(self.l2(a))
mean = self.mean(a)
log_std = self.log_std(a)
log_std = torch.clamp(log_std, min=LOG_SIG_MIN, max=LOG_SIG_MAX)
return mean, log_std
def sample(self, state):
mean, log_std = self.forward(state)
std = log_std.exp()
normal = Normal(mean, std)
x_t = normal.rsample() # Reparameterization trick
y_t = torch.tanh(x_t)
action = y_t
log_prob = normal.log_prob(x_t)
# Enforcing Action Bound
log_prob -= torch.log(1 - y_t.pow(2) + epsilon)
log_prob = log_prob.sum(1, keepdim=True)
mean = torch.tanh(mean)
return action, log_prob, mean
class Critic(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=256):
super(Critic, self).__init__()
# Twin Q networks
self.l1 = nn.Linear(state_dim + action_dim, hidden_dim)
self.l2 = nn.Linear(hidden_dim, hidden_dim)
self.l3 = nn.Linear(hidden_dim, 1)
self.l4 = nn.Linear(state_dim + action_dim, hidden_dim)
self.l5 = nn.Linear(hidden_dim, hidden_dim)
self.l6 = nn.Linear(hidden_dim, 1)
def forward(self, state, action):
sa = torch.cat([state, action], 1)
q1 = F.relu(self.l1(sa))
q1 = F.relu(self.l2(q1))
q1 = self.l3(q1)
q2 = F.relu(self.l4(sa))
q2 = F.relu(self.l5(q2))
q2 = self.l6(q2)
return q1, q2
class ReplayBuffer:
def __init__(self, state_dim, action_dim, max_size=int(1e6)):
self.max_size = max_size
self.ptr = 0
self.size = 0
self.state = np.zeros((max_size, state_dim))
self.action = np.zeros((max_size, action_dim))
self.next_state = np.zeros((max_size, state_dim))
self.reward = np.zeros((max_size, 1))
self.not_done = np.zeros((max_size, 1))
def add(self, state, action, next_state, reward, done):
self.state[self.ptr] = state
self.action[self.ptr] = action
self.next_state[self.ptr] = next_state
self.reward[self.ptr] = reward
self.not_done[self.ptr] = 1. - done
self.ptr = (self.ptr + 1) % self.max_size
self.size = min(self.size + 1, self.max_size)
def sample(self, batch_size):
ind = np.random.randint(0, self.size, size=batch_size)
return (
torch.FloatTensor(self.state[ind]),
torch.FloatTensor(self.action[ind]),
torch.FloatTensor(self.next_state[ind]),
torch.FloatTensor(self.reward[ind]),
torch.FloatTensor(self.not_done[ind])
)
class SAC:
def __init__(self, state_dim, action_dim, device,
discount=0.99, tau=0.005, alpha=0.2,
automatic_entropy_tuning=True, target_entropy=None):
self.device = device
self.discount = discount
self.tau = tau
self.actor = Actor(state_dim, action_dim).to(device)
self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=3e-4)
self.critic = Critic(state_dim, action_dim).to(device)
self.critic_target = Critic(state_dim, action_dim).to(device)
self.critic_target.load_state_dict(self.critic.state_dict())
self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=3e-4)
self.automatic_entropy_tuning = automatic_entropy_tuning
if self.automatic_entropy_tuning:
# Target entropy = -dim(A)
if target_entropy is None:
self.target_entropy = -action_dim
else:
self.target_entropy = target_entropy
self.log_alpha = torch.zeros(1, requires_grad=True, device=device)
self.alpha_optimizer = optim.Adam([self.log_alpha], lr=3e-4)
self.alpha = self.log_alpha.exp()
else:
self.alpha = alpha
def select_action(self, state, evaluate=False):
state = torch.FloatTensor(state).to(self.device).unsqueeze(0)
if evaluate:
_, _, action = self.actor.sample(state)
else:
action, _, _ = self.actor.sample(state)
return action.detach().cpu().numpy()[0]
def update(self, replay_buffer, batch_size=256):
state, action, next_state, reward, not_done = replay_buffer.sample(batch_size)
state = state.to(self.device)
action = action.to(self.device)
next_state = next_state.to(self.device)
reward = reward.to(self.device)
not_done = not_done.to(self.device)
with torch.no_grad():
next_action, next_log_prob, _ = self.actor.sample(next_state)
target_Q1, target_Q2 = self.critic_target(next_state, next_action)
target_Q = torch.min(target_Q1, target_Q2)
target_Q = reward + not_done * self.discount * (target_Q - self.alpha * next_log_prob)
current_Q1, current_Q2 = self.critic(state, action)
critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()
pi, log_pi, _ = self.actor.sample(state)
qf1_pi, qf2_pi = self.critic(state, pi)
min_qf_pi = torch.min(qf1_pi, qf2_pi)
actor_loss = ((self.alpha * log_pi) - min_qf_pi).mean()
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
if self.automatic_entropy_tuning:
alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()
self.alpha_optimizer.zero_grad()
alpha_loss.backward()
self.alpha_optimizer.step()
self.alpha = self.log_alpha.exp()
for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--env", default="HalfCheetah-v4")
parser.add_argument("--seed", default=0, type=int)
parser.add_argument("--alpha", default="auto", help="0.2 or auto")
args = parser.parse_args()
env = gym.make(args.env)
env.action_space.seed(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
automatic_entropy_tuning = (args.alpha == "auto")
alpha = 0.2 if not automatic_entropy_tuning else None
policy = SAC(state_dim, action_dim, device,
automatic_entropy_tuning=automatic_entropy_tuning,
alpha=alpha)
replay_buffer = ReplayBuffer(state_dim, action_dim)
state, _ = env.reset(seed=args.seed)
for t in range(1000000):
if t < 10000:
action = env.action_space.sample()
else:
action = policy.select_action(state)
next_state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
replay_buffer.add(state, action, next_state, reward, float(done))
state = next_state
if t >= 10000:
policy.update(replay_buffer)
if done:
state, _ = env.reset()
if __name__ == "__main__":
main()
2.1.2 离线强化学习(Offline RL)部署
-
2.1.2.1 CQL(Conservative Q-Learning)
-
在标准贝尔曼误差上添加CQL正则项,防止对未见状态-动作对的过估计
保守Q学习算法通过修改标准贝尔曼更新规则,在最小化时间差分误差的同时施加正则化约束,抑制对未在离线数据集中出现的状态-动作对的过高估计。该方法在动作价值函数的学习过程中引入保守性惩罚项,要求学习到的Q值低于真实值,从而在策略评估阶段避免对分布外动作的过度乐观预测,确保从静态数据集中提取的策略在实际环境部署时的稳定性与安全性。
Python
复制
"""
CQL (Conservative Q-Learning) 实现脚本
涉及内容:保守正则项、CQL(H)变体、离线数据集加载、SAC基础架构修改
使用方式:python cql.py --dataset_path d4rl_hopper-medium-v2.npz --cql_weight 5.0
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from torch.utils.data import Dataset, DataLoader
import argparse
class Actor(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=256):
super(Actor, self).__init__()
self.l1 = nn.Linear(state_dim, hidden_dim)
self.l2 = nn.Linear(hidden_dim, hidden_dim)
self.mean = nn.Linear(hidden_dim, action_dim)
self.log_std = nn.Linear(hidden_dim, action_dim)
def forward(self, state):
a = F.relu(self.l1(state))
a = F.relu(self.l2(a))
mean = self.mean(a)
log_std = self.log_std(a)
log_std = torch.clamp(log_std, min=-20, max=2)
return mean, log_std
def sample(self, state):
mean, log_std = self.forward(state)
std = log_std.exp()
normal = torch.distributions.Normal(mean, std)
x_t = normal.rsample()
action = torch.tanh(x_t)
log_prob = normal.log_prob(x_t)
log_prob -= torch.log(1 - action.pow(2) + 1e-6)
return action, log_prob, mean
class Critic(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=256):
super(Critic, self).__init__()
self.l1 = nn.Linear(state_dim + action_dim, hidden_dim)
self.l2 = nn.Linear(hidden_dim, hidden_dim)
self.l3 = nn.Linear(hidden_dim, 1)
self.l4 = nn.Linear(state_dim + action_dim, hidden_dim)
self.l5 = nn.Linear(hidden_dim, hidden_dim)
self.l6 = nn.Linear(hidden_dim, 1)
def forward(self, state, action):
sa = torch.cat([state, action], 1)
q1 = F.relu(self.l1(sa))
q1 = F.relu(self.l2(q1))
q1 = self.l3(q1)
q2 = F.relu(self.l4(sa))
q2 = F.relu(self.l5(q2))
q2 = self.l6(q2)
return q1, q2
class OfflineDataset(Dataset):
def __init__(self, data_path):
data = np.load(data_path)
self.states = torch.FloatTensor(data['observations'])
self.actions = torch.FloatTensor(data['actions'])
self.next_states = torch.FloatTensor(data['next_observations'])
self.rewards = torch.FloatTensor(data['rewards']).unsqueeze(1)
self.dones = torch.FloatTensor(data['terminals']).unsqueeze(1)
self.size = len(self.states)
def __len__(self):
return self.size
def __getitem__(self, idx):
return (self.states[idx], self.actions[idx], self.next_states[idx],
self.rewards[idx], self.dones[idx])
class CQL:
def __init__(self, state_dim, action_dim, device, cql_weight=5.0,
target_action_gap=10, temp=1.0):
self.device = device
self.cql_weight = cql_weight
self.target_action_gap = target_action_gap
self.temp = temp
self.actor = Actor(state_dim, action_dim).to(device)
self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=3e-4)
self.critic = Critic(state_dim, action_dim).to(device)
self.critic_target = Critic(state_dim, action_dim).to(device)
self.critic_target.load_state_dict(self.critic.state_dict())
self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=3e-4)
self.discount = 0.99
self.tau = 0.005
def cql_loss(self, q_values, actions, dataset_actions):
# CQL(H) variant with importance weighting
cql1_loss = torch.logsumexp(q_values / self.temp, dim=1, keepdim=True) * self.temp
cql1_loss = cql1_loss - q_values.gather(1, actions.long())
return cql1_loss.mean()
def update(self, batch):
state, action, next_state, reward, done = batch
state = state.to(self.device)
action = action.to(self.device)
next_state = next_state.to(self.device)
reward = reward.to(self.device)
done = done.to(self.device)
with torch.no_grad():
next_action, next_log_pi, _ = self.actor.sample(next_state)
target_Q1, target_Q2 = self.critic_target(next_state, next_action)
target_Q = torch.min(target_Q1, target_Q2) - 0.2 * next_log_pi
target_Q = reward + (1 - done) * 0.99 * target_Q
current_Q1, current_Q2 = self.critic(state, action)
bellman_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)
# CQL regularization
# Sample random actions for conservative estimation
batch_size = state.shape[0]
random_actions = torch.FloatTensor(batch_size, action.shape[1]).uniform_(-1, 1).to(self.device)
random_q1, random_q2 = self.critic(state, random_actions)
# Sample actions from current policy
current_actions, current_log_pi, _ = self.actor.sample(state)
current_q1, current_q2 = self.critic(state, current_actions)
cat_q1 = torch.cat([random_q1, current_q1], dim=1)
cat_q2 = torch.cat([random_q2, current_q2], dim=1)
cql1_loss = torch.logsumexp(cat_q1 / self.temp, dim=1).mean() * self.temp
cql2_loss = torch.logsumexp(cat_q2 / self.temp, dim=1).mean() * self.temp
cql1_loss = cql1_loss - current_Q1.mean()
cql2_loss = cql2_loss - current_Q2.mean()
conservative_loss = cql1_loss + cql2_loss
total_critic_loss = bellman_loss + self.cql_weight * conservative_loss
self.critic_optimizer.zero_grad()
total_critic_loss.backward()
torch.nn.utils.clip_grad_norm_(self.critic.parameters(), 1.0)
self.critic_optimizer.step()
# Actor update
new_actions, log_pi, _ = self.actor.sample(state)
qf1_new, qf2_new = self.critic(state, new_actions)
min_qf = torch.min(qf1_new, qf2_new)
actor_loss = (0.2 * log_pi - min_qf).mean()
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
# Soft update
for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
target_param.data.copy_(0.005 * param.data + 0.995 * target_param.data)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_path", required=True)
parser.add_argument("--cql_weight", default=5.0, type=float)
parser.add_argument("--seed", default=0, type=int)
args = parser.parse_args()
torch.manual_seed(args.seed)
np.random.seed(args.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset = OfflineDataset(args.dataset_path)
dataloader = DataLoader(dataset, batch_size=256, shuffle=True)
state_dim = dataset.states.shape[1]
action_dim = dataset.actions.shape[1]
policy = CQL(state_dim, action_dim, device, cql_weight=args.cql_weight)
for epoch in range(1000):
for batch in dataloader:
policy.update(batch)
if epoch % 10 == 0:
print(f"Epoch {epoch} completed")
if __name__ == "__main__":
main()
-
2.1.2.2 IQL(Implicit Q-Learning):期望价值(expectile regression)替代max操作,避免OOD动作查询的离线到在线微调代码
隐式Q学习算法摒弃了在动作价值估计中显式取最大值的操作,转而通过期望分位数回归估计状态价值的期望水平。该方法将策略评估与策略提取解耦,首先学习一个能够表示最优状态价值分布的函数,随后通过优势加权回归从静态数据中提取策略,避免了直接查询分布外动作价值导致的估计误差,为离线到在线的连续学习提供了稳定的桥梁。
Python
复制
"""
IQL (Implicit Q-Learning) 实现脚本
涉及内容:期望分位数回归、优势加权行为克隆、两步策略提取、离线到在线微调
使用方式:python iql.py --dataset_path d4rl_walker-expert-v2.npz --expectile 0.7 --beta 3.0
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from torch.utils.data import Dataset, DataLoader
import argparse
class Value(nn.Module):
def __init__(self, state_dim, hidden_dim=256):
super(Value, self).__init__()
self.l1 = nn.Linear(state_dim, hidden_dim)
self.l2 = nn.Linear(hidden_dim, hidden_dim)
self.l3 = nn.Linear(hidden_dim, 1)
def forward(self, state):
v = F.relu(self.l1(state))
v = F.relu(self.l2(v))
v = self.l3(v)
return v
class QNetwork(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=256):
super(QNetwork, self).__init__()
self.l1 = nn.Linear(state_dim + action_dim, hidden_dim)
self.l2 = nn.Linear(hidden_dim, hidden_dim)
self.l3 = nn.Linear(hidden_dim, 1)
def forward(self, state, action):
sa = torch.cat([state, action], 1)
q = F.relu(self.l1(sa))
q = F.relu(self.l2(q))
q = self.l3(q)
return q
class GaussianPolicy(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=256):
super(GaussianPolicy, self).__init__()
self.l1 = nn.Linear(state_dim, hidden_dim)
self.l2 = nn.Linear(hidden_dim, hidden_dim)
self.mean = nn.Linear(hidden_dim, action_dim)
self.log_std = nn.Linear(hidden_dim, action_dim)
def forward(self, state):
a = F.relu(self.l1(state))
a = F.relu(self.l2(a))
mean = torch.tanh(self.mean(a))
log_std = self.log_std(a)
log_std = torch.clamp(log_std, -20, 2)
return mean, log_std
def sample(self, state):
mean, log_std = self.forward(state)
std = log_std.exp()
normal = torch.distributions.Normal(mean, std)
x_t = normal.rsample()
action = torch.tanh(x_t)
return action, mean
class OfflineDataset(Dataset):
def __init__(self, data_path):
data = np.load(data_path)
self.states = torch.FloatTensor(data['observations'])
self.actions = torch.FloatTensor(data['actions'])
self.next_states = torch.FloatTensor(data['next_observations'])
self.rewards = torch.FloatTensor(data['rewards']).unsqueeze(1)
self.dones = torch.FloatTensor(data['terminals']).unsqueeze(1)
def __len__(self):
return len(self.states)
def __getitem__(self, idx):
return (self.states[idx], self.actions[idx], self.next_states[idx],
self.rewards[idx], self.dones[idx])
class IQL:
def __init__(self, state_dim, action_dim, device, expectile=0.7, beta=3.0, tau=0.005):
self.device = device
self.expectile = expectile
self.beta = beta
self.tau = tau
self.value_net = Value(state_dim).to(device)
self.v_optimizer = optim.Adam(self.value_net.parameters(), lr=3e-4)
self.q_net = QNetwork(state_dim, action_dim).to(device)
self.q_target = QNetwork(state_dim, action_dim).to(device)
self.q_target.load_state_dict(self.q_net.state_dict())
self.q_optimizer = optim.Adam(self.q_net.parameters(), lr=3e-4)
self.policy = GaussianPolicy(state_dim, action_dim).to(device)
self.policy_optimizer = optim.Adam(self.policy.parameters(), lr=3e-4)
def expectile_loss(self, diff, expectile):
weight = torch.where(diff > 0, expectile, 1 - expectile)
return weight * (diff ** 2)
def update(self, batch):
state, action, next_state, reward, done = batch
state = state.to(self.device)
action = action.to(self.device)
next_state = next_state.to(self.device)
reward = reward.to(self.device)
done = done.to(self.device)
with torch.no_grad():
next_v = self.value_net(next_state)
target_q = reward + (1 - done) * 0.99 * next_v
current_q = self.q_net(state, action)
q_loss = F.mse_loss(current_q, target_q)
self.q_optimizer.zero_grad()
q_loss.backward()
self.q_optimizer.step()
# Update value network with expectile regression
with torch.no_grad():
q_pred = self.q_target(state, action)
v_pred = self.value_net(state)
v_loss = self.expectile_loss(q_pred - v_pred, self.expectile).mean()
self.v_optimizer.zero_grad()
v_loss.backward()
self.v_optimizer.step()
# Policy extraction via advantage weighted regression
with torch.no_grad():
v = self.value_net(state)
q = self.q_net(state, action)
adv = q - v
exp_adv = torch.exp(adv * self.beta)
exp_adv = torch.clamp(exp_adv, max=100.0)
action_pred, mean = self.policy.sample(state)
policy_loss = -((action_pred - action) ** 2).sum(dim=1, keepdim=True) * exp_adv
policy_loss = policy_loss.mean()
self.policy_optimizer.zero_grad()
policy_loss.backward()
self.policy_optimizer.step()
# Soft update target Q
for param, target_param in zip(self.q_net.parameters(), self.q_target.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_path", required=True)
parser.add_argument("--expectile", default=0.7, type=float)
parser.add_argument("--beta", default=3.0, type=float)
parser.add_argument("--seed", default=0, type=int)
args = parser.parse_args()
torch.manual_seed(args.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset = OfflineDataset(args.dataset_path)
dataloader = DataLoader(dataset, batch_size=256, shuffle=True)
state_dim = dataset.states.shape[1]
action_dim = dataset.actions.shape[1]
policy = IQL(state_dim, action_dim, device, expectile=args.expectile, beta=args.beta)
for epoch in range(1000):
for batch in dataloader:
policy.update(batch)
if epoch % 10 == 0:
print(f"Epoch {epoch}")
if __name__ == "__main__":
main()
2.2 模型基础强化学习(Model-Based RL)
2.2.1 环境动力学模型学习
-
2.2.1.1 PETS概率性模型:基于高斯混合网络(MDN)的集成模型(Ensemble)实现DON(Distillation via Online Network)对抗模型误差
概率性集成工具集合通过构建多个异构的动力学模型预测未来状态转移的不确定性分布,每个模型参数化一个高斯混合网络以捕捉环境随机性。集成方法通过模型预测的方差量化认知不确定性,在规划阶段利用模型预测误差的统计特性指导探索行为。蒸馏在线网络技术通过约束集成模型与蒸馏网络之间的预测差异,动态调整模型置信度,防止策略过度拟合特定动力学模型的偏差。
Python
复制
"""
PETS (Probabilistic Ensemble with Trajectory Sampling) 实现脚本
涉及内容:概率集成模型、高斯混合网络、MPC规划、粒子传播、模型不确定性估计
使用方式:python pets.py --env_pendulum --n_ensemble 5 --plan_horizon 20
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import gymnasium as gym
from typing import Tuple
import argparse
class GaussianMixtureNetwork(nn.Module):
def __init__(self, input_dim, output_dim, n_gaussians=1, hidden_dim=200):
super().__init__()
self.n_gaussians = n_gaussians
self.output_dim = output_dim
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, hidden_dim)
# Output: mean, log_std for each Gaussian
self.mean = nn.Linear(hidden_dim, output_dim * n_gaussians)
self.log_std = nn.Linear(hidden_dim, output_dim * n_gaussians)
self.logits = nn.Linear(hidden_dim, n_gaussians) # Mixture weights
def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.relu(self.fc3(x))
mean = self.mean(x).view(-1, self.n_gaussians, self.output_dim)
log_std = torch.clamp(self.log_std(x), min=-2, max=2)
std = torch.exp(log_std).view(-1, self.n_gaussians, self.output_dim)
logits = self.logits(x)
return mean, std, logits
def sample(self, x) -> torch.Tensor:
mean, std, logits = self.forward(x)
# Gumbel sampling for mixture weights
probs = F.softmax(logits, dim=-1)
component = torch.multinomial(probs, 1).squeeze(-1)
batch_idx = torch.arange(x.shape[0])
selected_mean = mean[batch_idx, component]
selected_std = std[batch_idx, component]
sample = selected_mean + selected_std * torch.randn_like(selected_mean)
return sample
class EnsembleDynamicsModel(nn.Module):
def __init__(self, state_dim, action_dim, n_ensemble=5):
super().__init__()
self.n_ensemble = n_ensemble
self.state_dim = state_dim
self.action_dim = action_dim
# Create ensemble of models
self.models = nn.ModuleList([
GaussianMixtureNetwork(state_dim + action_dim, state_dim + 1) # next_state + reward
for _ in range(n_ensemble)
])
def forward(self, state, action, model_idx=None):
x = torch.cat([state, action], dim=-1)
if model_idx is not None:
return self.models[model_idx](x)
# Return predictions from all models
means, stds, logits = [], [], []
for model in self.models:
m, s, l = model(x)
means.append(m)
stds.append(s)
logits.append(l)
return torch.stack(means), torch.stack(stds), torch.stack(logits)
def predict(self, state, action):
"""Predict next state and reward with uncertainty"""
x = torch.cat([state, action], dim=-1)
predictions = []
for model in self.models:
mean, std, _ = model(x)
# Sample from the Gaussian
sample = mean[:, 0] + std[:, 0] * torch.randn_like(mean[:, 0])
predictions.append(sample)
predictions = torch.stack(predictions) # [n_ensemble, batch, state_dim+1]
mean_pred = predictions.mean(dim=0)
std_pred = predictions.std(dim=0)
return mean_pred, std_pred
class MPCPlanner:
def __init__(self, model, action_dim, horizon=20, n_samples=1000, top_k=100):
self.model = model
self.action_dim = action_dim
self.horizon = horizon
self.n_samples = n_samples
self.top_k = top_k
self.action_mean = np.zeros(action_dim)
self.action_std = np.ones(action_dim)
def plan(self, state, goal=None):
# Cross-Entropy Method (CEM) for action sequence optimization
state_batch = torch.FloatTensor(state).unsqueeze(0).repeat(self.n_samples, 1)
for _ in range(5): # CEM iterations
# Sample action sequences
action_sequences = []
for t in range(self.horizon):
actions = np.random.normal(self.action_mean, self.action_std,
(self.n_samples, self.action_dim))
action_sequences.append(torch.FloatTensor(actions))
# Rollout trajectories
returns = torch.zeros(self.n_samples)
current_state = state_batch.clone()
for t in range(self.horizon):
action = action_sequences[t]
with torch.no_grad():
pred, _ = self.model.predict(current_state, action)
next_state_delta = pred[:, :-1] # State change
reward = pred[:, -1] # Predicted reward
current_state = current_state + next_state_delta
returns += (0.95 ** t) * reward
# Select elite actions
elite_indices = torch.topk(returns, self.top_k).indices
elite_actions = action_sequences[0][elite_indices]
# Update action distribution
self.action_mean = elite_actions.mean(dim=0).numpy()
self.action_std = elite_actions.std(dim=0).numpy() + 0.01
return self.action_mean
class PETSAgent:
def __init__(self, state_dim, action_dim, n_ensemble=5, horizon=20):
self.model = EnsembleDynamicsModel(state_dim, action_dim, n_ensemble)
self.planner = MPCPlanner(self.model, action_dim, horizon)
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)
def train(self, buffer, batch_size=256, epochs=100):
for epoch in range(epochs):
# Sample from buffer (assumed format: [state, action, next_state, reward])
indices = np.random.choice(len(buffer), batch_size)
states = torch.FloatTensor(np.array([buffer[i][0] for i in indices]))
actions = torch.FloatTensor(np.array([buffer[i][1] for i in indices]))
next_states = torch.FloatTensor(np.array([buffer[i][2] for i in indices]))
rewards = torch.FloatTensor(np.array([buffer[i][3] for i in indices])).unsqueeze(1)
targets = torch.cat([next_states - states, rewards], dim=1)
loss = 0
for i, model in enumerate(self.model.models):
mean, std, logits = model(torch.cat([states, actions], dim=1))
# Negative log-likelihood
log_prob = -0.5 * (((targets.unsqueeze(1) - mean) / std) ** 2 +
torch.log(std ** 2))
loss += -log_prob.mean()
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
def act(self, state):
return self.planner.plan(state)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--env", default="Pendulum-v1")
parser.add_argument("--n_ensemble", default=5, type=int)
parser.add_argument("--plan_horizon", default=20, type=int)
args = parser.parse_args()
env = gym.make(args.env)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
agent = PETSAgent(state_dim, action_dim, args.n_ensemble, args.plan_horizon)
# Initial random data collection
buffer = []
state, _ = env.reset()
for _ in range(1000):
action = env.action_space.sample()
next_state, reward, terminated, truncated, _ = env.step(action)
buffer.append([state, action, next_state, reward])
state = next_state if not (terminated or truncated) else env.reset()[0]
# Train model
agent.train(buffer)
# MPC control loop
state, _ = env.reset()
for step in range(1000):
action = agent.act(state)
next_state, reward, terminated, truncated, _ = env.step(action)
buffer.append([state, action, next_state, reward])
state = next_state
if terminated or truncated:
state, _ = env.reset()
if __name__ == "__main__":
main()
-
2.2.1.2 MBPO(Model-Based Policy Optimization):基于模型的虚拟 rollouts 与真实数据混合比例(real-to-sim ratio=0.95)的采样器实现
基于模型的策略优化算法通过训练深度神经网络近似环境转移动力学,利用学习到的模型生成虚拟轨迹以扩展训练数据集。该方法采用短视域模型预测与真实数据混合训练的策略,通过控制虚拟轨迹与真实交互的比例,在保持样本效率的同时抑制模型复合误差的累积。集成模型的使用进一步通过预测方差过滤不可靠的合成样本,确保策略优化过程中使用的状态转移估计满足精度要求。
Python
复制
"""
MBPO (Model-Based Policy Optimization) 实现脚本
涉及内容:集成动力学模型、模型生成rollout、虚拟与真实数据混合、SAC智能体集成
使用方式:python mbpo.py --env HalfCheetah-v4 --real_ratio 0.05 --rollout_length 1
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import gymnasium as gym
from collections import deque
import argparse
class EnsembleModel(nn.Module):
def __init__(self, state_dim, action_dim, hidden_size=200, num_networks=7):
super().__init__()
self.num_networks = num_networks
self.state_dim = state_dim
self.action_dim = action_dim
self.ensemble = nn.ModuleList([
nn.Sequential(
nn.Linear(state_dim + action_dim, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, state_dim + 1) # next_state_delta + reward
) for _ in range(num_networks)
])
self.max_logvar = nn.Parameter(torch.ones(1, state_dim + 1) / 2.0)
self.min_logvar = nn.Parameter(-torch.ones(1, state_dim + 1) * 10)
def forward(self, state, action):
x = torch.cat([state, action], dim=-1)
outputs = [model(x) for model in self.ensemble]
outputs = torch.stack(outputs) # [num_networks, batch, state_dim+1]
mean = outputs.mean(dim=0)
log_var = torch.clamp(outputs.var(dim=0), min=1e-8).log()
return mean, log_var
def loss(self, state, action, next_state, reward):
target = torch.cat([next_state - state, reward.unsqueeze(-1)], dim=-1)
mean, log_var = self.forward(state, action)
inv_var = torch.exp(-log_var)
loss = ((mean - target.detach()) ** 2) * inv_var + log_var
return loss.mean()
class ModelBuffer:
def __init__(self, capacity=int(1e6)):
self.buffer = deque(maxlen=capacity)
def add(self, state, action, next_state, reward, done):
self.buffer.append([state, action, next_state, reward, done])
def sample(self, batch_size):
indices = np.random.choice(len(self.buffer), batch_size, replace=False)
batch = [self.buffer[i] for i in indices]
states = torch.FloatTensor(np.array([x[0] for x in batch]))
actions = torch.FloatTensor(np.array([x[1] for x in batch]))
next_states = torch.FloatTensor(np.array([x[2] for x in batch]))
rewards = torch.FloatTensor(np.array([x[3] for x in batch]))
dones = torch.FloatTensor(np.array([x[4] for x in batch]))
return states, actions, next_states, rewards, dones
def __len__(self):
return len(self.buffer)
class SAC:
# Simplified SAC implementation for MBPO integration
def __init__(self, state_dim, action_dim, device):
self.device = device
self.actor = nn.Sequential(
nn.Linear(state_dim, 256), nn.ReLU(),
nn.Linear(256, 256), nn.ReLU(),
nn.Linear(256, action_dim), nn.Tanh()
).to(device)
self.critic1 = nn.Sequential(
nn.Linear(state_dim + action_dim, 256), nn.ReLU(),
nn.Linear(256, 256), nn.ReLU(),
nn.Linear(256, 1)
).to(device)
self.critic2 = nn.Sequential(
nn.Linear(state_dim + action_dim, 256), nn.ReLU(),
nn.Linear(256, 256), nn.ReLU(),
nn.Linear(256, 1)
).to(device)
self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=3e-4)
self.critic_opt = torch.optim.Adam(
list(self.critic1.parameters()) + list(self.critic2.parameters()), lr=3e-4
)
def select_action(self, state, noise=0.1):
state = torch.FloatTensor(state).to(self.device)
action = self.actor(state).detach().cpu().numpy()
if noise > 0:
action += np.random.normal(0, noise, size=action.shape)
return np.clip(action, -1, 1)
def update(self, buffer, batch_size=256, real_ratio=0.05):
# Sample from both real and model buffer
n_real = int(batch_size * real_ratio)
n_model = batch_size - n_real
if len(buffer) < n_model:
return
# Assume buffer provides both real and model data
# In practice, MBPO maintains separate buffers
states, actions, next_states, rewards, dones = buffer.sample(batch_size)
states = states.to(self.device)
actions = actions.to(self.device)
next_states = next_states.to(self.device)
rewards = rewards.to(self.device)
dones = dones.to(self.device)
# Standard SAC update
with torch.no_grad():
next_actions = self.actor(next_states)
target_q = torch.min(
self.critic1(torch.cat([next_states, next_actions], 1)),
self.critic2(torch.cat([next_states, next_actions], 1))
)
target_q = rewards.unsqueeze(1) + 0.99 * (1 - dones.unsqueeze(1)) * target_q
current_q1 = self.critic1(torch.cat([states, actions], 1))
current_q2 = self.critic2(torch.cat([states, actions], 1))
critic_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q)
self.critic_opt.zero_grad()
critic_loss.backward()
self.critic_opt.step()
actor_loss = -self.critic1(torch.cat([states, self.actor(states)], 1)).mean()
self.actor_opt.zero_grad()
actor_loss.backward()
self.actor_opt.step()
class MBPO:
def __init__(self, state_dim, action_dim, device,
rollout_length=1, rollout_batch_size=50000,
real_ratio=0.05, ensemble_size=7):
self.device = device
self.model = EnsembleModel(state_dim, action_dim, num_networks=ensemble_size).to(device)
self.model_opt = torch.optim.Adam(self.model.parameters(), lr=1e-3)
self.policy = SAC(state_dim, action_dim, device)
self.real_buffer = ModelBuffer()
self.model_buffer = ModelBuffer()
self.rollout_length = rollout_length
self.rollout_batch_size = rollout_batch_size
self.real_ratio = real_ratio
self.ensemble_size = ensemble_size
def train_model(self, batch_size=256, epochs=100):
for epoch in range(epochs):
if len(self.real_buffer) < batch_size:
continue
states, actions, next_states, rewards, _ = self.real_buffer.sample(batch_size)
states = states.to(self.device)
actions = actions.to(self.device)
next_states = next_states.to(self.device)
rewards = rewards.to(self.device)
loss = self.model.loss(states, actions, next_states, rewards)
self.model_opt.zero_grad()
loss.backward()
self.model_opt.step()
def generate_rollouts(self):
# Sample initial states from real buffer
init_states, _, _, _, _ = self.real_buffer.sample(self.rollout_batch_size)
init_states = init_states.to(self.device)
states = init_states
for _ in range(self.rollout_length):
actions = torch.FloatTensor(self.policy.select_action(states.cpu().numpy(), noise=0.2)).to(self.device)
with torch.no_grad():
mean, log_var = self.model(states, actions)
std = torch.exp(0.5 * log_var)
next_state_delta = mean[:, :-1] + std[:, :-1] * torch.randn_like(mean[:, :-1])
rewards = mean[:, -1] + std[:, -1] * torch.randn_like(mean[:, -1])
next_states = states + next_state_delta
# Add to model buffer
for i in range(len(states)):
self.model_buffer.add(
states[i].cpu().numpy(),
actions[i].cpu().numpy(),
next_states[i].cpu().numpy(),
rewards[i].cpu().numpy(),
False
)
states = next_states
def update_policy(self):
# Combine real and model buffer
total_size = min(len(self.real_buffer) + len(self.model_buffer), 256)
n_real = int(total_size * self.real_ratio)
n_model = total_size - n_real
if len(self.real_buffer) < n_real or len(self.model_buffer) < n_model:
return
real_samples = self.real_buffer.sample(n_real)
model_samples = self.model_buffer.sample(n_model)
# Merge samples
states = torch.cat([real_samples[0], model_samples[0]], 0).to(self.device)
actions = torch.cat([real_samples[1], model_samples[1]], 0).to(self.device)
next_states = torch.cat([real_samples[2], model_samples[2]], 0).to(self.device)
rewards = torch.cat([real_samples[3], model_samples[3]], 0).to(self.device)
dones = torch.cat([real_samples[4], model_samples[4]], 0).to(self.device)
# Update SAC with mixed batch
class MixedBuffer:
def __init__(self, s, a, ns, r, d):
self.batch = (s, a, ns, r, d)
def sample(self, size):
return self.batch
mixed = MixedBuffer(states, actions, next_states, rewards, dones)
self.policy.update(mixed, batch_size=len(states), real_ratio=1.0) # Already mixed
def run(self, env, total_steps=300000):
state, _ = env.reset()
episode_step = 0
for step in range(total_steps):
action = self.policy.select_action(state)
next_state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
self.real_buffer.add(state, action, next_state, reward, float(done))
state = next_state
episode_step += 1
if done:
state, _ = env.reset()
episode_step = 0
# Model training every 250 steps
if step > 5000 and step % 250 == 0:
self.train_model(epochs=5)
self.generate_rollouts()
if step > 5000:
self.update_policy()
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--env", default="HalfCheetah-v4")
parser.add_argument("--real_ratio", default=0.05, type=float)
parser.add_argument("--rollout_length", default=1, type=int)
args = parser.parse_args()
env = gym.make(args.env)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
agent = MBPO(state_dim, action_dim, device,
rollout_length=args.rollout_length,
real_ratio=args.real_ratio)
agent.run(env)
if __name__ == "__main__":
main()
2.2.2 潜在空间规划(Latent Space Planning)
-
2.2.2.1 PlaNet/Dreamer架构:RSSM(Recurrent State-Space Model)的确定性路径与随机先验分解,KL散度损失加权(β=1.0)
循环状态空间模型通过将环境动态分解为确定性的循环路径与随机性的初始状态先验,在高维观测空间中学习紧凑的潜在表征。该架构利用循环神经网络捕获时间依赖性,同时通过变分推断优化潜在状态分布与观测重建之间的平衡。变分自由能的分解将模型学习、表示学习与策略优化统一在单一目标函数下,允许智能体在潜在空间中进行规划而无需原始高维图像输入。
Python
复制
"""
PlaNet/Dreamer 实现脚本
涉及内容:RSSM循环状态空间模型、变分推断、潜在空间规划、图像重建、Actor-Critic在 latent space
使用方式:python dreamer.py --env_dmc walker_walk --image_size 64 --latent_dim 30
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import gymnasium as gym
from typing import Dict, List, Tuple
import argparse
class RSSM(nn.Module):
def __init__(self, stochastic_dim=30, deterministic_dim=200, hidden_dim=200):
super().__init__()
self.stochastic_dim = stochastic_dim
self.deterministic_dim = deterministic_dim
# Recurrent cell
self.gru = nn.GRUCell(stochastic_dim + hidden_dim, deterministic_dim)
# Prior p(s_t | h_t)
self.fc_prior = nn.Linear(deterministic_dim, stochastic_dim * 2)
# Posterior q(s_t | h_t, x_t)
self.fc_posterior = nn.Linear(deterministic_dim + hidden_dim, stochastic_dim * 2)
def init_state(self, batch_size, device):
h = torch.zeros(batch_size, self.deterministic_dim).to(device)
s = torch.zeros(batch_size, self.stochastic_dim).to(device)
return h, s
def imagine_step(self, h_prev, action):
# Prior: predict next stochastic state from deterministic path
x = torch.cat([h_prev, action], dim=-1)
h = self.gru(x, h_prev)
prior_mean, prior_std_logit = torch.chunk(self.fc_prior(h), 2, dim=-1)
prior_std = F.softplus(prior_std_logit) + 0.1
s = prior_mean + prior_std * torch.randn_like(prior_mean)
return h, s, prior_mean, prior_std
def observe_step(self, h_prev, s_prev, action, embedding):
# Posterior: incorporate observation
x = torch.cat([h_prev, action], dim=-1)
h = self.gru(x, h_prev)
posterior_input = torch.cat([h, embedding], dim=-1)
post_mean, post_std_logit = torch.chunk(self.fc_posterior(posterior_input), 2, dim=-1)
post_std = F.softplus(post_std_logit) + 0.1
s = post_mean + post_std * torch.randn_like(post_mean)
return h, s, post_mean, post_std
class ObservationEncoder(nn.Module):
def __init__(self, image_channels=3, depth=32, latent_dim=200):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(image_channels, 1*depth, 4, stride=2),
nn.ReLU(),
nn.Conv2d(1*depth, 2*depth, 4, stride=2),
nn.ReLU(),
nn.Conv2d(2*depth, 4*depth, 4, stride=2),
nn.ReLU(),
nn.Conv2d(4*depth, 8*depth, 4, stride=2),
nn.ReLU(),
nn.Flatten()
)
self.out_dim = 8*depth * 2 * 2 # Assuming 64x64 input -> 2x2
self.fc = nn.Linear(self.out_dim, latent_dim)
def forward(self, x):
x = self.net(x)
return self.fc(x)
class ObservationDecoder(nn.Module):
def __init__(self, state_dim, depth=32, image_channels=3):
super().__init__()
self.fc = nn.Linear(state_dim, 8*depth * 4 * 4)
self.net = nn.Sequential(
nn.ConvTranspose2d(8*depth, 4*depth, 5, stride=2),
nn.ReLU(),
nn.ConvTranspose2d(4*depth, 2*depth, 5, stride=2),
nn.ReLU(),
nn.ConvTranspose2d(2*depth, 1*depth, 6, stride=2),
nn.ReLU(),
nn.ConvTranspose2d(1*depth, image_channels, 6, stride=2),
)
def forward(self, state):
x = self.fc(state).view(-1, 32*8, 4, 4)
return self.net(x)
class Actor(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=200):
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, action_dim),
nn.Tanh()
)
def forward(self, state):
return self.net(state)
class ValueModel(nn.Module):
def __init__(self, state_dim, hidden_dim=200):
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
def forward(self, state):
return self.net(state)
class Dreamer:
def __init__(self, image_channels, action_dim, device,
stochastic_dim=30, deterministic_dim=200):
self.device = device
self.action_dim = action_dim
self.encoder = ObservationEncoder(image_channels, latent_dim=200).to(device)
self.rssm = RSSM(stochastic_dim, deterministic_dim, hidden_dim=200).to(device)
self.decoder = ObservationDecoder(stochastic_dim + deterministic_dim).to(device)
self.reward_model = nn.Sequential(
nn.Linear(stochastic_dim + deterministic_dim, 200),
nn.ReLU(),
nn.Linear(200, 1)
).to(device)
self.actor = Actor(stochastic_dim + deterministic_dim, action_dim).to(device)
self.value_model = ValueModel(stochastic_dim + deterministic_dim).to(device)
self.models = [self.encoder, self.rssm, self.decoder, self.reward_model]
self.behavior = [self.actor, self.value_model]
self.model_opt = optim.Adam(
list(self.encoder.parameters()) +
list(self.rssm.parameters()) +
list(self.decoder.parameters()) +
list(self.reward_model.parameters()),
lr=1e-3
)
self.actor_opt = optim.Adam(self.actor.parameters(), lr=8e-5)
self.value_opt = optim.Adam(self.value_model.parameters(), lr=8e-5)
def preprocess(self, obs):
# Normalize images to [-0.5, 0.5]
return obs / 255.0 - 0.5
def update_world_model(self, observations, actions, rewards):
# observations: [batch, seq, channels, height, width]
# actions: [batch, seq, action_dim]
batch_size, seq_len = observations.shape[:2]
# Embed observations
flat_obs = observations.reshape(-1, *observations.shape[2:])
embeddings = self.encoder(flat_obs).view(batch_size, seq_len, -1)
# Initialize state
h, s = self.rssm.init_state(batch_size, self.device)
kl_loss = 0
recon_loss = 0
reward_loss = 0
states = []
for t in range(seq_len):
action = actions[:, t]
emb = embeddings[:, t]
# Observe step
h, s, post_mean, post_std = self.rssm.observe_step(h, s, action, emb)
# Compute KL divergence between posterior and prior
prior_mean, prior_std_logit = torch.chunk(self.rssm.fc_prior(h), 2, dim=-1)
prior_std = F.softplus(prior_std_logit) + 0.1
kl = torch.distributions.kl_divergence(
torch.distributions.Normal(post_mean, post_std),
torch.distributions.Normal(prior_mean, prior_std)
).sum(dim=-1)
kl_loss += kl.mean()
# Reconstruction
state = torch.cat([h, s], dim=-1)
recon = self.decoder(state)
recon_loss += F.mse_loss(recon, observations[:, t])
# Reward prediction
pred_reward = self.reward_model(state)
reward_loss += F.mse_loss(pred_reward, rewards[:, t].unsqueeze(-1))
states.append(state)
total_loss = recon_loss + reward_loss + 0.1 * kl_loss
self.model_opt.zero_grad()
total_loss.backward()
torch.nn.utils.clip_grad_norm_(self.model_opt.param_groups[0]['params'], 100.0)
self.model_opt.step()
return torch.stack(states, dim=1)
def imagine_trajectories(self, initial_states, horizon=15):
# Roll out policy in imagination
states = [initial_states]
rewards = []
values = []
h = initial_states[:, :self.rssm.deterministic_dim]
s = initial_states[:, self.rssm.deterministic_dim:]
for _ in range(horizon):
state = torch.cat([h, s], dim=-1)
action = self.actor(state)
# Imagine next state
h, s, _, _ = self.rssm.imagine_step(h, action)
next_state = torch.cat([h, s], dim=-1)
pred_reward = self.reward_model(next_state)
pred_value = self.value_model(next_state)
states.append(next_state)
rewards.append(pred_reward)
values.append(pred_value)
return torch.stack(states, dim=1), torch.stack(rewards, dim=1), torch.stack(values, dim=1)
def update_behavior(self, initial_states):
states, rewards, values = self.imagine_trajectories(initial_states)
# Lambda-return calculation
lambda_ = 0.95
discounts = torch.ones_like(rewards) * 0.99
# Compute returns
returns = []
last_value = values[:, -1]
for t in reversed(range(rewards.shape[1])):
if t == rewards.shape[1] - 1:
ret = rewards[:, t] + discounts[:, t] * last_value
else:
ret = rewards[:, t] + discounts[:, t] * (lambda_ * returns[0] + (1 - lambda_) * values[:, t+1])
returns.insert(0, ret)
returns = torch.stack(returns, dim=1)
# Actor loss: maximize expected return
actor_loss = -returns.mean()
# Value loss
value_pred = self.value_model(states[:, :-1].reshape(-1, states.shape[-1])).view(states.shape[0], states.shape[1]-1, 1)
value_loss = F.mse_loss(value_pred, returns.detach())
self.actor_opt.zero_grad()
actor_loss.backward()
self.actor_opt.step()
self.value_opt.zero_grad()
value_loss.backward()
self.value_opt.step()
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--env", default="Walker2d-v4")
parser.add_argument("--image_size", default=64, type=int)
parser.add_argument("--latent_dim", default=30, type=int)
parser.add_argument("--batch_size", default=50, type=int)
parser.add_argument("--seq_len", default=50, type=int)
args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Note: For image-based control, use DMC or custom wrapper
# This example uses state-based as placeholder for image logic
env = gym.make(args.env)
action_dim = env.action_space.shape[0]
agent = Dreamer(3, action_dim, device, stochastic_dim=args.latent_dim)
# Simulated experience buffer (in practice, collect from environment)
buffer = []
state, _ = env.reset()
for _ in range(1000):
action = env.action_space.sample()
next_state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
buffer.append((state, action, next_state, reward, done))
state = next_state if not done else env.reset()[0]
# Training loop
for epoch in range(1000):
# Sample sequence batch
batch_idx = np.random.choice(len(buffer)-args.seq_len, args.batch_size)
obs_seq = torch.randn(args.batch_size, args.seq_len, 3, 64, 64) # Placeholder for images
act_seq = torch.randn(args.batch_size, args.seq_len, action_dim)
rew_seq = torch.randn(args.batch_size, args.seq_len)
states = agent.update_world_model(obs_seq.to(device), act_seq.to(device), rew_seq.to(device))
agent.update_behavior(states[:, 0])
if __name__ == "__main__":
main()
-
2.2.2.2 交叉熵方法(CEM)在图像输入中的动作序列优化:通过迭代采样(精英比例top-k=10%)规划未来H步动作序列
交叉熵方法在潜在空间规划中通过迭代优化动作序列分布,将最优控制问题转化为概率推断任务。该方法从当前动作分布中采样多条候选轨迹,评估其在习得的环境模型中的累积回报,随后根据精英样本的统计特性更新分布参数。在图像输入环境下,视觉编码器将高维观测压缩至紧凑潜在表示,规划器在该低维空间中进行蒙特卡洛树搜索,通过有限步长的前瞻计算选择最优动作。
Python
复制
"""
CEM (Cross Entropy Method) 在图像输入中的动作序列优化
涉及内容:图像编码器、动作序列优化、迭代分布更新、精英采样、模型预测控制
使用方式:python cem_image.py --env cartpole_swingup --horizon 20 --n_samples 1000
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import gymnasium as gym
import argparse
class ImageEncoder(nn.Module):
def __init__(self, latent_dim=50):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(3, 32, 3, stride=2), nn.ReLU(),
nn.Conv2d(32, 64, 3, stride=2), nn.ReLU(),
nn.Conv2d(64, 128, 3, stride=2), nn.ReLU(),
nn.Flatten()
)
self.fc = nn.Linear(128 * 7 * 7, latent_dim) # Assuming 64x64 input
def forward(self, x):
x = self.conv(x)
return self.fc(x)
class SimpleDynamics(nn.Module):
def __init__(self, latent_dim, action_dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(latent_dim + action_dim, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
nn.Linear(256, latent_dim + 1) # next_latent + reward
)
def forward(self, latent, action):
x = torch.cat([latent, action], dim=-1)
out = self.net(x)
next_latent = out[:, :-1]
reward = out[:, -1]
return next_latent, reward
class CEMPlanner:
def __init__(self, encoder, dynamics, action_dim, horizon=20,
n_samples=1000, elite_ratio=0.1, iterations=5):
self.encoder = encoder
self.dynamics = dynamics
self.action_dim = action_dim
self.horizon = horizon
self.n_samples = n_samples
self.elite_ratio = elite_ratio
self.iterations = iterations
def plan(self, image_obs, goal_latent=None):
# Encode current observation
with torch.no_grad():
current_latent = self.encoder(image_obs.unsqueeze(0))
# Initialize action distribution
mean = torch.zeros(self.horizon, self.action_dim)
std = torch.ones(self.horizon, self.action_dim)
for _ in range(self.iterations):
# Sample action sequences
actions = torch.normal(mean.unsqueeze(0).repeat(self.n_samples, 1, 1),
std.unsqueeze(0).repeat(self.n_samples, 1, 1))
actions = torch.clamp(actions, -1, 1)
# Evaluate sequences
returns = self.evaluate_sequences(current_latent.repeat(self.n_samples, 1), actions, goal_latent)
# Select elites
n_elites = int(self.n_samples * self.elite_ratio)
elite_indices = torch.topk(returns, n_elites).indices
elite_actions = actions[elite_indices]
# Update distribution
mean = elite_actions.mean(dim=0)
std = elite_actions.std(dim=0) + 0.01
return mean[0].numpy() # Return first action
def evaluate_sequences(self, initial_latents, action_sequences, goal_latent=None):
batch_size = action_sequences.shape[0]
latents = initial_latents.repeat(batch_size, 1)
total_rewards = torch.zeros(batch_size)
for t in range(self.horizon):
actions = action_sequences[:, t]
with torch.no_grad():
next_latents, rewards = self.dynamics(latents, actions)
if goal_latent is not None:
# Distance to goal
rewards = -torch.norm(next_latents - goal_latent, dim=1)
total_rewards += (0.99 ** t) * rewards
latents = next_latents
return total_rewards
class CEMAgent:
def __init__(self, state_shape, action_dim, device):
self.device = device
self.latent_dim = 50
self.encoder = ImageEncoder(self.latent_dim).to(device)
self.dynamics = SimpleDynamics(self.latent_dim, action_dim).to(device)
self.planner = CEMPlanner(self.encoder, self.dynamics, action_dim)
self.optimizer = torch.optim.Adam(
list(self.encoder.parameters()) + list(self.dynamics.parameters()),
lr=1e-3
)
def train(self, buffer, epochs=100):
# Buffer contains (image, action, next_image, reward)
for epoch in range(epochs):
batch = buffer.sample(32)
images = torch.FloatTensor(batch['images']).to(self.device)
actions = torch.FloatTensor(batch['actions']).to(self.device)
next_images = torch.FloatTensor(batch['next_images']).to(self.device)
rewards = torch.FloatTensor(batch['rewards']).to(self.device)
# Encode
latents = self.encoder(images)
next_latents_pred, pred_rewards = self.dynamics(latents, actions)
# Reconstruction loss for next image (simplified)
with torch.no_grad():
next_latents_target = self.encoder(next_images)
loss = F.mse_loss(next_latents_pred, next_latents_target) + \
F.mse_loss(pred_rewards, rewards)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
def act(self, image):
image_tensor = torch.FloatTensor(image).unsqueeze(0).to(self.device)
action = self.planner.plan(image_tensor)
return action
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--env", default="CartPole-v1")
parser.add_argument("--horizon", default=20, type=int)
parser.add_argument("--n_samples", default=1000, type=int)
parser.add_argument("--elite_ratio", default=0.1, type=float)
args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
env = gym.make(args.env)
# Assume image observation space
action_dim = env.action_space.shape[0]
agent = CEMAgent((3, 64, 64), action_dim, device)
# Training phase with random exploration
buffer = {'images': [], 'actions': [], 'next_images': [], 'rewards': []}
obs, _ = env.reset()
for step in range(5000):
action = env.action_space.sample()
next_obs, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
# Store (simplified, assume obs is already image)
buffer['images'].append(obs)
buffer['actions'].append(action)
buffer['next_images'].append(next_obs)
buffer['rewards'].append(reward)
obs = next_obs if not done else env.reset()[0]
# Train model
agent.train(buffer)
# Deploy with CEM planning
obs, _ = env.reset()
for step in range(1000):
action = agent.act(obs)
next_obs, reward, terminated, truncated, _ = env.step(action)
if terminated or truncated:
obs, _ = env.reset()
else:
obs = next_obs
if __name__ == "__main__":
main()
2.3 分布式训练系统架构
2.3.1 IMPALA与SEED的工业级实现
-
2.3.1.1 演员-学习者(Actor-Learner)分离架构:RingBuffer实现异步数据传输,参数服务器(Parameter Server)的Ray分布式实现
重要性加权演员-学习者架构通过解耦环境交互与策略优化过程,利用分布式采样节点并行生成经验序列,经由环形缓冲区实现异步数据传输。参数服务器维护全局网络权重的集中副本,采用异步梯度更新机制允许学习节点在无需锁定的条件下读取过时参数,通过梯度累积的延迟容忍特性最大化硬件吞吐量。该架构在异构计算集群上实现了采样与训练的计算重叠,显著降低了中央处理单元与图形处理器之间的通信瓶颈。
Python
复制
"""
IMPALA/SEED 分布式架构实现脚本
涉及内容:Actor-Learner分离、RingBuffer异步传输、Ray分布式参数服务器、异步梯度更新
使用方式:python impala_seed.py --num_actors 10 --learner_gpus 1
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import gymnasium as gym
import ray
from collections import deque
import threading
import queue
import argparse
ray.init(ignore_reinit_error=True)
@ray.remote
class ParameterServer:
def __init__(self, state_dim, action_dim, hidden_dim=256):
self.policy = nn.Sequential(
nn.Linear(state_dim, hidden_dim), nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
nn.Linear(hidden_dim, action_dim), nn.Softmax(dim=-1)
)
self.value = nn.Sequential(
nn.Linear(state_dim, hidden_dim), nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
self.parameters = {
'policy': self.policy.state_dict(),
'value': self.value.state_dict()
}
self.version = 0
def get_parameters(self):
return self.parameters, self.version
def update_parameters(self, gradients, version):
# Hogwild! style update
with torch.no_grad():
for name, param in self.policy.named_parameters():
if name in gradients['policy']:
param.data -= 0.001 * gradients['policy'][name]
for name, param in self.value.named_parameters():
if name in gradients['value']:
param.data -= 0.001 * gradients['value'][name]
self.version += 1
return self.version
class RingBuffer:
def __init__(self, capacity=1000):
self.capacity = capacity
self.buffer = [None] * capacity
self.write_idx = 0
self.read_idx = 0
self.lock = threading.Lock()
self.not_empty = threading.Condition(self.lock)
self.not_full = threading.Condition(self.lock)
def write(self, data):
with self.not_full:
while (self.write_idx + 1) % self.capacity == self.read_idx:
self.not_full.wait()
self.buffer[self.write_idx] = data
self.write_idx = (self.write_idx + 1) % self.capacity
self.not_empty.notify()
def read(self):
with self.not_empty:
while self.write_idx == self.read_idx:
self.not_empty.wait()
data = self.buffer[self.read_idx]
self.read_idx = (self.read_idx + 1) % self.capacity
self.not_full.notify()
return data
def read_batch(self, batch_size):
batch = []
for _ in range(batch_size):
batch.append(self.read())
return batch
@ray.remote
class Actor:
def __init__(self, env_name, ps_handle, actor_id):
self.env = gym.make(env_name)
self.ps = ps_handle
self.actor_id = actor_id
self.device = torch.device("cpu")
self.policy = None
self.version = 0
def run(self, buffer_handle, num_steps=100000):
buffer = ray.get(buffer_handle)
local_buffer = []
for step in range(num_steps):
# Fetch latest parameters every N steps
if step % 400 == 0:
params, version = ray.get(self.ps.get_parameters.remote())
if self.policy is None:
state_dim = self.env.observation_space.shape[0]
action_dim = self.env.action_space.n
self.policy = nn.Sequential(
nn.Linear(state_dim, 256), nn.ReLU(),
nn.Linear(256, action_dim), nn.Softmax(dim=-1)
)
self.policy.load_state_dict(params['policy'])
self.version = version
state, _ = self.env.reset()
done = False
episode_data = []
while not done:
with torch.no_grad():
state_tensor = torch.FloatTensor(state).unsqueeze(0)
action_probs = self.policy(state_tensor)
action = torch.multinomial(action_probs, 1).item()
next_state, reward, terminated, truncated, _ = self.env.step(action)
done = terminated or truncated
episode_data.append((state, action, reward, next_state, done))
state = next_state
if len(episode_data) >= 20: # Send trajectory
ray.get(buffer.write.remote(episode_data))
episode_data = []
return self.actor_id
class Learner:
def __init__(self, state_dim, action_dim, ps_handle, device):
self.device = device
self.ps = ps_handle
self.policy = nn.Sequential(
nn.Linear(state_dim, 256), nn.ReLU(),
nn.Linear(256, 256), nn.ReLU(),
nn.Linear(256, action_dim), nn.Softmax(dim=-1)
).to(device)
self.value = nn.Sequential(
nn.Linear(state_dim, 256), nn.ReLU(),
nn.Linear(256, 256), nn.ReLU(),
nn.Linear(256, 1)
).to(device)
self.optimizer = torch.optim.Adam(
list(self.policy.parameters()) + list(self.value.parameters()),
lr=1e-3
)
self.gamma = 0.99
self.clip_rho = 1.0
def compute_gradients(self, batch):
states = torch.FloatTensor(np.array([t[0] for traj in batch for t in traj])).to(self.device)
actions = torch.LongTensor(np.array([t[1] for traj in batch for t in traj])).to(self.device)
rewards = torch.FloatTensor(np.array([t[2] for traj in batch for t in traj])).to(self.device)
next_states = torch.FloatTensor(np.array([t[3] for traj in batch for t in traj])).to(self.device)
dones = torch.FloatTensor(np.array([t[4] for traj in batch for t in traj])).to(self.device)
# Compute action probabilities
action_probs = self.policy(states)
log_probs = torch.log(action_probs.gather(1, actions.unsqueeze(1)) + 1e-10)
with torch.no_grad():
values = self.value(states).squeeze()
next_values = self.value(next_states).squeeze()
td_targets = rewards + self.gamma * next_values * (1 - dones)
advantages = td_targets - values
# Policy gradient
policy_loss = -(log_probs.squeeze() * advantages).mean()
# Value loss
value_loss = F.mse_loss(self.value(states).squeeze(), td_targets)
loss = policy_loss + 0.5 * value_loss
self.optimizer.zero_grad()
loss.backward()
# Extract gradients
grads = {
'policy': {name: param.grad.clone() for name, param in self.policy.named_parameters()},
'value': {name: param.grad.clone() for name, param in self.value.named_parameters()}
}
return grads
def train(self, buffer_handle, num_updates=10000):
buffer = ray.get(buffer_handle)
for i in range(num_updates):
batch = ray.get(buffer.read_batch.remote(32))
gradients = self.compute_gradients(batch)
new_version = ray.get(self.ps.update_parameters.remote(gradients, i))
# Sync local parameters
params, _ = ray.get(self.ps.get_parameters.remote())
self.policy.load_state_dict(params['policy'])
self.value.load_state_dict(params['value'])
if i % 100 == 0:
print(f"Learner update {i}, version {new_version}")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--num_actors", default=4, type=int)
parser.add_argument("--env", default="CartPole-v1")
parser.add_argument("--learner_gpus", default=0, type=int)
args = parser.parse_args()
env = gym.make(args.env)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
ps_handle = ParameterServer.remote(state_dim, action_dim)
buffer_handle = RingBuffer.remote(capacity=10000)
device = torch.device("cuda" if args.learner_gpus > 0 and torch.cuda.is_available() else "cpu")
learner = Learner(state_dim, action_dim, ps_handle, device)
actors = [Actor.remote(args.env, ps_handle, i) for i in range(args.num_actors)]
# Start actors
actor_futures = [actor.run.remote(buffer_handle) for actor in actors]
# Start learner
learner.train(buffer_handle)
ray.get(actor_futures)
if __name__ == "__main__":
main()
-
2.3.1.2 V-trace偏置校正:重要性采样比率(ρ)截断(c=1.0)与Actor-Critic共享卷积编码器的梯度分离技术
V-trace偏差校正算法通过截断重要性采样比率,在分布式演员-学习者架构中纠正由于策略滞后引起的价值估计偏差。该方法计算从行为策略到目标策略的逐动作重要性权重,通过上限截断防止方差爆炸,同时保持梯度的无偏性。在共享卷积编码器的场景下,梯度分离技术确保策略网络的更新不干扰用于价值估计的表征学习,通过停止梯度算子阻断不必要的梯度回传。
Python
复制
"""
V-trace实现脚本
涉及内容:重要性采样比率截断、V-trace目标计算、Retrace-style校正、策略与价值梯度分离
使用方式:python vtrace.py --batch_size 64 --trajectory_length 20
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import argparse
class PolicyNetwork(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=256):
super().__init__()
self.shared_encoder = nn.Sequential(
nn.Linear(state_dim, hidden_dim), nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim), nn.ReLU()
)
self.policy_head = nn.Linear(hidden_dim, action_dim)
self.value_head = nn.Linear(hidden_dim, 1)
def forward(self, state):
features = self.shared_encoder(state)
logits = self.policy_head(features)
value = self.value_head(features.detach()) # Gradient isolation
return F.softmax(logits, dim=-1), value
def get_features(self, state):
return self.shared_encoder(state)
class VTrace:
def __init__(self, clip_rho=1.0, clip_c=1.0, rho_bar=1.0):
self.clip_rho = clip_rho
self.clip_c = clip_c
self.rho_bar = rho_bar
def compute_vtrace_target(self, behaviour_policy_probs, target_policy_probs,
actions, rewards, values, next_values, dones, gamma=0.99):
"""
behaviour_policy_probs: [T, B, num_actions] or [T, B]
target_policy_probs: [T, B, num_actions] or [T, B]
actions: [T, B]
rewards: [T, B]
values: [T, B]
next_values: [T, B]
"""
T, B = actions.shape
# Calculate importance sampling ratios
if len(behaviour_policy_probs.shape) == 3:
# Get probs for taken actions
behaviour_probs = behaviour_policy_probs.gather(-1, actions.unsqueeze(-1)).squeeze(-1)
target_probs = target_policy_probs.gather(-1, actions.unsqueeze(-1)).squeeze(-1)
else:
behaviour_probs = behaviour_policy_probs
target_probs = target_policy_probs
rho = target_probs / (behaviour_probs + 1e-10)
# Clipping for variance reduction
clipped_rho = torch.clamp(rho, max=self.clip_rho)
clipped_c = torch.clamp(rho, max=self.clip_c)
# Calculate td errors
td_errors = clipped_rho * (rewards + gamma * next_values * (1 - dones) - values)
# Backward recursion for v-trace targets
vs = torch.zeros_like(values)
vs[-1] = next_values[-1] + td_errors[-1]
for t in reversed(range(T - 1)):
vs[t] = values[t] + td_errors[t] + gamma * clipped_c[t] * (vs[t+1] - next_values[t])
# Policy gradient advantages
pg_advantages = clipped_rho * (rewards + gamma * next_values * (1 - dones) - values) + \
(clipped_rho - 1) * (vs - values)
return vs, pg_advantages, clipped_rho
class VTraceAgent:
def __init__(self, state_dim, action_dim, device, clip_rho=1.0, clip_c=1.0):
self.device = device
self.policy = PolicyNetwork(state_dim, action_dim).to(device)
self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=3e-4)
self.vtrace = VTrace(clip_rho=clip_rho, clip_c=clip_c)
def update(self, trajectories):
# trajectories: dict with states, actions, rewards, behaviour_probs, target_probs, dones
states = torch.FloatTensor(trajectories['states']).to(self.device)
actions = torch.LongTensor(trajectories['actions']).to(self.device)
rewards = torch.FloatTensor(trajectories['rewards']).to(self.device)
behaviour_probs = torch.FloatTensor(trajectories['behaviour_probs']).to(self.device)
dones = torch.FloatTensor(trajectories['dones']).to(self.device)
# Forward pass
target_probs, values = self.policy(states)
_, next_values = self.policy(torch.cat([states[1:], states[:1]], 0)) # Simple next value estimate
# Compute V-trace targets
vs, pg_advantages, rhos = self.vtrace.compute_vtrace_target(
behaviour_probs, target_probs, actions, rewards, values.squeeze(),
next_values.squeeze(), dones
)
# Policy loss
log_probs = torch.log(target_probs.gather(-1, actions.unsqueeze(-1)).squeeze(-1) + 1e-10)
policy_loss = -(log_probs * pg_advantages.detach()).mean()
# Value loss
value_loss = F.mse_loss(values.squeeze(), vs.detach())
# Entropy bonus
entropy = -(target_probs * torch.log(target_probs + 1e-10)).sum(dim=-1).mean()
total_loss = policy_loss + 0.5 * value_loss - 0.01 * entropy
self.optimizer.zero_grad()
total_loss.backward()
torch.nn.utils.clip_grad_norm_(self.policy.parameters(), 40.0)
self.optimizer.step()
return {
'policy_loss': policy_loss.item(),
'value_loss': value_loss.item(),
'mean_rho': rhos.mean().item(),
'mean_adv': pg_advantages.mean().item()
}
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--state_dim", default=4, type=int)
parser.add_argument("--action_dim", default=2, type=int)
parser.add_argument("--batch_size", default=64, type=int)
parser.add_argument("--trajectory_length", default=20, type=int)
parser.add_argument("--clip_rho", default=1.0, type=float)
parser.add_argument("--clip_c", default=1.0, type=float)
args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
agent = VTraceAgent(args.state_dim, args.action_dim, device, args.clip_rho, args.clip_c)
# Simulate trajectories
for step in range(1000):
traj = {
'states': np.random.randn(args.trajectory_length, args.batch_size, args.state_dim),
'actions': np.random.randint(0, args.action_dim, (args.trajectory_length, args.batch_size)),
'rewards': np.random.randn(args.trajectory_length, args.batch_size),
'behaviour_probs': np.random.uniform(0.3, 0.7, (args.trajectory_length, args.batch_size)),
'dones': np.random.randint(0, 2, (args.trajectory_length, args.batch_size)).astype(float)
}
metrics = agent.update(traj)
if step % 100 == 0:
print(f"Step {step}: {metrics}")
if __name__ == "__main__":
main()
2.3.2 GPU端到端加速
-
2.3.2.1 Isaac Gym/Brax物理引擎:全GPU并行模拟( thousands of envs )与JAX JIT编译的端到端训练流水线
Isaac Gym与Brax物理引擎利用图形处理器的大规模并行计算能力,实现数千个环境实例的同时模拟,消除中央处理器与图形处理器之间的数据传输延迟。通过将物理模拟、状态计算与策略推理全部置于图形处理器端执行,并结合即时编译技术对计算图进行优化,该方法实现了端到端的训练流水线,在维持物理仿真精度的同时达到百万级帧每秒的采样吞吐量。
Python
复制
"""
Isaac Gym/Brax GPU端到端训练实现脚本
涉及内容:JAX JIT编译、GPU并行环境、PPO算法、向量环境批处理、端到端训练
使用方式:python brax_train.py --env humanoid --num_envs 4096 --jit_compile True
"""
import jax
import jax.numpy as jnp
from jax import random, jit, grad, vmap
import flax.linen as nn
from flax.training import train_state
import optax
import numpy as np
import argparse
class ActorCritic(nn.Module):
action_dim: int
@nn.compact
def __call__(self, x):
x = nn.Dense(256)(x)
x = nn.relu(x)
x = nn.Dense(256)(x)
x = nn.relu(x)
logits = nn.Dense(self.action_dim)(x)
value = nn.Dense(1)(x)
return logits, value.squeeze()
class BraxEnv:
def __init__(self, env_name, num_envs):
self.env_name = env_name
self.num_envs = num_envs
# In real implementation, use brax.envs.create
self.state_dim = 87 # Humanoid
self.action_dim = 21
def reset(self, rng):
# Parallel reset
return jax.random.normal(rng, (self.num_envs, self.state_dim))
def step(self, state, action):
# Parallel step (simplified physics)
# In real Brax: return env.step(state, action)
next_state = state + 0.01 * action.sum(axis=-1, keepdims=True)
reward = -jnp.sum(next_state**2, axis=-1) # Dummy reward
done = jnp.zeros(self.num_envs)
return next_state, reward, done
@jit
def select_action(params, state, rng):
logits, value = ActorCritic(action_dim=21).apply({'params': params}, state)
rng, key = random.split(rng)
action = random.categorical(key, logits)
log_prob = jax.nn.log_softmax(logits)[jnp.arange(len(action)), action]
return action, log_prob, value, logits, rng
@jit
def compute_gae(rewards, values, dones, gamma=0.99, gae_lambda=0.95):
def _gae_step(carry, transition):
gae, next_value = carry
reward, done, value = transition
delta = reward + gamma * next_value * (1 - done) - value
gae = delta + gamma * gae_lambda * (1 - done) * gae
return (gae, value), gae
values = jnp.concatenate([values, jnp.zeros((1,))])
_, advantages = jax.lax.scan(
_gae_step,
(jnp.zeros(()), values[-1]),
(rewards, dones, values[:-1]),
reverse=True
)
return advantages
def ppo_loss(params, batch, clip_epsilon=0.2, value_coef=0.5, entropy_coef=0.01):
states, actions, old_log_probs, advantages, returns = batch
logits, values = ActorCritic(action_dim=21).apply({'params': params}, states)
log_probs = jax.nn.log_softmax(logits)
log_probs_actions = log_probs[jnp.arange(len(actions)), actions]
# Policy loss
ratio = jnp.exp(log_probs_actions - old_log_probs)
clipped_ratio = jnp.clip(ratio, 1 - clip_epsilon, 1 + clip_epsilon)
policy_loss = -jnp.mean(jnp.minimum(ratio * advantages, clipped_ratio * advantages))
# Value loss
value_loss = jnp.mean((values - returns) ** 2)
# Entropy
entropy = -jnp.mean(jnp.sum(jnp.exp(log_probs) * log_probs, axis=-1))
return policy_loss + value_coef * value_loss - entropy_coef * entropy
@jit
def train_step(state, batch):
loss, grads = jax.value_and_grad(ppo_loss)(state.params, batch)
state = state.apply_gradients(grads=grads)
return state, loss
class GPUParallelTrainer:
def __init__(self, env_name, num_envs=4096, learning_rate=3e-4):
self.env = BraxEnv(env_name, num_envs)
self.num_envs = num_envs
# Initialize network
rng = random.PRNGKey(0)
dummy_input = jnp.zeros((1, self.env.state_dim))
net = ActorCritic(action_dim=self.env.action_dim)
params = net.init(rng, dummy_input)['params']
self.state = train_state.TrainState.create(
apply_fn=net.apply,
params=params,
tx=optax.adam(learning_rate)
)
self.rng = rng
def rollout(self, horizon=128):
# Vectorized rollout across all envs
self.rng, step_rng = random.split(self.rng)
states = self.env.reset(step_rng)
trajectories = {
'states': [], 'actions': [], 'rewards': [],
'log_probs': [], 'values': [], 'dones': []
}
for _ in range(horizon):
self.rng, act_rng = random.split(self.rng)
actions, log_probs, values, _, _ = select_action(
self.state.params, states, act_rng
)
next_states, rewards, dones = self.env.step(states, actions)
trajectories['states'].append(states)
trajectories['actions'].append(actions)
trajectories['rewards'].append(rewards)
trajectories['log_probs'].append(log_probs)
trajectories['values'].append(values)
trajectories['dones'].append(dones)
states = next_states
# Convert to arrays
for key in trajectories:
trajectories[key] = jnp.stack(trajectories[key], axis=0)
# Compute advantages
advantages = compute_gae(
trajectories['rewards'],
trajectories['values'],
trajectories['dones']
)
returns = advantages + trajectories['values']
# Flatten batch dimensions
batch = (
trajectories['states'].reshape(-1, self.env.state_dim),
trajectories['actions'].reshape(-1),
trajectories['log_probs'].reshape(-1),
advantages.reshape(-1),
returns.reshape(-1)
)
return batch
def train(self, num_iterations=1000, batch_size=8192):
for iteration in range(num_iterations):
# Collect data
batch = self.rollout(horizon=128)
# Mini-batch updates
n_samples = batch[0].shape[0]
indices = np.random.permutation(n_samples)
for start in range(0, n_samples, batch_size):
end = start + batch_size
mb_indices = indices[start:end]
mb = tuple(x[mb_indices] for x in batch)
self.state, loss = train_step(self.state, mb)
if iteration % 10 == 0:
print(f"Iteration {iteration}, Loss: {loss:.4f}")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--env", default="humanoid")
parser.add_argument("--num_envs", default=4096, type=int)
parser.add_argument("--learning_rate", default=3e-4, type=float)
args = parser.parse_args()
trainer = GPUParallelTrainer(args.env, args.num_envs, args.learning_rate)
trainer.train()
if __name__ == "__main__":
main()
-
2.3.2.2 RLlib中的SampleCollector优化:通过压缩观测(LZ4/ObsCompression)与异步梯度计算减少CPU-GPU传输瓶颈
RLlib中的样本收集器通过压缩观测数据与异步梯度计算优化分布式训练流水线。LZ4压缩算法在数据传输前对高维观测进行实时压缩,降低网络带宽占用;异步梯度计算允许中央处理器在图形处理器执行前向与反向传播的同时准备下一批次数据,通过计算与通信的重叠消除流水线气泡,最大化硬件利用率。
Python
复制
"""
RLlib SampleCollector优化实现脚本
涉及内容:LZ4观测压缩、异步梯度计算、样本压缩与解压、流水线优化
使用方式:python rllib_optimized.py --compress_observations True --async_gradients True
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import lz4.frame
import threading
import queue
import time
import argparse
from typing import Dict, List, Any
class LZ4Compressor:
@staticmethod
def compress(observation: np.ndarray) -> bytes:
return lz4.frame.compress(observation.tobytes())
@staticmethod
def decompress(compressed: bytes, shape: tuple, dtype: np.dtype) -> np.ndarray:
decompressed = lz4.frame.decompress(compressed)
return np.frombuffer(decompressed, dtype=dtype).reshape(shape)
class AsyncGradientComputer:
def __init__(self, model, device):
self.model = model
self.device = device
self.queue = queue.Queue(maxsize=2)
self.result_queue = queue.Queue()
self.thread = threading.Thread(target=self._compute_loop, daemon=True)
self.running = False
def start(self):
self.running = True
self.thread.start()
def stop(self):
self.running = False
self.thread.join()
def _compute_loop(self):
while self.running:
try:
batch = self.queue.get(timeout=1.0)
# Move to GPU and compute
states = batch['states'].to(self.device)
actions = batch['actions'].to(self.device)
returns = batch['returns'].to(self.device)
advantages = batch['advantages'].to(self.device)
loss = self._compute_loss(states, actions, returns, advantages)
loss.backward()
# Extract gradients
grads = {name: param.grad.clone() for name, param in self.model.named_parameters()}
self.result_queue.put(grads)
# Zero grads for next iteration
self.model.zero_grad()
except queue.Empty:
continue
def _compute_loss(self, states, actions, returns, advantages):
logits, values = self.model(states)
log_probs = F.log_softmax(logits, dim=-1)
log_probs_actions = log_probs.gather(1, actions.unsqueeze(1)).squeeze()
policy_loss = -(log_probs_actions * advantages).mean()
value_loss = F.mse_loss(values, returns)
return policy_loss + 0.5 * value_loss
def submit_batch(self, batch):
self.queue.put(batch, block=True)
def get_gradients(self):
return self.result_queue.get(block=True)
class OptimizedSampleCollector:
def __init__(self, state_shape, action_dim, use_compression=True, use_async=True, device='cuda'):
self.state_shape = state_shape
self.action_dim = action_dim
self.use_compression = use_compression
self.device = device
self.model = nn.Sequential(
nn.Linear(np.prod(state_shape), 256), nn.ReLU(),
nn.Linear(256, 256), nn.ReLU(),
nn.Linear(256, action_dim + 1) # logits + value
).to(device)
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=3e-4)
if use_async:
self.async_computer = AsyncGradientComputer(self.model, device)
self.async_computer.start()
else:
self.async_computer = None
self.compressor = LZ4Compressor() if use_compression else None
self.buffer = []
def collect_sample(self, env, policy):
# Collect with optional compression
state, _ = env.reset()
trajectory = []
done = False
while not done:
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
with torch.no_grad():
output = self.model(state_tensor)
logits = output[:, :-1]
value = output[:, -1]
action_probs = F.softmax(logits, dim=-1)
action = torch.multinomial(action_probs, 1).item()
next_state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
if self.use_compression:
compressed_state = self.compressor.compress(state)
compressed_next = self.compressor.compress(next_state)
trajectory.append({
'state': compressed_state,
'action': action,
'reward': reward,
'next_state': compressed_next,
'done': done,
'state_shape': state.shape,
'state_dtype': state.dtype
})
else:
trajectory.append({
'state': state,
'action': action,
'reward': reward,
'next_state': next_state,
'done': done
})
state = next_state
return trajectory
def process_batch(self, trajectories: List[List[Dict]]):
# Decompress if needed and prepare batch
all_states = []
all_actions = []
all_returns = []
all_advantages = []
gamma = 0.99
gae_lambda = 0.95
for traj in trajectories:
states = []
for step in traj:
if self.use_compression:
state = self.compressor.decompress(
step['state'], step['state_shape'], step['state_dtype']
)
else:
state = step['state']
states.append(state)
states = np.array(states)
rewards = np.array([step['reward'] for step in traj])
actions = np.array([step['action'] for step in traj])
dones = np.array([step['done'] for step in traj])
# Compute returns and advantages
returns = np.zeros_like(rewards)
advantages = np.zeros_like(rewards)
# Simplified GAE computation
returns[-1] = rewards[-1]
for t in reversed(range(len(rewards) - 1)):
returns[t] = rewards[t] + gamma * returns[t+1] * (1 - dones[t])
# Normalize advantages
advantages = returns - returns.mean()
if returns.std() > 0:
advantages = advantages / (returns.std() + 1e-8)
all_states.append(states)
all_actions.append(actions)
all_returns.append(returns)
all_advantages.append(advantages)
batch = {
'states': torch.FloatTensor(np.concatenate(all_states)),
'actions': torch.LongTensor(np.concatenate(all_actions)),
'returns': torch.FloatTensor(np.concatenate(all_returns)),
'advantages': torch.FloatTensor(np.concatenate(all_advantages))
}
return batch
def train_step(self, batch):
if self.async_computer:
# Async path: submit and immediately get previous gradients
self.async_computer.submit_batch(batch)
try:
grads = self.async_computer.result_queue.get_nowait()
# Apply gradients
for name, param in self.model.named_parameters():
if name in grads:
param.grad = grads[name]
self.optimizer.step()
self.optimizer.zero_grad()
except queue.Empty:
pass
else:
# Sync path
states = batch['states'].to(self.device)
actions = batch['actions'].to(self.device)
returns = batch['returns'].to(self.device)
advantages = batch['advantages'].to(self.device)
output = self.model(states)
logits = output[:, :-1]
values = output[:, -1]
log_probs = F.log_softmax(logits, dim=-1)
log_probs_actions = log_probs.gather(1, actions.unsqueeze(1)).squeeze()
policy_loss = -(log_probs_actions * advantages).mean()
value_loss = F.mse_loss(values, returns)
loss = policy_loss + 0.5 * value_loss
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--compress_observations", action="store_true")
parser.add_argument("--async_gradients", action="store_true")
parser.add_argument("--num_envs", default=4, type=int)
parser.add_argument("--batch_size", default=64, type=int)
args = parser.parse_args()
import gymnasium as gym
collector = OptimizedSampleCollector(
state_shape=(4,),
action_dim=2,
use_compression=args.compress_observations,
use_async=args.async_gradients,
device='cuda' if torch.cuda.is_available() else 'cpu'
)
envs = [gym.make("CartPole-v1") for _ in range(args.num_envs)]
# Collection loop
for iteration in range(100):
trajectories = []
for env in envs:
traj = collector.collect_sample(env, collector.model)
trajectories.append(traj)
batch = collector.process_batch(trajectories)
# Split into mini-batches
n = len(batch['states'])
indices = torch.randperm(n)
for start in range(0, n, args.batch_size):
end = min(start + args.batch_size, n)
mb_idx = indices[start:end]
mb = {k: v[mb_idx] for k, v in batch.items()}
collector.train_step(mb)
if iteration % 10 == 0:
print(f"Iteration {iteration} completed")
if args.async_gradients:
collector.async_computer.stop()
if __name__ == "__main__":
main()
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐


所有评论(0)