目录

第二章 高样本效率与模型化方法

2.1 Off-Policy Actor-Critic与熵优化

2.1.1 TD3与SAC的工程细节

2.1.1.1 TD3双延迟更新:

2.1.1.2 SAC自动温度调整:

2.1.2 离线强化学习(Offline RL)部署

2.1.2.1 CQL(Conservative Q-Learning)

2.1.2.2 IQL(Implicit Q-Learning):期望价值(expectile regression)替代max操作,避免OOD动作查询的离线到在线微调代码

2.2 模型基础强化学习(Model-Based RL)

2.2.1 环境动力学模型学习

2.2.1.1 PETS概率性模型:基于高斯混合网络(MDN)的集成模型(Ensemble)实现DON(Distillation via Online Network)对抗模型误差

2.2.1.2 MBPO(Model-Based Policy Optimization):基于模型的虚拟 rollouts 与真实数据混合比例(real-to-sim ratio=0.95)的采样器实现

2.2.2 潜在空间规划(Latent Space Planning)

2.2.2.1 PlaNet/Dreamer架构:RSSM(Recurrent State-Space Model)的确定性路径与随机先验分解,KL散度损失加权(β=1.0)

2.2.2.2 交叉熵方法(CEM)在图像输入中的动作序列优化:通过迭代采样(精英比例top-k=10%)规划未来H步动作序列

2.3 分布式训练系统架构

2.3.1 IMPALA与SEED的工业级实现

2.3.1.1 演员-学习者(Actor-Learner)分离架构:RingBuffer实现异步数据传输,参数服务器(Parameter Server)的Ray分布式实现

2.3.1.2 V-trace偏置校正:重要性采样比率(ρ)截断(c=1.0)与Actor-Critic共享卷积编码器的梯度分离技术

2.3.2 GPU端到端加速

2.3.2.1 Isaac Gym/Brax物理引擎:全GPU并行模拟( thousands of envs )与JAX JIT编译的端到端训练流水线

2.3.2.2 RLlib中的SampleCollector优化:通过压缩观测(LZ4/ObsCompression)与异步梯度计算减少CPU-GPU传输瓶颈


第二章 高样本效率与模型化方法

2.1 Off-Policy Actor-Critic与熵优化

2.1.1 TD3与SAC的工程细节
  • 2.1.1.1 TD3双延迟更新:
  • 目标策略平滑(Target Policy Smoothing)通过向目标动作添加裁剪噪声(clip noise σ=0.2)减少过拟合

 深度强化学习系统在样本效率方面的提升主要依赖于对价值函数估计偏差的抑制以及对环境交互数据的复用。双延迟深度确定性策略梯度算法通过维护两组独立的参数化动作价值估计,在策略优化阶段选取二者中较小者作为优化目标,从而有效缓解持续学习过程中的估计值膨胀现象。目标策略平滑机制在目标网络的动作选择环节引入服从截断正态分布的随机扰动,通过对动作空间施加有界噪声,平滑价值函数在动作维度上的变化率,降低策略对特定动作模式过度拟合的风险。

Python

"""
TD3 (Twin Delayed Deep Deterministic Policy Gradient) 实现脚本
涉及内容:双延迟更新、目标策略平滑、双评论家网络、软更新机制
使用方式:python td3.py --env HalfCheetah-v4 --seed 0
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import gymnasium as gym
from collections import deque
import random
import argparse

class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, max_action, hidden_dim=256):
        super(Actor, self).__init__()
        self.l1 = nn.Linear(state_dim, hidden_dim)
        self.l2 = nn.Linear(hidden_dim, hidden_dim)
        self.l3 = nn.Linear(hidden_dim, action_dim)
        self.max_action = max_action
        
    def forward(self, state):
        a = F.relu(self.l1(state))
        a = F.relu(self.l2(a))
        return self.max_action * torch.tanh(self.l3(a))

class Critic(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=256):
        super(Critic, self).__init__()
        # Twin Q architecture (Q1 and Q2)
        self.l1 = nn.Linear(state_dim + action_dim, hidden_dim)
        self.l2 = nn.Linear(hidden_dim, hidden_dim)
        self.l3 = nn.Linear(hidden_dim, 1)
        
        self.l4 = nn.Linear(state_dim + action_dim, hidden_dim)
        self.l5 = nn.Linear(hidden_dim, hidden_dim)
        self.l6 = nn.Linear(hidden_dim, 1)
        
    def forward(self, state, action):
        sa = torch.cat([state, action], 1)
        q1 = F.relu(self.l1(sa))
        q1 = F.relu(self.l2(q1))
        q1 = self.l3(q1)
        
        q2 = F.relu(self.l4(sa))
        q2 = F.relu(self.l5(q2))
        q2 = self.l6(q2)
        return q1, q2
    
    def Q1(self, state, action):
        sa = torch.cat([state, action], 1)
        q1 = F.relu(self.l1(sa))
        q1 = F.relu(self.l2(q1))
        q1 = self.l3(q1)
        return q1

class ReplayBuffer:
    def __init__(self, state_dim, action_dim, max_size=int(1e6)):
        self.max_size = max_size
        self.ptr = 0
        self.size = 0
        self.state = np.zeros((max_size, state_dim))
        self.action = np.zeros((max_size, action_dim))
        self.next_state = np.zeros((max_size, state_dim))
        self.reward = np.zeros((max_size, 1))
        self.not_done = np.zeros((max_size, 1))
        
    def add(self, state, action, next_state, reward, done):
        self.state[self.ptr] = state
        self.action[self.ptr] = action
        self.next_state[self.ptr] = next_state
        self.reward[self.ptr] = reward
        self.not_done[self.ptr] = 1. - done
        self.ptr = (self.ptr + 1) % self.max_size
        self.size = min(self.size + 1, self.max_size)
        
    def sample(self, batch_size):
        ind = np.random.randint(0, self.size, size=batch_size)
        return (
            torch.FloatTensor(self.state[ind]),
            torch.FloatTensor(self.action[ind]),
            torch.FloatTensor(self.next_state[ind]),
            torch.FloatTensor(self.reward[ind]),
            torch.FloatTensor(self.not_done[ind])
        )

class TD3:
    def __init__(self, state_dim, action_dim, max_action, device, 
                 discount=0.99, tau=0.005, policy_noise=0.2, 
                 noise_clip=0.5, policy_freq=2):
        self.device = device
        self.actor = Actor(state_dim, action_dim, max_action).to(device)
        self.actor_target = Actor(state_dim, action_dim, max_action).to(device)
        self.actor_target.load_state_dict(self.actor.state_dict())
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=3e-4)
        
        self.critic = Critic(state_dim, action_dim).to(device)
        self.critic_target = Critic(state_dim, action_dim).to(device)
        self.critic_target.load_state_dict(self.critic.state_dict())
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=3e-4)
        
        self.max_action = max_action
        self.discount = discount
        self.tau = tau
        self.policy_noise = policy_noise
        self.noise_clip = noise_clip
        self.policy_freq = policy_freq
        self.total_it = 0
        
    def select_action(self, state, noise=0.1):
        state = torch.FloatTensor(state.reshape(1, -1)).to(self.device)
        action = self.actor(state).cpu().data.numpy().flatten()
        if noise != 0:
            action = (action + np.random.normal(0, noise, size=action.shape))
        return np.clip(action, -self.max_action, self.max_action)
    
    def update(self, replay_buffer, batch_size=256):
        self.total_it += 1
        state, action, next_state, reward, not_done = replay_buffer.sample(batch_size)
        state = state.to(self.device)
        action = action.to(self.device)
        next_state = next_state.to(self.device)
        reward = reward.to(self.device)
        not_done = not_done.to(self.device)
        
        with torch.no_grad():
            # Target policy smoothing: add clipped noise to target action
            noise = (torch.randn_like(action) * self.policy_noise).clamp(
                -self.noise_clip, self.noise_clip
            )
            next_action = (self.actor_target(next_state) + noise).clamp(
                -self.max_action, self.max_action
            )
            target_Q1, target_Q2 = self.critic_target(next_state, next_action)
            target_Q = torch.min(target_Q1, target_Q2)
            target_Q = reward + not_done * self.discount * target_Q
            
        current_Q1, current_Q2 = self.critic(state, action)
        critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)
        
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()
        
        # Delayed policy updates
        if self.total_it % self.policy_freq == 0:
            actor_loss = -self.critic.Q1(state, self.actor(state)).mean()
            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            self.actor_optimizer.step()
            
            # Soft update
            for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
            for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
                
    def save(self, filename):
        torch.save(self.critic.state_dict(), filename + "_critic")
        torch.save(self.critic_optimizer.state_dict(), filename + "_critic_optimizer")
        torch.save(self.actor.state_dict(), filename + "_actor")
        torch.save(self.actor_optimizer.state_dict(), filename + "_actor_optimizer")

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--env", default="HalfCheetah-v4", type=str)
    parser.add_argument("--seed", default=0, type=int)
    parser.add_argument("--start_timesteps", default=25000, type=int)
    parser.add_argument("--eval_freq", default=5000, type=int)
    parser.add_argument("--max_timesteps", default=1000000, type=int)
    args = parser.parse_args()
    
    file_name = f"TD3_{args.env}_{args.seed}"
    print(f"Running TD3 on {args.env} with seed {args.seed}")
    
    env = gym.make(args.env)
    env.action_space.seed(args.seed)
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    max_action = float(env.action_space.high[0])
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    policy = TD3(state_dim, action_dim, max_action, device)
    replay_buffer = ReplayBuffer(state_dim, action_dim)
    
    state, _ = env.reset(seed=args.seed)
    episode_reward = 0
    episode_timesteps = 0
    episode_num = 0
    
    for t in range(args.max_timesteps):
        episode_timesteps += 1
        if t < args.start_timesteps:
            action = env.action_space.sample()
        else:
            action = policy.select_action(np.array(state))
        
        next_state, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        episode_reward += reward
        replay_buffer.add(state, action, next_state, reward, float(done))
        
        state = next_state
        
        if t >= args.start_timesteps:
            policy.update(replay_buffer)
            
        if done:
            print(f"Episode {episode_num+1} | Timesteps {t+1} | Reward {episode_reward:.2f}")
            state, _ = env.reset()
            episode_reward = 0
            episode_timesteps = 0
            episode_num += 1
            
        if (t + 1) % args.eval_freq == 0:
            policy.save(f"./models/{file_name}")

if __name__ == "__main__":
    main()

  • 2.1.1.2 SAC自动温度调整:
  • 通过Lagrange乘子法优化温度系数α,实现最大熵目标的动态平衡(PyTorch优化器嵌套优化)

软演员-评论家算法在最大化期望累积回报的同时引入策略熵正则化项,通过鼓励动作分布的随机性维持探索能力。温度系数控制熵奖励与任务奖励之间的权重比例,其自适应调整机制构建基于对偶优化的约束满足问题,利用Lagrange乘子法在策略迭代过程中动态求解最优温度参数,使策略熵维持在预设的目标值附近,避免手动调参导致的探索强度随训练阶段变化而失配。
Python
复制
"""
SAC (Soft Actor-Critic) 实现脚本
涉及内容:自动温度调整、重参数化技巧、双Q网络、高斯策略
使用方式:python sac.py --env HalfCheetah-v4 --alpha auto
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import gymnasium as gym
from torch.distributions import Normal
import argparse

LOG_SIG_MAX = 2
LOG_SIG_MIN = -20
epsilon = 1e-6

class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=256):
        super(Actor, self).__init__()
        self.l1 = nn.Linear(state_dim, hidden_dim)
        self.l2 = nn.Linear(hidden_dim, hidden_dim)
        self.mean = nn.Linear(hidden_dim, action_dim)
        self.log_std = nn.Linear(hidden_dim, action_dim)
        
    def forward(self, state):
        a = F.relu(self.l1(state))
        a = F.relu(self.l2(a))
        mean = self.mean(a)
        log_std = self.log_std(a)
        log_std = torch.clamp(log_std, min=LOG_SIG_MIN, max=LOG_SIG_MAX)
        return mean, log_std
    
    def sample(self, state):
        mean, log_std = self.forward(state)
        std = log_std.exp()
        normal = Normal(mean, std)
        x_t = normal.rsample()  # Reparameterization trick
        y_t = torch.tanh(x_t)
        action = y_t
        log_prob = normal.log_prob(x_t)
        # Enforcing Action Bound
        log_prob -= torch.log(1 - y_t.pow(2) + epsilon)
        log_prob = log_prob.sum(1, keepdim=True)
        mean = torch.tanh(mean)
        return action, log_prob, mean

class Critic(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=256):
        super(Critic, self).__init__()
        # Twin Q networks
        self.l1 = nn.Linear(state_dim + action_dim, hidden_dim)
        self.l2 = nn.Linear(hidden_dim, hidden_dim)
        self.l3 = nn.Linear(hidden_dim, 1)
        
        self.l4 = nn.Linear(state_dim + action_dim, hidden_dim)
        self.l5 = nn.Linear(hidden_dim, hidden_dim)
        self.l6 = nn.Linear(hidden_dim, 1)
        
    def forward(self, state, action):
        sa = torch.cat([state, action], 1)
        q1 = F.relu(self.l1(sa))
        q1 = F.relu(self.l2(q1))
        q1 = self.l3(q1)
        
        q2 = F.relu(self.l4(sa))
        q2 = F.relu(self.l5(q2))
        q2 = self.l6(q2)
        return q1, q2

class ReplayBuffer:
    def __init__(self, state_dim, action_dim, max_size=int(1e6)):
        self.max_size = max_size
        self.ptr = 0
        self.size = 0
        self.state = np.zeros((max_size, state_dim))
        self.action = np.zeros((max_size, action_dim))
        self.next_state = np.zeros((max_size, state_dim))
        self.reward = np.zeros((max_size, 1))
        self.not_done = np.zeros((max_size, 1))
        
    def add(self, state, action, next_state, reward, done):
        self.state[self.ptr] = state
        self.action[self.ptr] = action
        self.next_state[self.ptr] = next_state
        self.reward[self.ptr] = reward
        self.not_done[self.ptr] = 1. - done
        self.ptr = (self.ptr + 1) % self.max_size
        self.size = min(self.size + 1, self.max_size)
        
    def sample(self, batch_size):
        ind = np.random.randint(0, self.size, size=batch_size)
        return (
            torch.FloatTensor(self.state[ind]),
            torch.FloatTensor(self.action[ind]),
            torch.FloatTensor(self.next_state[ind]),
            torch.FloatTensor(self.reward[ind]),
            torch.FloatTensor(self.not_done[ind])
        )

class SAC:
    def __init__(self, state_dim, action_dim, device, 
                 discount=0.99, tau=0.005, alpha=0.2, 
                 automatic_entropy_tuning=True, target_entropy=None):
        self.device = device
        self.discount = discount
        self.tau = tau
        
        self.actor = Actor(state_dim, action_dim).to(device)
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=3e-4)
        
        self.critic = Critic(state_dim, action_dim).to(device)
        self.critic_target = Critic(state_dim, action_dim).to(device)
        self.critic_target.load_state_dict(self.critic.state_dict())
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=3e-4)
        
        self.automatic_entropy_tuning = automatic_entropy_tuning
        if self.automatic_entropy_tuning:
            # Target entropy = -dim(A)
            if target_entropy is None:
                self.target_entropy = -action_dim
            else:
                self.target_entropy = target_entropy
            self.log_alpha = torch.zeros(1, requires_grad=True, device=device)
            self.alpha_optimizer = optim.Adam([self.log_alpha], lr=3e-4)
            self.alpha = self.log_alpha.exp()
        else:
            self.alpha = alpha
            
    def select_action(self, state, evaluate=False):
        state = torch.FloatTensor(state).to(self.device).unsqueeze(0)
        if evaluate:
            _, _, action = self.actor.sample(state)
        else:
            action, _, _ = self.actor.sample(state)
        return action.detach().cpu().numpy()[0]
    
    def update(self, replay_buffer, batch_size=256):
        state, action, next_state, reward, not_done = replay_buffer.sample(batch_size)
        state = state.to(self.device)
        action = action.to(self.device)
        next_state = next_state.to(self.device)
        reward = reward.to(self.device)
        not_done = not_done.to(self.device)
        
        with torch.no_grad():
            next_action, next_log_prob, _ = self.actor.sample(next_state)
            target_Q1, target_Q2 = self.critic_target(next_state, next_action)
            target_Q = torch.min(target_Q1, target_Q2)
            target_Q = reward + not_done * self.discount * (target_Q - self.alpha * next_log_prob)
            
        current_Q1, current_Q2 = self.critic(state, action)
        critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)
        
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()
        
        pi, log_pi, _ = self.actor.sample(state)
        qf1_pi, qf2_pi = self.critic(state, pi)
        min_qf_pi = torch.min(qf1_pi, qf2_pi)
        
        actor_loss = ((self.alpha * log_pi) - min_qf_pi).mean()
        
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()
        
        if self.automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()
            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()
            self.alpha = self.log_alpha.exp()
            
        for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--env", default="HalfCheetah-v4")
    parser.add_argument("--seed", default=0, type=int)
    parser.add_argument("--alpha", default="auto", help="0.2 or auto")
    args = parser.parse_args()
    
    env = gym.make(args.env)
    env.action_space.seed(args.seed)
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    automatic_entropy_tuning = (args.alpha == "auto")
    alpha = 0.2 if not automatic_entropy_tuning else None
    
    policy = SAC(state_dim, action_dim, device, 
                 automatic_entropy_tuning=automatic_entropy_tuning, 
                 alpha=alpha)
    replay_buffer = ReplayBuffer(state_dim, action_dim)
    
    state, _ = env.reset(seed=args.seed)
    for t in range(1000000):
        if t < 10000:
            action = env.action_space.sample()
        else:
            action = policy.select_action(state)
            
        next_state, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        replay_buffer.add(state, action, next_state, reward, float(done))
        state = next_state
        
        if t >= 10000:
            policy.update(replay_buffer)
            
        if done:
            state, _ = env.reset()

if __name__ == "__main__":
    main()
2.1.2 离线强化学习(Offline RL)部署
  • 2.1.2.1 CQL(Conservative Q-Learning)
  • 在标准贝尔曼误差上添加CQL正则项,防止对未见状态-动作对的过估计

保守Q学习算法通过修改标准贝尔曼更新规则,在最小化时间差分误差的同时施加正则化约束,抑制对未在离线数据集中出现的状态-动作对的过高估计。该方法在动作价值函数的学习过程中引入保守性惩罚项,要求学习到的Q值低于真实值,从而在策略评估阶段避免对分布外动作的过度乐观预测,确保从静态数据集中提取的策略在实际环境部署时的稳定性与安全性。

Python

复制

"""
CQL (Conservative Q-Learning) 实现脚本
涉及内容:保守正则项、CQL(H)变体、离线数据集加载、SAC基础架构修改
使用方式:python cql.py --dataset_path d4rl_hopper-medium-v2.npz --cql_weight 5.0
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from torch.utils.data import Dataset, DataLoader
import argparse

class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=256):
        super(Actor, self).__init__()
        self.l1 = nn.Linear(state_dim, hidden_dim)
        self.l2 = nn.Linear(hidden_dim, hidden_dim)
        self.mean = nn.Linear(hidden_dim, action_dim)
        self.log_std = nn.Linear(hidden_dim, action_dim)
        
    def forward(self, state):
        a = F.relu(self.l1(state))
        a = F.relu(self.l2(a))
        mean = self.mean(a)
        log_std = self.log_std(a)
        log_std = torch.clamp(log_std, min=-20, max=2)
        return mean, log_std
    
    def sample(self, state):
        mean, log_std = self.forward(state)
        std = log_std.exp()
        normal = torch.distributions.Normal(mean, std)
        x_t = normal.rsample()
        action = torch.tanh(x_t)
        log_prob = normal.log_prob(x_t)
        log_prob -= torch.log(1 - action.pow(2) + 1e-6)
        return action, log_prob, mean

class Critic(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=256):
        super(Critic, self).__init__()
        self.l1 = nn.Linear(state_dim + action_dim, hidden_dim)
        self.l2 = nn.Linear(hidden_dim, hidden_dim)
        self.l3 = nn.Linear(hidden_dim, 1)
        
        self.l4 = nn.Linear(state_dim + action_dim, hidden_dim)
        self.l5 = nn.Linear(hidden_dim, hidden_dim)
        self.l6 = nn.Linear(hidden_dim, 1)
        
    def forward(self, state, action):
        sa = torch.cat([state, action], 1)
        q1 = F.relu(self.l1(sa))
        q1 = F.relu(self.l2(q1))
        q1 = self.l3(q1)
        
        q2 = F.relu(self.l4(sa))
        q2 = F.relu(self.l5(q2))
        q2 = self.l6(q2)
        return q1, q2

class OfflineDataset(Dataset):
    def __init__(self, data_path):
        data = np.load(data_path)
        self.states = torch.FloatTensor(data['observations'])
        self.actions = torch.FloatTensor(data['actions'])
        self.next_states = torch.FloatTensor(data['next_observations'])
        self.rewards = torch.FloatTensor(data['rewards']).unsqueeze(1)
        self.dones = torch.FloatTensor(data['terminals']).unsqueeze(1)
        self.size = len(self.states)
        
    def __len__(self):
        return self.size
    
    def __getitem__(self, idx):
        return (self.states[idx], self.actions[idx], self.next_states[idx], 
                self.rewards[idx], self.dones[idx])

class CQL:
    def __init__(self, state_dim, action_dim, device, cql_weight=5.0, 
                 target_action_gap=10, temp=1.0):
        self.device = device
        self.cql_weight = cql_weight
        self.target_action_gap = target_action_gap
        self.temp = temp
        
        self.actor = Actor(state_dim, action_dim).to(device)
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=3e-4)
        
        self.critic = Critic(state_dim, action_dim).to(device)
        self.critic_target = Critic(state_dim, action_dim).to(device)
        self.critic_target.load_state_dict(self.critic.state_dict())
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=3e-4)
        
        self.discount = 0.99
        self.tau = 0.005
        
    def cql_loss(self, q_values, actions, dataset_actions):
        # CQL(H) variant with importance weighting
        cql1_loss = torch.logsumexp(q_values / self.temp, dim=1, keepdim=True) * self.temp
        cql1_loss = cql1_loss - q_values.gather(1, actions.long())
        return cql1_loss.mean()
    
    def update(self, batch):
        state, action, next_state, reward, done = batch
        state = state.to(self.device)
        action = action.to(self.device)
        next_state = next_state.to(self.device)
        reward = reward.to(self.device)
        done = done.to(self.device)
        
        with torch.no_grad():
            next_action, next_log_pi, _ = self.actor.sample(next_state)
            target_Q1, target_Q2 = self.critic_target(next_state, next_action)
            target_Q = torch.min(target_Q1, target_Q2) - 0.2 * next_log_pi
            target_Q = reward + (1 - done) * 0.99 * target_Q
            
        current_Q1, current_Q2 = self.critic(state, action)
        bellman_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)
        
        # CQL regularization
        # Sample random actions for conservative estimation
        batch_size = state.shape[0]
        random_actions = torch.FloatTensor(batch_size, action.shape[1]).uniform_(-1, 1).to(self.device)
        random_q1, random_q2 = self.critic(state, random_actions)
        
        # Sample actions from current policy
        current_actions, current_log_pi, _ = self.actor.sample(state)
        current_q1, current_q2 = self.critic(state, current_actions)
        
        cat_q1 = torch.cat([random_q1, current_q1], dim=1)
        cat_q2 = torch.cat([random_q2, current_q2], dim=1)
        
        cql1_loss = torch.logsumexp(cat_q1 / self.temp, dim=1).mean() * self.temp
        cql2_loss = torch.logsumexp(cat_q2 / self.temp, dim=1).mean() * self.temp
        
        cql1_loss = cql1_loss - current_Q1.mean()
        cql2_loss = cql2_loss - current_Q2.mean()
        
        conservative_loss = cql1_loss + cql2_loss
        
        total_critic_loss = bellman_loss + self.cql_weight * conservative_loss
        
        self.critic_optimizer.zero_grad()
        total_critic_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.critic.parameters(), 1.0)
        self.critic_optimizer.step()
        
        # Actor update
        new_actions, log_pi, _ = self.actor.sample(state)
        qf1_new, qf2_new = self.critic(state, new_actions)
        min_qf = torch.min(qf1_new, qf2_new)
        actor_loss = (0.2 * log_pi - min_qf).mean()
        
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()
        
        # Soft update
        for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
            target_param.data.copy_(0.005 * param.data + 0.995 * target_param.data)

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset_path", required=True)
    parser.add_argument("--cql_weight", default=5.0, type=float)
    parser.add_argument("--seed", default=0, type=int)
    args = parser.parse_args()
    
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    dataset = OfflineDataset(args.dataset_path)
    dataloader = DataLoader(dataset, batch_size=256, shuffle=True)
    
    state_dim = dataset.states.shape[1]
    action_dim = dataset.actions.shape[1]
    
    policy = CQL(state_dim, action_dim, device, cql_weight=args.cql_weight)
    
    for epoch in range(1000):
        for batch in dataloader:
            policy.update(batch)
        if epoch % 10 == 0:
            print(f"Epoch {epoch} completed")

if __name__ == "__main__":
    main()
  • 2.1.2.2 IQL(Implicit Q-Learning):期望价值(expectile regression)替代max操作,避免OOD动作查询的离线到在线微调代码
隐式Q学习算法摒弃了在动作价值估计中显式取最大值的操作,转而通过期望分位数回归估计状态价值的期望水平。该方法将策略评估与策略提取解耦,首先学习一个能够表示最优状态价值分布的函数,随后通过优势加权回归从静态数据中提取策略,避免了直接查询分布外动作价值导致的估计误差,为离线到在线的连续学习提供了稳定的桥梁。
Python
复制
"""
IQL (Implicit Q-Learning) 实现脚本
涉及内容:期望分位数回归、优势加权行为克隆、两步策略提取、离线到在线微调
使用方式:python iql.py --dataset_path d4rl_walker-expert-v2.npz --expectile 0.7 --beta 3.0
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from torch.utils.data import Dataset, DataLoader
import argparse

class Value(nn.Module):
    def __init__(self, state_dim, hidden_dim=256):
        super(Value, self).__init__()
        self.l1 = nn.Linear(state_dim, hidden_dim)
        self.l2 = nn.Linear(hidden_dim, hidden_dim)
        self.l3 = nn.Linear(hidden_dim, 1)
        
    def forward(self, state):
        v = F.relu(self.l1(state))
        v = F.relu(self.l2(v))
        v = self.l3(v)
        return v

class QNetwork(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=256):
        super(QNetwork, self).__init__()
        self.l1 = nn.Linear(state_dim + action_dim, hidden_dim)
        self.l2 = nn.Linear(hidden_dim, hidden_dim)
        self.l3 = nn.Linear(hidden_dim, 1)
        
    def forward(self, state, action):
        sa = torch.cat([state, action], 1)
        q = F.relu(self.l1(sa))
        q = F.relu(self.l2(q))
        q = self.l3(q)
        return q

class GaussianPolicy(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=256):
        super(GaussianPolicy, self).__init__()
        self.l1 = nn.Linear(state_dim, hidden_dim)
        self.l2 = nn.Linear(hidden_dim, hidden_dim)
        self.mean = nn.Linear(hidden_dim, action_dim)
        self.log_std = nn.Linear(hidden_dim, action_dim)
        
    def forward(self, state):
        a = F.relu(self.l1(state))
        a = F.relu(self.l2(a))
        mean = torch.tanh(self.mean(a))
        log_std = self.log_std(a)
        log_std = torch.clamp(log_std, -20, 2)
        return mean, log_std
    
    def sample(self, state):
        mean, log_std = self.forward(state)
        std = log_std.exp()
        normal = torch.distributions.Normal(mean, std)
        x_t = normal.rsample()
        action = torch.tanh(x_t)
        return action, mean

class OfflineDataset(Dataset):
    def __init__(self, data_path):
        data = np.load(data_path)
        self.states = torch.FloatTensor(data['observations'])
        self.actions = torch.FloatTensor(data['actions'])
        self.next_states = torch.FloatTensor(data['next_observations'])
        self.rewards = torch.FloatTensor(data['rewards']).unsqueeze(1)
        self.dones = torch.FloatTensor(data['terminals']).unsqueeze(1)
        
    def __len__(self):
        return len(self.states)
    
    def __getitem__(self, idx):
        return (self.states[idx], self.actions[idx], self.next_states[idx], 
                self.rewards[idx], self.dones[idx])

class IQL:
    def __init__(self, state_dim, action_dim, device, expectile=0.7, beta=3.0, tau=0.005):
        self.device = device
        self.expectile = expectile
        self.beta = beta
        self.tau = tau
        
        self.value_net = Value(state_dim).to(device)
        self.v_optimizer = optim.Adam(self.value_net.parameters(), lr=3e-4)
        
        self.q_net = QNetwork(state_dim, action_dim).to(device)
        self.q_target = QNetwork(state_dim, action_dim).to(device)
        self.q_target.load_state_dict(self.q_net.state_dict())
        self.q_optimizer = optim.Adam(self.q_net.parameters(), lr=3e-4)
        
        self.policy = GaussianPolicy(state_dim, action_dim).to(device)
        self.policy_optimizer = optim.Adam(self.policy.parameters(), lr=3e-4)
        
    def expectile_loss(self, diff, expectile):
        weight = torch.where(diff > 0, expectile, 1 - expectile)
        return weight * (diff ** 2)
    
    def update(self, batch):
        state, action, next_state, reward, done = batch
        state = state.to(self.device)
        action = action.to(self.device)
        next_state = next_state.to(self.device)
        reward = reward.to(self.device)
        done = done.to(self.device)
        
        with torch.no_grad():
            next_v = self.value_net(next_state)
            target_q = reward + (1 - done) * 0.99 * next_v
            
        current_q = self.q_net(state, action)
        q_loss = F.mse_loss(current_q, target_q)
        
        self.q_optimizer.zero_grad()
        q_loss.backward()
        self.q_optimizer.step()
        
        # Update value network with expectile regression
        with torch.no_grad():
            q_pred = self.q_target(state, action)
        v_pred = self.value_net(state)
        v_loss = self.expectile_loss(q_pred - v_pred, self.expectile).mean()
        
        self.v_optimizer.zero_grad()
        v_loss.backward()
        self.v_optimizer.step()
        
        # Policy extraction via advantage weighted regression
        with torch.no_grad():
            v = self.value_net(state)
            q = self.q_net(state, action)
            adv = q - v
            exp_adv = torch.exp(adv * self.beta)
            exp_adv = torch.clamp(exp_adv, max=100.0)
            
        action_pred, mean = self.policy.sample(state)
        policy_loss = -((action_pred - action) ** 2).sum(dim=1, keepdim=True) * exp_adv
        policy_loss = policy_loss.mean()
        
        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()
        
        # Soft update target Q
        for param, target_param in zip(self.q_net.parameters(), self.q_target.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset_path", required=True)
    parser.add_argument("--expectile", default=0.7, type=float)
    parser.add_argument("--beta", default=3.0, type=float)
    parser.add_argument("--seed", default=0, type=int)
    args = parser.parse_args()
    
    torch.manual_seed(args.seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    dataset = OfflineDataset(args.dataset_path)
    dataloader = DataLoader(dataset, batch_size=256, shuffle=True)
    
    state_dim = dataset.states.shape[1]
    action_dim = dataset.actions.shape[1]
    
    policy = IQL(state_dim, action_dim, device, expectile=args.expectile, beta=args.beta)
    
    for epoch in range(1000):
        for batch in dataloader:
            policy.update(batch)
        if epoch % 10 == 0:
            print(f"Epoch {epoch}")

if __name__ == "__main__":
    main()

2.2 模型基础强化学习(Model-Based RL)

2.2.1 环境动力学模型学习
  • 2.2.1.1 PETS概率性模型:基于高斯混合网络(MDN)的集成模型(Ensemble)实现DON(Distillation via Online Network)对抗模型误差

概率性集成工具集合通过构建多个异构的动力学模型预测未来状态转移的不确定性分布,每个模型参数化一个高斯混合网络以捕捉环境随机性。集成方法通过模型预测的方差量化认知不确定性,在规划阶段利用模型预测误差的统计特性指导探索行为。蒸馏在线网络技术通过约束集成模型与蒸馏网络之间的预测差异,动态调整模型置信度,防止策略过度拟合特定动力学模型的偏差。

Python

复制

"""
PETS (Probabilistic Ensemble with Trajectory Sampling) 实现脚本
涉及内容:概率集成模型、高斯混合网络、MPC规划、粒子传播、模型不确定性估计
使用方式:python pets.py --env_pendulum --n_ensemble 5 --plan_horizon 20
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import gymnasium as gym
from typing import Tuple
import argparse

class GaussianMixtureNetwork(nn.Module):
    def __init__(self, input_dim, output_dim, n_gaussians=1, hidden_dim=200):
        super().__init__()
        self.n_gaussians = n_gaussians
        self.output_dim = output_dim
        
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, hidden_dim)
        
        # Output: mean, log_std for each Gaussian
        self.mean = nn.Linear(hidden_dim, output_dim * n_gaussians)
        self.log_std = nn.Linear(hidden_dim, output_dim * n_gaussians)
        self.logits = nn.Linear(hidden_dim, n_gaussians)  # Mixture weights
        
    def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        
        mean = self.mean(x).view(-1, self.n_gaussians, self.output_dim)
        log_std = torch.clamp(self.log_std(x), min=-2, max=2)
        std = torch.exp(log_std).view(-1, self.n_gaussians, self.output_dim)
        logits = self.logits(x)
        return mean, std, logits
    
    def sample(self, x) -> torch.Tensor:
        mean, std, logits = self.forward(x)
        # Gumbel sampling for mixture weights
        probs = F.softmax(logits, dim=-1)
        component = torch.multinomial(probs, 1).squeeze(-1)
        batch_idx = torch.arange(x.shape[0])
        selected_mean = mean[batch_idx, component]
        selected_std = std[batch_idx, component]
        sample = selected_mean + selected_std * torch.randn_like(selected_mean)
        return sample

class EnsembleDynamicsModel(nn.Module):
    def __init__(self, state_dim, action_dim, n_ensemble=5):
        super().__init__()
        self.n_ensemble = n_ensemble
        self.state_dim = state_dim
        self.action_dim = action_dim
        
        # Create ensemble of models
        self.models = nn.ModuleList([
            GaussianMixtureNetwork(state_dim + action_dim, state_dim + 1)  # next_state + reward
            for _ in range(n_ensemble)
        ])
        
    def forward(self, state, action, model_idx=None):
        x = torch.cat([state, action], dim=-1)
        if model_idx is not None:
            return self.models[model_idx](x)
        
        # Return predictions from all models
        means, stds, logits = [], [], []
        for model in self.models:
            m, s, l = model(x)
            means.append(m)
            stds.append(s)
            logits.append(l)
        return torch.stack(means), torch.stack(stds), torch.stack(logits)
    
    def predict(self, state, action):
        """Predict next state and reward with uncertainty"""
        x = torch.cat([state, action], dim=-1)
        predictions = []
        for model in self.models:
            mean, std, _ = model(x)
            # Sample from the Gaussian
            sample = mean[:, 0] + std[:, 0] * torch.randn_like(mean[:, 0])
            predictions.append(sample)
        predictions = torch.stack(predictions)  # [n_ensemble, batch, state_dim+1]
        mean_pred = predictions.mean(dim=0)
        std_pred = predictions.std(dim=0)
        return mean_pred, std_pred

class MPCPlanner:
    def __init__(self, model, action_dim, horizon=20, n_samples=1000, top_k=100):
        self.model = model
        self.action_dim = action_dim
        self.horizon = horizon
        self.n_samples = n_samples
        self.top_k = top_k
        self.action_mean = np.zeros(action_dim)
        self.action_std = np.ones(action_dim)
        
    def plan(self, state, goal=None):
        # Cross-Entropy Method (CEM) for action sequence optimization
        state_batch = torch.FloatTensor(state).unsqueeze(0).repeat(self.n_samples, 1)
        
        for _ in range(5):  # CEM iterations
            # Sample action sequences
            action_sequences = []
            for t in range(self.horizon):
                actions = np.random.normal(self.action_mean, self.action_std, 
                                          (self.n_samples, self.action_dim))
                action_sequences.append(torch.FloatTensor(actions))
            
            # Rollout trajectories
            returns = torch.zeros(self.n_samples)
            current_state = state_batch.clone()
            
            for t in range(self.horizon):
                action = action_sequences[t]
                with torch.no_grad():
                    pred, _ = self.model.predict(current_state, action)
                    next_state_delta = pred[:, :-1]  # State change
                    reward = pred[:, -1]  # Predicted reward
                    
                current_state = current_state + next_state_delta
                returns += (0.95 ** t) * reward
            
            # Select elite actions
            elite_indices = torch.topk(returns, self.top_k).indices
            elite_actions = action_sequences[0][elite_indices]
            
            # Update action distribution
            self.action_mean = elite_actions.mean(dim=0).numpy()
            self.action_std = elite_actions.std(dim=0).numpy() + 0.01
        
        return self.action_mean

class PETSAgent:
    def __init__(self, state_dim, action_dim, n_ensemble=5, horizon=20):
        self.model = EnsembleDynamicsModel(state_dim, action_dim, n_ensemble)
        self.planner = MPCPlanner(self.model, action_dim, horizon)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)
        
    def train(self, buffer, batch_size=256, epochs=100):
        for epoch in range(epochs):
            # Sample from buffer (assumed format: [state, action, next_state, reward])
            indices = np.random.choice(len(buffer), batch_size)
            states = torch.FloatTensor(np.array([buffer[i][0] for i in indices]))
            actions = torch.FloatTensor(np.array([buffer[i][1] for i in indices]))
            next_states = torch.FloatTensor(np.array([buffer[i][2] for i in indices]))
            rewards = torch.FloatTensor(np.array([buffer[i][3] for i in indices])).unsqueeze(1)
            
            targets = torch.cat([next_states - states, rewards], dim=1)
            
            loss = 0
            for i, model in enumerate(self.model.models):
                mean, std, logits = model(torch.cat([states, actions], dim=1))
                # Negative log-likelihood
                log_prob = -0.5 * (((targets.unsqueeze(1) - mean) / std) ** 2 + 
                                  torch.log(std ** 2))
                loss += -log_prob.mean()
            
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            
    def act(self, state):
        return self.planner.plan(state)

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--env", default="Pendulum-v1")
    parser.add_argument("--n_ensemble", default=5, type=int)
    parser.add_argument("--plan_horizon", default=20, type=int)
    args = parser.parse_args()
    
    env = gym.make(args.env)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    
    agent = PETSAgent(state_dim, action_dim, args.n_ensemble, args.plan_horizon)
    
    # Initial random data collection
    buffer = []
    state, _ = env.reset()
    for _ in range(1000):
        action = env.action_space.sample()
        next_state, reward, terminated, truncated, _ = env.step(action)
        buffer.append([state, action, next_state, reward])
        state = next_state if not (terminated or truncated) else env.reset()[0]
    
    # Train model
    agent.train(buffer)
    
    # MPC control loop
    state, _ = env.reset()
    for step in range(1000):
        action = agent.act(state)
        next_state, reward, terminated, truncated, _ = env.step(action)
        buffer.append([state, action, next_state, reward])
        state = next_state
        if terminated or truncated:
            state, _ = env.reset()

if __name__ == "__main__":
    main()
  • 2.2.1.2 MBPO(Model-Based Policy Optimization):基于模型的虚拟 rollouts 与真实数据混合比例(real-to-sim ratio=0.95)的采样器实现

基于模型的策略优化算法通过训练深度神经网络近似环境转移动力学,利用学习到的模型生成虚拟轨迹以扩展训练数据集。该方法采用短视域模型预测与真实数据混合训练的策略,通过控制虚拟轨迹与真实交互的比例,在保持样本效率的同时抑制模型复合误差的累积。集成模型的使用进一步通过预测方差过滤不可靠的合成样本,确保策略优化过程中使用的状态转移估计满足精度要求。

Python

复制

"""
MBPO (Model-Based Policy Optimization) 实现脚本
涉及内容:集成动力学模型、模型生成rollout、虚拟与真实数据混合、SAC智能体集成
使用方式:python mbpo.py --env HalfCheetah-v4 --real_ratio 0.05 --rollout_length 1
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import gymnasium as gym
from collections import deque
import argparse

class EnsembleModel(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_size=200, num_networks=7):
        super().__init__()
        self.num_networks = num_networks
        self.state_dim = state_dim
        self.action_dim = action_dim
        
        self.ensemble = nn.ModuleList([
            nn.Sequential(
                nn.Linear(state_dim + action_dim, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, state_dim + 1)  # next_state_delta + reward
            ) for _ in range(num_networks)
        ])
        
        self.max_logvar = nn.Parameter(torch.ones(1, state_dim + 1) / 2.0)
        self.min_logvar = nn.Parameter(-torch.ones(1, state_dim + 1) * 10)
        
    def forward(self, state, action):
        x = torch.cat([state, action], dim=-1)
        outputs = [model(x) for model in self.ensemble]
        outputs = torch.stack(outputs)  # [num_networks, batch, state_dim+1]
        
        mean = outputs.mean(dim=0)
        log_var = torch.clamp(outputs.var(dim=0), min=1e-8).log()
        return mean, log_var
    
    def loss(self, state, action, next_state, reward):
        target = torch.cat([next_state - state, reward.unsqueeze(-1)], dim=-1)
        mean, log_var = self.forward(state, action)
        
        inv_var = torch.exp(-log_var)
        loss = ((mean - target.detach()) ** 2) * inv_var + log_var
        return loss.mean()

class ModelBuffer:
    def __init__(self, capacity=int(1e6)):
        self.buffer = deque(maxlen=capacity)
        
    def add(self, state, action, next_state, reward, done):
        self.buffer.append([state, action, next_state, reward, done])
        
    def sample(self, batch_size):
        indices = np.random.choice(len(self.buffer), batch_size, replace=False)
        batch = [self.buffer[i] for i in indices]
        states = torch.FloatTensor(np.array([x[0] for x in batch]))
        actions = torch.FloatTensor(np.array([x[1] for x in batch]))
        next_states = torch.FloatTensor(np.array([x[2] for x in batch]))
        rewards = torch.FloatTensor(np.array([x[3] for x in batch]))
        dones = torch.FloatTensor(np.array([x[4] for x in batch]))
        return states, actions, next_states, rewards, dones
    
    def __len__(self):
        return len(self.buffer)

class SAC:
    # Simplified SAC implementation for MBPO integration
    def __init__(self, state_dim, action_dim, device):
        self.device = device
        self.actor = nn.Sequential(
            nn.Linear(state_dim, 256), nn.ReLU(),
            nn.Linear(256, 256), nn.ReLU(),
            nn.Linear(256, action_dim), nn.Tanh()
        ).to(device)
        self.critic1 = nn.Sequential(
            nn.Linear(state_dim + action_dim, 256), nn.ReLU(),
            nn.Linear(256, 256), nn.ReLU(),
            nn.Linear(256, 1)
        ).to(device)
        self.critic2 = nn.Sequential(
            nn.Linear(state_dim + action_dim, 256), nn.ReLU(),
            nn.Linear(256, 256), nn.ReLU(),
            nn.Linear(256, 1)
        ).to(device)
        
        self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=3e-4)
        self.critic_opt = torch.optim.Adam(
            list(self.critic1.parameters()) + list(self.critic2.parameters()), lr=3e-4
        )
        
    def select_action(self, state, noise=0.1):
        state = torch.FloatTensor(state).to(self.device)
        action = self.actor(state).detach().cpu().numpy()
        if noise > 0:
            action += np.random.normal(0, noise, size=action.shape)
        return np.clip(action, -1, 1)
    
    def update(self, buffer, batch_size=256, real_ratio=0.05):
        # Sample from both real and model buffer
        n_real = int(batch_size * real_ratio)
        n_model = batch_size - n_real
        
        if len(buffer) < n_model:
            return
            
        # Assume buffer provides both real and model data
        # In practice, MBPO maintains separate buffers
        states, actions, next_states, rewards, dones = buffer.sample(batch_size)
        states = states.to(self.device)
        actions = actions.to(self.device)
        next_states = next_states.to(self.device)
        rewards = rewards.to(self.device)
        dones = dones.to(self.device)
        
        # Standard SAC update
        with torch.no_grad():
            next_actions = self.actor(next_states)
            target_q = torch.min(
                self.critic1(torch.cat([next_states, next_actions], 1)),
                self.critic2(torch.cat([next_states, next_actions], 1))
            )
            target_q = rewards.unsqueeze(1) + 0.99 * (1 - dones.unsqueeze(1)) * target_q
            
        current_q1 = self.critic1(torch.cat([states, actions], 1))
        current_q2 = self.critic2(torch.cat([states, actions], 1))
        critic_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q)
        
        self.critic_opt.zero_grad()
        critic_loss.backward()
        self.critic_opt.step()
        
        actor_loss = -self.critic1(torch.cat([states, self.actor(states)], 1)).mean()
        self.actor_opt.zero_grad()
        actor_loss.backward()
        self.actor_opt.step()

class MBPO:
    def __init__(self, state_dim, action_dim, device, 
                 rollout_length=1, rollout_batch_size=50000, 
                 real_ratio=0.05, ensemble_size=7):
        self.device = device
        self.model = EnsembleModel(state_dim, action_dim, num_networks=ensemble_size).to(device)
        self.model_opt = torch.optim.Adam(self.model.parameters(), lr=1e-3)
        
        self.policy = SAC(state_dim, action_dim, device)
        self.real_buffer = ModelBuffer()
        self.model_buffer = ModelBuffer()
        
        self.rollout_length = rollout_length
        self.rollout_batch_size = rollout_batch_size
        self.real_ratio = real_ratio
        self.ensemble_size = ensemble_size
        
    def train_model(self, batch_size=256, epochs=100):
        for epoch in range(epochs):
            if len(self.real_buffer) < batch_size:
                continue
            states, actions, next_states, rewards, _ = self.real_buffer.sample(batch_size)
            states = states.to(self.device)
            actions = actions.to(self.device)
            next_states = next_states.to(self.device)
            rewards = rewards.to(self.device)
            
            loss = self.model.loss(states, actions, next_states, rewards)
            self.model_opt.zero_grad()
            loss.backward()
            self.model_opt.step()
            
    def generate_rollouts(self):
        # Sample initial states from real buffer
        init_states, _, _, _, _ = self.real_buffer.sample(self.rollout_batch_size)
        init_states = init_states.to(self.device)
        
        states = init_states
        for _ in range(self.rollout_length):
            actions = torch.FloatTensor(self.policy.select_action(states.cpu().numpy(), noise=0.2)).to(self.device)
            
            with torch.no_grad():
                mean, log_var = self.model(states, actions)
                std = torch.exp(0.5 * log_var)
                next_state_delta = mean[:, :-1] + std[:, :-1] * torch.randn_like(mean[:, :-1])
                rewards = mean[:, -1] + std[:, -1] * torch.randn_like(mean[:, -1])
                
            next_states = states + next_state_delta
            
            # Add to model buffer
            for i in range(len(states)):
                self.model_buffer.add(
                    states[i].cpu().numpy(), 
                    actions[i].cpu().numpy(), 
                    next_states[i].cpu().numpy(), 
                    rewards[i].cpu().numpy(),
                    False
                )
            states = next_states
            
    def update_policy(self):
        # Combine real and model buffer
        total_size = min(len(self.real_buffer) + len(self.model_buffer), 256)
        n_real = int(total_size * self.real_ratio)
        n_model = total_size - n_real
        
        if len(self.real_buffer) < n_real or len(self.model_buffer) < n_model:
            return
            
        real_samples = self.real_buffer.sample(n_real)
        model_samples = self.model_buffer.sample(n_model)
        
        # Merge samples
        states = torch.cat([real_samples[0], model_samples[0]], 0).to(self.device)
        actions = torch.cat([real_samples[1], model_samples[1]], 0).to(self.device)
        next_states = torch.cat([real_samples[2], model_samples[2]], 0).to(self.device)
        rewards = torch.cat([real_samples[3], model_samples[3]], 0).to(self.device)
        dones = torch.cat([real_samples[4], model_samples[4]], 0).to(self.device)
        
        # Update SAC with mixed batch
        class MixedBuffer:
            def __init__(self, s, a, ns, r, d):
                self.batch = (s, a, ns, r, d)
            def sample(self, size):
                return self.batch
                
        mixed = MixedBuffer(states, actions, next_states, rewards, dones)
        self.policy.update(mixed, batch_size=len(states), real_ratio=1.0)  # Already mixed
        
    def run(self, env, total_steps=300000):
        state, _ = env.reset()
        episode_step = 0
        
        for step in range(total_steps):
            action = self.policy.select_action(state)
            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            
            self.real_buffer.add(state, action, next_state, reward, float(done))
            state = next_state
            episode_step += 1
            
            if done:
                state, _ = env.reset()
                episode_step = 0
                
            # Model training every 250 steps
            if step > 5000 and step % 250 == 0:
                self.train_model(epochs=5)
                self.generate_rollouts()
                
            if step > 5000:
                self.update_policy()

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--env", default="HalfCheetah-v4")
    parser.add_argument("--real_ratio", default=0.05, type=float)
    parser.add_argument("--rollout_length", default=1, type=int)
    args = parser.parse_args()
    
    env = gym.make(args.env)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    agent = MBPO(state_dim, action_dim, device, 
                 rollout_length=args.rollout_length,
                 real_ratio=args.real_ratio)
    agent.run(env)

if __name__ == "__main__":
    main()
2.2.2 潜在空间规划(Latent Space Planning)
  • 2.2.2.1 PlaNet/Dreamer架构:RSSM(Recurrent State-Space Model)的确定性路径与随机先验分解,KL散度损失加权(β=1.0)
循环状态空间模型通过将环境动态分解为确定性的循环路径与随机性的初始状态先验,在高维观测空间中学习紧凑的潜在表征。该架构利用循环神经网络捕获时间依赖性,同时通过变分推断优化潜在状态分布与观测重建之间的平衡。变分自由能的分解将模型学习、表示学习与策略优化统一在单一目标函数下,允许智能体在潜在空间中进行规划而无需原始高维图像输入。
Python
复制
"""
PlaNet/Dreamer 实现脚本
涉及内容:RSSM循环状态空间模型、变分推断、潜在空间规划、图像重建、Actor-Critic在 latent space
使用方式:python dreamer.py --env_dmc walker_walk --image_size 64 --latent_dim 30
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import gymnasium as gym
from typing import Dict, List, Tuple
import argparse

class RSSM(nn.Module):
    def __init__(self, stochastic_dim=30, deterministic_dim=200, hidden_dim=200):
        super().__init__()
        self.stochastic_dim = stochastic_dim
        self.deterministic_dim = deterministic_dim
        
        # Recurrent cell
        self.gru = nn.GRUCell(stochastic_dim + hidden_dim, deterministic_dim)
        
        # Prior p(s_t | h_t)
        self.fc_prior = nn.Linear(deterministic_dim, stochastic_dim * 2)
        
        # Posterior q(s_t | h_t, x_t)
        self.fc_posterior = nn.Linear(deterministic_dim + hidden_dim, stochastic_dim * 2)
        
    def init_state(self, batch_size, device):
        h = torch.zeros(batch_size, self.deterministic_dim).to(device)
        s = torch.zeros(batch_size, self.stochastic_dim).to(device)
        return h, s
    
    def imagine_step(self, h_prev, action):
        # Prior: predict next stochastic state from deterministic path
        x = torch.cat([h_prev, action], dim=-1)
        h = self.gru(x, h_prev)
        prior_mean, prior_std_logit = torch.chunk(self.fc_prior(h), 2, dim=-1)
        prior_std = F.softplus(prior_std_logit) + 0.1
        s = prior_mean + prior_std * torch.randn_like(prior_mean)
        return h, s, prior_mean, prior_std
    
    def observe_step(self, h_prev, s_prev, action, embedding):
        # Posterior: incorporate observation
        x = torch.cat([h_prev, action], dim=-1)
        h = self.gru(x, h_prev)
        posterior_input = torch.cat([h, embedding], dim=-1)
        post_mean, post_std_logit = torch.chunk(self.fc_posterior(posterior_input), 2, dim=-1)
        post_std = F.softplus(post_std_logit) + 0.1
        s = post_mean + post_std * torch.randn_like(post_mean)
        return h, s, post_mean, post_std

class ObservationEncoder(nn.Module):
    def __init__(self, image_channels=3, depth=32, latent_dim=200):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(image_channels, 1*depth, 4, stride=2),
            nn.ReLU(),
            nn.Conv2d(1*depth, 2*depth, 4, stride=2),
            nn.ReLU(),
            nn.Conv2d(2*depth, 4*depth, 4, stride=2),
            nn.ReLU(),
            nn.Conv2d(4*depth, 8*depth, 4, stride=2),
            nn.ReLU(),
            nn.Flatten()
        )
        self.out_dim = 8*depth * 2 * 2  # Assuming 64x64 input -> 2x2
        self.fc = nn.Linear(self.out_dim, latent_dim)
        
    def forward(self, x):
        x = self.net(x)
        return self.fc(x)

class ObservationDecoder(nn.Module):
    def __init__(self, state_dim, depth=32, image_channels=3):
        super().__init__()
        self.fc = nn.Linear(state_dim, 8*depth * 4 * 4)
        self.net = nn.Sequential(
            nn.ConvTranspose2d(8*depth, 4*depth, 5, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(4*depth, 2*depth, 5, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(2*depth, 1*depth, 6, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(1*depth, image_channels, 6, stride=2),
        )
        
    def forward(self, state):
        x = self.fc(state).view(-1, 32*8, 4, 4)
        return self.net(x)

class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=200):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim),
            nn.Tanh()
        )
        
    def forward(self, state):
        return self.net(state)

class ValueModel(nn.Module):
    def __init__(self, state_dim, hidden_dim=200):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
        
    def forward(self, state):
        return self.net(state)

class Dreamer:
    def __init__(self, image_channels, action_dim, device, 
                 stochastic_dim=30, deterministic_dim=200):
        self.device = device
        self.action_dim = action_dim
        
        self.encoder = ObservationEncoder(image_channels, latent_dim=200).to(device)
        self.rssm = RSSM(stochastic_dim, deterministic_dim, hidden_dim=200).to(device)
        self.decoder = ObservationDecoder(stochastic_dim + deterministic_dim).to(device)
        self.reward_model = nn.Sequential(
            nn.Linear(stochastic_dim + deterministic_dim, 200),
            nn.ReLU(),
            nn.Linear(200, 1)
        ).to(device)
        
        self.actor = Actor(stochastic_dim + deterministic_dim, action_dim).to(device)
        self.value_model = ValueModel(stochastic_dim + deterministic_dim).to(device)
        
        self.models = [self.encoder, self.rssm, self.decoder, self.reward_model]
        self.behavior = [self.actor, self.value_model]
        
        self.model_opt = optim.Adam(
            list(self.encoder.parameters()) + 
            list(self.rssm.parameters()) + 
            list(self.decoder.parameters()) + 
            list(self.reward_model.parameters()), 
            lr=1e-3
        )
        self.actor_opt = optim.Adam(self.actor.parameters(), lr=8e-5)
        self.value_opt = optim.Adam(self.value_model.parameters(), lr=8e-5)
        
    def preprocess(self, obs):
        # Normalize images to [-0.5, 0.5]
        return obs / 255.0 - 0.5
    
    def update_world_model(self, observations, actions, rewards):
        # observations: [batch, seq, channels, height, width]
        # actions: [batch, seq, action_dim]
        batch_size, seq_len = observations.shape[:2]
        
        # Embed observations
        flat_obs = observations.reshape(-1, *observations.shape[2:])
        embeddings = self.encoder(flat_obs).view(batch_size, seq_len, -1)
        
        # Initialize state
        h, s = self.rssm.init_state(batch_size, self.device)
        
        kl_loss = 0
        recon_loss = 0
        reward_loss = 0
        
        states = []
        
        for t in range(seq_len):
            action = actions[:, t]
            emb = embeddings[:, t]
            
            # Observe step
            h, s, post_mean, post_std = self.rssm.observe_step(h, s, action, emb)
            
            # Compute KL divergence between posterior and prior
            prior_mean, prior_std_logit = torch.chunk(self.rssm.fc_prior(h), 2, dim=-1)
            prior_std = F.softplus(prior_std_logit) + 0.1
            
            kl = torch.distributions.kl_divergence(
                torch.distributions.Normal(post_mean, post_std),
                torch.distributions.Normal(prior_mean, prior_std)
            ).sum(dim=-1)
            kl_loss += kl.mean()
            
            # Reconstruction
            state = torch.cat([h, s], dim=-1)
            recon = self.decoder(state)
            recon_loss += F.mse_loss(recon, observations[:, t])
            
            # Reward prediction
            pred_reward = self.reward_model(state)
            reward_loss += F.mse_loss(pred_reward, rewards[:, t].unsqueeze(-1))
            
            states.append(state)
            
        total_loss = recon_loss + reward_loss + 0.1 * kl_loss
        
        self.model_opt.zero_grad()
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model_opt.param_groups[0]['params'], 100.0)
        self.model_opt.step()
        
        return torch.stack(states, dim=1)
    
    def imagine_trajectories(self, initial_states, horizon=15):
        # Roll out policy in imagination
        states = [initial_states]
        rewards = []
        values = []
        
        h = initial_states[:, :self.rssm.deterministic_dim]
        s = initial_states[:, self.rssm.deterministic_dim:]
        
        for _ in range(horizon):
            state = torch.cat([h, s], dim=-1)
            action = self.actor(state)
            
            # Imagine next state
            h, s, _, _ = self.rssm.imagine_step(h, action)
            next_state = torch.cat([h, s], dim=-1)
            
            pred_reward = self.reward_model(next_state)
            pred_value = self.value_model(next_state)
            
            states.append(next_state)
            rewards.append(pred_reward)
            values.append(pred_value)
            
        return torch.stack(states, dim=1), torch.stack(rewards, dim=1), torch.stack(values, dim=1)
    
    def update_behavior(self, initial_states):
        states, rewards, values = self.imagine_trajectories(initial_states)
        
        # Lambda-return calculation
        lambda_ = 0.95
        discounts = torch.ones_like(rewards) * 0.99
        
        # Compute returns
        returns = []
        last_value = values[:, -1]
        for t in reversed(range(rewards.shape[1])):
            if t == rewards.shape[1] - 1:
                ret = rewards[:, t] + discounts[:, t] * last_value
            else:
                ret = rewards[:, t] + discounts[:, t] * (lambda_ * returns[0] + (1 - lambda_) * values[:, t+1])
            returns.insert(0, ret)
        returns = torch.stack(returns, dim=1)
        
        # Actor loss: maximize expected return
        actor_loss = -returns.mean()
        
        # Value loss
        value_pred = self.value_model(states[:, :-1].reshape(-1, states.shape[-1])).view(states.shape[0], states.shape[1]-1, 1)
        value_loss = F.mse_loss(value_pred, returns.detach())
        
        self.actor_opt.zero_grad()
        actor_loss.backward()
        self.actor_opt.step()
        
        self.value_opt.zero_grad()
        value_loss.backward()
        self.value_opt.step()

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--env", default="Walker2d-v4")
    parser.add_argument("--image_size", default=64, type=int)
    parser.add_argument("--latent_dim", default=30, type=int)
    parser.add_argument("--batch_size", default=50, type=int)
    parser.add_argument("--seq_len", default=50, type=int)
    args = parser.parse_args()
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Note: For image-based control, use DMC or custom wrapper
    # This example uses state-based as placeholder for image logic
    env = gym.make(args.env)
    action_dim = env.action_space.shape[0]
    
    agent = Dreamer(3, action_dim, device, stochastic_dim=args.latent_dim)
    
    # Simulated experience buffer (in practice, collect from environment)
    buffer = []
    state, _ = env.reset()
    for _ in range(1000):
        action = env.action_space.sample()
        next_state, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        buffer.append((state, action, next_state, reward, done))
        state = next_state if not done else env.reset()[0]
    
    # Training loop
    for epoch in range(1000):
        # Sample sequence batch
        batch_idx = np.random.choice(len(buffer)-args.seq_len, args.batch_size)
        obs_seq = torch.randn(args.batch_size, args.seq_len, 3, 64, 64)  # Placeholder for images
        act_seq = torch.randn(args.batch_size, args.seq_len, action_dim)
        rew_seq = torch.randn(args.batch_size, args.seq_len)
        
        states = agent.update_world_model(obs_seq.to(device), act_seq.to(device), rew_seq.to(device))
        agent.update_behavior(states[:, 0])

if __name__ == "__main__":
    main()
  • 2.2.2.2 交叉熵方法(CEM)在图像输入中的动作序列优化:通过迭代采样(精英比例top-k=10%)规划未来H步动作序列
交叉熵方法在潜在空间规划中通过迭代优化动作序列分布,将最优控制问题转化为概率推断任务。该方法从当前动作分布中采样多条候选轨迹,评估其在习得的环境模型中的累积回报,随后根据精英样本的统计特性更新分布参数。在图像输入环境下,视觉编码器将高维观测压缩至紧凑潜在表示,规划器在该低维空间中进行蒙特卡洛树搜索,通过有限步长的前瞻计算选择最优动作。
Python
复制
"""
CEM (Cross Entropy Method) 在图像输入中的动作序列优化
涉及内容:图像编码器、动作序列优化、迭代分布更新、精英采样、模型预测控制
使用方式:python cem_image.py --env cartpole_swingup --horizon 20 --n_samples 1000
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import gymnasium as gym
import argparse

class ImageEncoder(nn.Module):
    def __init__(self, latent_dim=50):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 32, 3, stride=2), nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=2), nn.ReLU(),
            nn.Conv2d(64, 128, 3, stride=2), nn.ReLU(),
            nn.Flatten()
        )
        self.fc = nn.Linear(128 * 7 * 7, latent_dim)  # Assuming 64x64 input
        
    def forward(self, x):
        x = self.conv(x)
        return self.fc(x)

class SimpleDynamics(nn.Module):
    def __init__(self, latent_dim, action_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(latent_dim + action_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, latent_dim + 1)  # next_latent + reward
        )
        
    def forward(self, latent, action):
        x = torch.cat([latent, action], dim=-1)
        out = self.net(x)
        next_latent = out[:, :-1]
        reward = out[:, -1]
        return next_latent, reward

class CEMPlanner:
    def __init__(self, encoder, dynamics, action_dim, horizon=20, 
                 n_samples=1000, elite_ratio=0.1, iterations=5):
        self.encoder = encoder
        self.dynamics = dynamics
        self.action_dim = action_dim
        self.horizon = horizon
        self.n_samples = n_samples
        self.elite_ratio = elite_ratio
        self.iterations = iterations
        
    def plan(self, image_obs, goal_latent=None):
        # Encode current observation
        with torch.no_grad():
            current_latent = self.encoder(image_obs.unsqueeze(0))
            
        # Initialize action distribution
        mean = torch.zeros(self.horizon, self.action_dim)
        std = torch.ones(self.horizon, self.action_dim)
        
        for _ in range(self.iterations):
            # Sample action sequences
            actions = torch.normal(mean.unsqueeze(0).repeat(self.n_samples, 1, 1), 
                                  std.unsqueeze(0).repeat(self.n_samples, 1, 1))
            actions = torch.clamp(actions, -1, 1)
            
            # Evaluate sequences
            returns = self.evaluate_sequences(current_latent.repeat(self.n_samples, 1), actions, goal_latent)
            
            # Select elites
            n_elites = int(self.n_samples * self.elite_ratio)
            elite_indices = torch.topk(returns, n_elites).indices
            elite_actions = actions[elite_indices]
            
            # Update distribution
            mean = elite_actions.mean(dim=0)
            std = elite_actions.std(dim=0) + 0.01
            
        return mean[0].numpy()  # Return first action
    
    def evaluate_sequences(self, initial_latents, action_sequences, goal_latent=None):
        batch_size = action_sequences.shape[0]
        latents = initial_latents.repeat(batch_size, 1)
        total_rewards = torch.zeros(batch_size)
        
        for t in range(self.horizon):
            actions = action_sequences[:, t]
            with torch.no_grad():
                next_latents, rewards = self.dynamics(latents, actions)
                
            if goal_latent is not None:
                # Distance to goal
                rewards = -torch.norm(next_latents - goal_latent, dim=1)
                
            total_rewards += (0.99 ** t) * rewards
            latents = next_latents
            
        return total_rewards

class CEMAgent:
    def __init__(self, state_shape, action_dim, device):
        self.device = device
        self.latent_dim = 50
        
        self.encoder = ImageEncoder(self.latent_dim).to(device)
        self.dynamics = SimpleDynamics(self.latent_dim, action_dim).to(device)
        self.planner = CEMPlanner(self.encoder, self.dynamics, action_dim)
        
        self.optimizer = torch.optim.Adam(
            list(self.encoder.parameters()) + list(self.dynamics.parameters()), 
            lr=1e-3
        )
        
    def train(self, buffer, epochs=100):
        # Buffer contains (image, action, next_image, reward)
        for epoch in range(epochs):
            batch = buffer.sample(32)
            images = torch.FloatTensor(batch['images']).to(self.device)
            actions = torch.FloatTensor(batch['actions']).to(self.device)
            next_images = torch.FloatTensor(batch['next_images']).to(self.device)
            rewards = torch.FloatTensor(batch['rewards']).to(self.device)
            
            # Encode
            latents = self.encoder(images)
            next_latents_pred, pred_rewards = self.dynamics(latents, actions)
            
            # Reconstruction loss for next image (simplified)
            with torch.no_grad():
                next_latents_target = self.encoder(next_images)
                
            loss = F.mse_loss(next_latents_pred, next_latents_target) + \
                   F.mse_loss(pred_rewards, rewards)
                   
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            
    def act(self, image):
        image_tensor = torch.FloatTensor(image).unsqueeze(0).to(self.device)
        action = self.planner.plan(image_tensor)
        return action

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--env", default="CartPole-v1")
    parser.add_argument("--horizon", default=20, type=int)
    parser.add_argument("--n_samples", default=1000, type=int)
    parser.add_argument("--elite_ratio", default=0.1, type=float)
    args = parser.parse_args()
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    env = gym.make(args.env)
    
    # Assume image observation space
    action_dim = env.action_space.shape[0]
    agent = CEMAgent((3, 64, 64), action_dim, device)
    
    # Training phase with random exploration
    buffer = {'images': [], 'actions': [], 'next_images': [], 'rewards': []}
    obs, _ = env.reset()
    
    for step in range(5000):
        action = env.action_space.sample()
        next_obs, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        
        # Store (simplified, assume obs is already image)
        buffer['images'].append(obs)
        buffer['actions'].append(action)
        buffer['next_images'].append(next_obs)
        buffer['rewards'].append(reward)
        
        obs = next_obs if not done else env.reset()[0]
        
    # Train model
    agent.train(buffer)
    
    # Deploy with CEM planning
    obs, _ = env.reset()
    for step in range(1000):
        action = agent.act(obs)
        next_obs, reward, terminated, truncated, _ = env.step(action)
        if terminated or truncated:
            obs, _ = env.reset()
        else:
            obs = next_obs

if __name__ == "__main__":
    main()

2.3 分布式训练系统架构

2.3.1 IMPALA与SEED的工业级实现
  • 2.3.1.1 演员-学习者(Actor-Learner)分离架构:RingBuffer实现异步数据传输,参数服务器(Parameter Server)的Ray分布式实现
重要性加权演员-学习者架构通过解耦环境交互与策略优化过程,利用分布式采样节点并行生成经验序列,经由环形缓冲区实现异步数据传输。参数服务器维护全局网络权重的集中副本,采用异步梯度更新机制允许学习节点在无需锁定的条件下读取过时参数,通过梯度累积的延迟容忍特性最大化硬件吞吐量。该架构在异构计算集群上实现了采样与训练的计算重叠,显著降低了中央处理单元与图形处理器之间的通信瓶颈。
Python
复制
"""
IMPALA/SEED 分布式架构实现脚本
涉及内容:Actor-Learner分离、RingBuffer异步传输、Ray分布式参数服务器、异步梯度更新
使用方式:python impala_seed.py --num_actors 10 --learner_gpus 1
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import gymnasium as gym
import ray
from collections import deque
import threading
import queue
import argparse

ray.init(ignore_reinit_error=True)

@ray.remote
class ParameterServer:
    def __init__(self, state_dim, action_dim, hidden_dim=256):
        self.policy = nn.Sequential(
            nn.Linear(state_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, action_dim), nn.Softmax(dim=-1)
        )
        self.value = nn.Sequential(
            nn.Linear(state_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
        self.parameters = {
            'policy': self.policy.state_dict(),
            'value': self.value.state_dict()
        }
        self.version = 0
        
    def get_parameters(self):
        return self.parameters, self.version
    
    def update_parameters(self, gradients, version):
        # Hogwild! style update
        with torch.no_grad():
            for name, param in self.policy.named_parameters():
                if name in gradients['policy']:
                    param.data -= 0.001 * gradients['policy'][name]
            for name, param in self.value.named_parameters():
                if name in gradients['value']:
                    param.data -= 0.001 * gradients['value'][name]
        self.version += 1
        return self.version

class RingBuffer:
    def __init__(self, capacity=1000):
        self.capacity = capacity
        self.buffer = [None] * capacity
        self.write_idx = 0
        self.read_idx = 0
        self.lock = threading.Lock()
        self.not_empty = threading.Condition(self.lock)
        self.not_full = threading.Condition(self.lock)
        
    def write(self, data):
        with self.not_full:
            while (self.write_idx + 1) % self.capacity == self.read_idx:
                self.not_full.wait()
            self.buffer[self.write_idx] = data
            self.write_idx = (self.write_idx + 1) % self.capacity
            self.not_empty.notify()
            
    def read(self):
        with self.not_empty:
            while self.write_idx == self.read_idx:
                self.not_empty.wait()
            data = self.buffer[self.read_idx]
            self.read_idx = (self.read_idx + 1) % self.capacity
            self.not_full.notify()
            return data
            
    def read_batch(self, batch_size):
        batch = []
        for _ in range(batch_size):
            batch.append(self.read())
        return batch

@ray.remote
class Actor:
    def __init__(self, env_name, ps_handle, actor_id):
        self.env = gym.make(env_name)
        self.ps = ps_handle
        self.actor_id = actor_id
        self.device = torch.device("cpu")
        self.policy = None
        self.version = 0
        
    def run(self, buffer_handle, num_steps=100000):
        buffer = ray.get(buffer_handle)
        local_buffer = []
        
        for step in range(num_steps):
            # Fetch latest parameters every N steps
            if step % 400 == 0:
                params, version = ray.get(self.ps.get_parameters.remote())
                if self.policy is None:
                    state_dim = self.env.observation_space.shape[0]
                    action_dim = self.env.action_space.n
                    self.policy = nn.Sequential(
                        nn.Linear(state_dim, 256), nn.ReLU(),
                        nn.Linear(256, action_dim), nn.Softmax(dim=-1)
                    )
                self.policy.load_state_dict(params['policy'])
                self.version = version
                
            state, _ = self.env.reset()
            done = False
            episode_data = []
            
            while not done:
                with torch.no_grad():
                    state_tensor = torch.FloatTensor(state).unsqueeze(0)
                    action_probs = self.policy(state_tensor)
                    action = torch.multinomial(action_probs, 1).item()
                    
                next_state, reward, terminated, truncated, _ = self.env.step(action)
                done = terminated or truncated
                episode_data.append((state, action, reward, next_state, done))
                state = next_state
                
                if len(episode_data) >= 20:  # Send trajectory
                    ray.get(buffer.write.remote(episode_data))
                    episode_data = []
                    
        return self.actor_id

class Learner:
    def __init__(self, state_dim, action_dim, ps_handle, device):
        self.device = device
        self.ps = ps_handle
        
        self.policy = nn.Sequential(
            nn.Linear(state_dim, 256), nn.ReLU(),
            nn.Linear(256, 256), nn.ReLU(),
            nn.Linear(256, action_dim), nn.Softmax(dim=-1)
        ).to(device)
        
        self.value = nn.Sequential(
            nn.Linear(state_dim, 256), nn.ReLU(),
            nn.Linear(256, 256), nn.ReLU(),
            nn.Linear(256, 1)
        ).to(device)
        
        self.optimizer = torch.optim.Adam(
            list(self.policy.parameters()) + list(self.value.parameters()), 
            lr=1e-3
        )
        
        self.gamma = 0.99
        self.clip_rho = 1.0
        
    def compute_gradients(self, batch):
        states = torch.FloatTensor(np.array([t[0] for traj in batch for t in traj])).to(self.device)
        actions = torch.LongTensor(np.array([t[1] for traj in batch for t in traj])).to(self.device)
        rewards = torch.FloatTensor(np.array([t[2] for traj in batch for t in traj])).to(self.device)
        next_states = torch.FloatTensor(np.array([t[3] for traj in batch for t in traj])).to(self.device)
        dones = torch.FloatTensor(np.array([t[4] for traj in batch for t in traj])).to(self.device)
        
        # Compute action probabilities
        action_probs = self.policy(states)
        log_probs = torch.log(action_probs.gather(1, actions.unsqueeze(1)) + 1e-10)
        
        with torch.no_grad():
            values = self.value(states).squeeze()
            next_values = self.value(next_states).squeeze()
            td_targets = rewards + self.gamma * next_values * (1 - dones)
            advantages = td_targets - values
            
        # Policy gradient
        policy_loss = -(log_probs.squeeze() * advantages).mean()
        
        # Value loss
        value_loss = F.mse_loss(self.value(states).squeeze(), td_targets)
        
        loss = policy_loss + 0.5 * value_loss
        self.optimizer.zero_grad()
        loss.backward()
        
        # Extract gradients
        grads = {
            'policy': {name: param.grad.clone() for name, param in self.policy.named_parameters()},
            'value': {name: param.grad.clone() for name, param in self.value.named_parameters()}
        }
        return grads
    
    def train(self, buffer_handle, num_updates=10000):
        buffer = ray.get(buffer_handle)
        
        for i in range(num_updates):
            batch = ray.get(buffer.read_batch.remote(32))
            gradients = self.compute_gradients(batch)
            new_version = ray.get(self.ps.update_parameters.remote(gradients, i))
            
            # Sync local parameters
            params, _ = ray.get(self.ps.get_parameters.remote())
            self.policy.load_state_dict(params['policy'])
            self.value.load_state_dict(params['value'])
            
            if i % 100 == 0:
                print(f"Learner update {i}, version {new_version}")

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--num_actors", default=4, type=int)
    parser.add_argument("--env", default="CartPole-v1")
    parser.add_argument("--learner_gpus", default=0, type=int)
    args = parser.parse_args()
    
    env = gym.make(args.env)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n
    
    ps_handle = ParameterServer.remote(state_dim, action_dim)
    buffer_handle = RingBuffer.remote(capacity=10000)
    
    device = torch.device("cuda" if args.learner_gpus > 0 and torch.cuda.is_available() else "cpu")
    learner = Learner(state_dim, action_dim, ps_handle, device)
    
    actors = [Actor.remote(args.env, ps_handle, i) for i in range(args.num_actors)]
    
    # Start actors
    actor_futures = [actor.run.remote(buffer_handle) for actor in actors]
    
    # Start learner
    learner.train(buffer_handle)
    
    ray.get(actor_futures)

if __name__ == "__main__":
    main()
  • 2.3.1.2 V-trace偏置校正:重要性采样比率(ρ)截断(c=1.0)与Actor-Critic共享卷积编码器的梯度分离技术

V-trace偏差校正算法通过截断重要性采样比率,在分布式演员-学习者架构中纠正由于策略滞后引起的价值估计偏差。该方法计算从行为策略到目标策略的逐动作重要性权重,通过上限截断防止方差爆炸,同时保持梯度的无偏性。在共享卷积编码器的场景下,梯度分离技术确保策略网络的更新不干扰用于价值估计的表征学习,通过停止梯度算子阻断不必要的梯度回传。

Python

复制

"""
V-trace实现脚本
涉及内容:重要性采样比率截断、V-trace目标计算、Retrace-style校正、策略与价值梯度分离
使用方式:python vtrace.py --batch_size 64 --trajectory_length 20
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import argparse

class PolicyNetwork(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=256):
        super().__init__()
        self.shared_encoder = nn.Sequential(
            nn.Linear(state_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU()
        )
        self.policy_head = nn.Linear(hidden_dim, action_dim)
        self.value_head = nn.Linear(hidden_dim, 1)
        
    def forward(self, state):
        features = self.shared_encoder(state)
        logits = self.policy_head(features)
        value = self.value_head(features.detach())  # Gradient isolation
        return F.softmax(logits, dim=-1), value
    
    def get_features(self, state):
        return self.shared_encoder(state)

class VTrace:
    def __init__(self, clip_rho=1.0, clip_c=1.0, rho_bar=1.0):
        self.clip_rho = clip_rho
        self.clip_c = clip_c
        self.rho_bar = rho_bar
        
    def compute_vtrace_target(self, behaviour_policy_probs, target_policy_probs, 
                             actions, rewards, values, next_values, dones, gamma=0.99):
        """
        behaviour_policy_probs: [T, B, num_actions] or [T, B]
        target_policy_probs: [T, B, num_actions] or [T, B]
        actions: [T, B]
        rewards: [T, B]
        values: [T, B]
        next_values: [T, B]
        """
        T, B = actions.shape
        
        # Calculate importance sampling ratios
        if len(behaviour_policy_probs.shape) == 3:
            # Get probs for taken actions
            behaviour_probs = behaviour_policy_probs.gather(-1, actions.unsqueeze(-1)).squeeze(-1)
            target_probs = target_policy_probs.gather(-1, actions.unsqueeze(-1)).squeeze(-1)
        else:
            behaviour_probs = behaviour_policy_probs
            target_probs = target_policy_probs
            
        rho = target_probs / (behaviour_probs + 1e-10)
        
        # Clipping for variance reduction
        clipped_rho = torch.clamp(rho, max=self.clip_rho)
        clipped_c = torch.clamp(rho, max=self.clip_c)
        
        # Calculate td errors
        td_errors = clipped_rho * (rewards + gamma * next_values * (1 - dones) - values)
        
        # Backward recursion for v-trace targets
        vs = torch.zeros_like(values)
        vs[-1] = next_values[-1] + td_errors[-1]
        
        for t in reversed(range(T - 1)):
            vs[t] = values[t] + td_errors[t] + gamma * clipped_c[t] * (vs[t+1] - next_values[t])
            
        # Policy gradient advantages
        pg_advantages = clipped_rho * (rewards + gamma * next_values * (1 - dones) - values) + \
                       (clipped_rho - 1) * (vs - values)
        
        return vs, pg_advantages, clipped_rho

class VTraceAgent:
    def __init__(self, state_dim, action_dim, device, clip_rho=1.0, clip_c=1.0):
        self.device = device
        self.policy = PolicyNetwork(state_dim, action_dim).to(device)
        self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=3e-4)
        self.vtrace = VTrace(clip_rho=clip_rho, clip_c=clip_c)
        
    def update(self, trajectories):
        # trajectories: dict with states, actions, rewards, behaviour_probs, target_probs, dones
        states = torch.FloatTensor(trajectories['states']).to(self.device)
        actions = torch.LongTensor(trajectories['actions']).to(self.device)
        rewards = torch.FloatTensor(trajectories['rewards']).to(self.device)
        behaviour_probs = torch.FloatTensor(trajectories['behaviour_probs']).to(self.device)
        dones = torch.FloatTensor(trajectories['dones']).to(self.device)
        
        # Forward pass
        target_probs, values = self.policy(states)
        _, next_values = self.policy(torch.cat([states[1:], states[:1]], 0))  # Simple next value estimate
        
        # Compute V-trace targets
        vs, pg_advantages, rhos = self.vtrace.compute_vtrace_target(
            behaviour_probs, target_probs, actions, rewards, values.squeeze(), 
            next_values.squeeze(), dones
        )
        
        # Policy loss
        log_probs = torch.log(target_probs.gather(-1, actions.unsqueeze(-1)).squeeze(-1) + 1e-10)
        policy_loss = -(log_probs * pg_advantages.detach()).mean()
        
        # Value loss
        value_loss = F.mse_loss(values.squeeze(), vs.detach())
        
        # Entropy bonus
        entropy = -(target_probs * torch.log(target_probs + 1e-10)).sum(dim=-1).mean()
        
        total_loss = policy_loss + 0.5 * value_loss - 0.01 * entropy
        
        self.optimizer.zero_grad()
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.policy.parameters(), 40.0)
        self.optimizer.step()
        
        return {
            'policy_loss': policy_loss.item(),
            'value_loss': value_loss.item(),
            'mean_rho': rhos.mean().item(),
            'mean_adv': pg_advantages.mean().item()
        }

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--state_dim", default=4, type=int)
    parser.add_argument("--action_dim", default=2, type=int)
    parser.add_argument("--batch_size", default=64, type=int)
    parser.add_argument("--trajectory_length", default=20, type=int)
    parser.add_argument("--clip_rho", default=1.0, type=float)
    parser.add_argument("--clip_c", default=1.0, type=float)
    args = parser.parse_args()
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    agent = VTraceAgent(args.state_dim, args.action_dim, device, args.clip_rho, args.clip_c)
    
    # Simulate trajectories
    for step in range(1000):
        traj = {
            'states': np.random.randn(args.trajectory_length, args.batch_size, args.state_dim),
            'actions': np.random.randint(0, args.action_dim, (args.trajectory_length, args.batch_size)),
            'rewards': np.random.randn(args.trajectory_length, args.batch_size),
            'behaviour_probs': np.random.uniform(0.3, 0.7, (args.trajectory_length, args.batch_size)),
            'dones': np.random.randint(0, 2, (args.trajectory_length, args.batch_size)).astype(float)
        }
        metrics = agent.update(traj)
        if step % 100 == 0:
            print(f"Step {step}: {metrics}")

if __name__ == "__main__":
    main()
2.3.2 GPU端到端加速
  • 2.3.2.1 Isaac Gym/Brax物理引擎:全GPU并行模拟( thousands of envs )与JAX JIT编译的端到端训练流水线
Isaac Gym与Brax物理引擎利用图形处理器的大规模并行计算能力,实现数千个环境实例的同时模拟,消除中央处理器与图形处理器之间的数据传输延迟。通过将物理模拟、状态计算与策略推理全部置于图形处理器端执行,并结合即时编译技术对计算图进行优化,该方法实现了端到端的训练流水线,在维持物理仿真精度的同时达到百万级帧每秒的采样吞吐量。
Python
复制
"""
Isaac Gym/Brax GPU端到端训练实现脚本
涉及内容:JAX JIT编译、GPU并行环境、PPO算法、向量环境批处理、端到端训练
使用方式:python brax_train.py --env humanoid --num_envs 4096 --jit_compile True
"""

import jax
import jax.numpy as jnp
from jax import random, jit, grad, vmap
import flax.linen as nn
from flax.training import train_state
import optax
import numpy as np
import argparse

class ActorCritic(nn.Module):
    action_dim: int
    
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(256)(x)
        x = nn.relu(x)
        x = nn.Dense(256)(x)
        x = nn.relu(x)
        
        logits = nn.Dense(self.action_dim)(x)
        value = nn.Dense(1)(x)
        return logits, value.squeeze()

class BraxEnv:
    def __init__(self, env_name, num_envs):
        self.env_name = env_name
        self.num_envs = num_envs
        # In real implementation, use brax.envs.create
        self.state_dim = 87  # Humanoid
        self.action_dim = 21
        
    def reset(self, rng):
        # Parallel reset
        return jax.random.normal(rng, (self.num_envs, self.state_dim))
    
    def step(self, state, action):
        # Parallel step (simplified physics)
        # In real Brax: return env.step(state, action)
        next_state = state + 0.01 * action.sum(axis=-1, keepdims=True)
        reward = -jnp.sum(next_state**2, axis=-1)  # Dummy reward
        done = jnp.zeros(self.num_envs)
        return next_state, reward, done

@jit
def select_action(params, state, rng):
    logits, value = ActorCritic(action_dim=21).apply({'params': params}, state)
    rng, key = random.split(rng)
    action = random.categorical(key, logits)
    log_prob = jax.nn.log_softmax(logits)[jnp.arange(len(action)), action]
    return action, log_prob, value, logits, rng

@jit
def compute_gae(rewards, values, dones, gamma=0.99, gae_lambda=0.95):
    def _gae_step(carry, transition):
        gae, next_value = carry
        reward, done, value = transition
        delta = reward + gamma * next_value * (1 - done) - value
        gae = delta + gamma * gae_lambda * (1 - done) * gae
        return (gae, value), gae
    
    values = jnp.concatenate([values, jnp.zeros((1,))])
    _, advantages = jax.lax.scan(
        _gae_step, 
        (jnp.zeros(()), values[-1]), 
        (rewards, dones, values[:-1]), 
        reverse=True
    )
    return advantages

def ppo_loss(params, batch, clip_epsilon=0.2, value_coef=0.5, entropy_coef=0.01):
    states, actions, old_log_probs, advantages, returns = batch
    
    logits, values = ActorCritic(action_dim=21).apply({'params': params}, states)
    log_probs = jax.nn.log_softmax(logits)
    log_probs_actions = log_probs[jnp.arange(len(actions)), actions]
    
    # Policy loss
    ratio = jnp.exp(log_probs_actions - old_log_probs)
    clipped_ratio = jnp.clip(ratio, 1 - clip_epsilon, 1 + clip_epsilon)
    policy_loss = -jnp.mean(jnp.minimum(ratio * advantages, clipped_ratio * advantages))
    
    # Value loss
    value_loss = jnp.mean((values - returns) ** 2)
    
    # Entropy
    entropy = -jnp.mean(jnp.sum(jnp.exp(log_probs) * log_probs, axis=-1))
    
    return policy_loss + value_coef * value_loss - entropy_coef * entropy

@jit
def train_step(state, batch):
    loss, grads = jax.value_and_grad(ppo_loss)(state.params, batch)
    state = state.apply_gradients(grads=grads)
    return state, loss

class GPUParallelTrainer:
    def __init__(self, env_name, num_envs=4096, learning_rate=3e-4):
        self.env = BraxEnv(env_name, num_envs)
        self.num_envs = num_envs
        
        # Initialize network
        rng = random.PRNGKey(0)
        dummy_input = jnp.zeros((1, self.env.state_dim))
        net = ActorCritic(action_dim=self.env.action_dim)
        params = net.init(rng, dummy_input)['params']
        
        self.state = train_state.TrainState.create(
            apply_fn=net.apply,
            params=params,
            tx=optax.adam(learning_rate)
        )
        
        self.rng = rng
        
    def rollout(self, horizon=128):
        # Vectorized rollout across all envs
        self.rng, step_rng = random.split(self.rng)
        states = self.env.reset(step_rng)
        
        trajectories = {
            'states': [], 'actions': [], 'rewards': [],
            'log_probs': [], 'values': [], 'dones': []
        }
        
        for _ in range(horizon):
            self.rng, act_rng = random.split(self.rng)
            actions, log_probs, values, _, _ = select_action(
                self.state.params, states, act_rng
            )
            
            next_states, rewards, dones = self.env.step(states, actions)
            
            trajectories['states'].append(states)
            trajectories['actions'].append(actions)
            trajectories['rewards'].append(rewards)
            trajectories['log_probs'].append(log_probs)
            trajectories['values'].append(values)
            trajectories['dones'].append(dones)
            
            states = next_states
            
        # Convert to arrays
        for key in trajectories:
            trajectories[key] = jnp.stack(trajectories[key], axis=0)
            
        # Compute advantages
        advantages = compute_gae(
            trajectories['rewards'],
            trajectories['values'],
            trajectories['dones']
        )
        
        returns = advantages + trajectories['values']
        
        # Flatten batch dimensions
        batch = (
            trajectories['states'].reshape(-1, self.env.state_dim),
            trajectories['actions'].reshape(-1),
            trajectories['log_probs'].reshape(-1),
            advantages.reshape(-1),
            returns.reshape(-1)
        )
        
        return batch
    
    def train(self, num_iterations=1000, batch_size=8192):
        for iteration in range(num_iterations):
            # Collect data
            batch = self.rollout(horizon=128)
            
            # Mini-batch updates
            n_samples = batch[0].shape[0]
            indices = np.random.permutation(n_samples)
            
            for start in range(0, n_samples, batch_size):
                end = start + batch_size
                mb_indices = indices[start:end]
                mb = tuple(x[mb_indices] for x in batch)
                
                self.state, loss = train_step(self.state, mb)
                
            if iteration % 10 == 0:
                print(f"Iteration {iteration}, Loss: {loss:.4f}")

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--env", default="humanoid")
    parser.add_argument("--num_envs", default=4096, type=int)
    parser.add_argument("--learning_rate", default=3e-4, type=float)
    args = parser.parse_args()
    
    trainer = GPUParallelTrainer(args.env, args.num_envs, args.learning_rate)
    trainer.train()

if __name__ == "__main__":
    main()
  • 2.3.2.2 RLlib中的SampleCollector优化:通过压缩观测(LZ4/ObsCompression)与异步梯度计算减少CPU-GPU传输瓶颈

RLlib中的样本收集器通过压缩观测数据与异步梯度计算优化分布式训练流水线。LZ4压缩算法在数据传输前对高维观测进行实时压缩,降低网络带宽占用;异步梯度计算允许中央处理器在图形处理器执行前向与反向传播的同时准备下一批次数据,通过计算与通信的重叠消除流水线气泡,最大化硬件利用率。

Python

复制

"""
RLlib SampleCollector优化实现脚本
涉及内容:LZ4观测压缩、异步梯度计算、样本压缩与解压、流水线优化
使用方式:python rllib_optimized.py --compress_observations True --async_gradients True
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import lz4.frame
import threading
import queue
import time
import argparse
from typing import Dict, List, Any

class LZ4Compressor:
    @staticmethod
    def compress(observation: np.ndarray) -> bytes:
        return lz4.frame.compress(observation.tobytes())
    
    @staticmethod
    def decompress(compressed: bytes, shape: tuple, dtype: np.dtype) -> np.ndarray:
        decompressed = lz4.frame.decompress(compressed)
        return np.frombuffer(decompressed, dtype=dtype).reshape(shape)

class AsyncGradientComputer:
    def __init__(self, model, device):
        self.model = model
        self.device = device
        self.queue = queue.Queue(maxsize=2)
        self.result_queue = queue.Queue()
        self.thread = threading.Thread(target=self._compute_loop, daemon=True)
        self.running = False
        
    def start(self):
        self.running = True
        self.thread.start()
        
    def stop(self):
        self.running = False
        self.thread.join()
        
    def _compute_loop(self):
        while self.running:
            try:
                batch = self.queue.get(timeout=1.0)
                # Move to GPU and compute
                states = batch['states'].to(self.device)
                actions = batch['actions'].to(self.device)
                returns = batch['returns'].to(self.device)
                advantages = batch['advantages'].to(self.device)
                
                loss = self._compute_loss(states, actions, returns, advantages)
                loss.backward()
                
                # Extract gradients
                grads = {name: param.grad.clone() for name, param in self.model.named_parameters()}
                self.result_queue.put(grads)
                
                # Zero grads for next iteration
                self.model.zero_grad()
            except queue.Empty:
                continue
                
    def _compute_loss(self, states, actions, returns, advantages):
        logits, values = self.model(states)
        log_probs = F.log_softmax(logits, dim=-1)
        log_probs_actions = log_probs.gather(1, actions.unsqueeze(1)).squeeze()
        
        policy_loss = -(log_probs_actions * advantages).mean()
        value_loss = F.mse_loss(values, returns)
        
        return policy_loss + 0.5 * value_loss
    
    def submit_batch(self, batch):
        self.queue.put(batch, block=True)
        
    def get_gradients(self):
        return self.result_queue.get(block=True)

class OptimizedSampleCollector:
    def __init__(self, state_shape, action_dim, use_compression=True, use_async=True, device='cuda'):
        self.state_shape = state_shape
        self.action_dim = action_dim
        self.use_compression = use_compression
        self.device = device
        
        self.model = nn.Sequential(
            nn.Linear(np.prod(state_shape), 256), nn.ReLU(),
            nn.Linear(256, 256), nn.ReLU(),
            nn.Linear(256, action_dim + 1)  # logits + value
        ).to(device)
        
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=3e-4)
        
        if use_async:
            self.async_computer = AsyncGradientComputer(self.model, device)
            self.async_computer.start()
        else:
            self.async_computer = None
            
        self.compressor = LZ4Compressor() if use_compression else None
        self.buffer = []
        
    def collect_sample(self, env, policy):
        # Collect with optional compression
        state, _ = env.reset()
        trajectory = []
        
        done = False
        while not done:
            state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
            with torch.no_grad():
                output = self.model(state_tensor)
                logits = output[:, :-1]
                value = output[:, -1]
                action_probs = F.softmax(logits, dim=-1)
                action = torch.multinomial(action_probs, 1).item()
                
            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            
            if self.use_compression:
                compressed_state = self.compressor.compress(state)
                compressed_next = self.compressor.compress(next_state)
                trajectory.append({
                    'state': compressed_state,
                    'action': action,
                    'reward': reward,
                    'next_state': compressed_next,
                    'done': done,
                    'state_shape': state.shape,
                    'state_dtype': state.dtype
                })
            else:
                trajectory.append({
                    'state': state,
                    'action': action,
                    'reward': reward,
                    'next_state': next_state,
                    'done': done
                })
                
            state = next_state
            
        return trajectory
    
    def process_batch(self, trajectories: List[List[Dict]]):
        # Decompress if needed and prepare batch
        all_states = []
        all_actions = []
        all_returns = []
        all_advantages = []
        
        gamma = 0.99
        gae_lambda = 0.95
        
        for traj in trajectories:
            states = []
            for step in traj:
                if self.use_compression:
                    state = self.compressor.decompress(
                        step['state'], step['state_shape'], step['state_dtype']
                    )
                else:
                    state = step['state']
                states.append(state)
                
            states = np.array(states)
            rewards = np.array([step['reward'] for step in traj])
            actions = np.array([step['action'] for step in traj])
            dones = np.array([step['done'] for step in traj])
            
            # Compute returns and advantages
            returns = np.zeros_like(rewards)
            advantages = np.zeros_like(rewards)
            
            # Simplified GAE computation
            returns[-1] = rewards[-1]
            for t in reversed(range(len(rewards) - 1)):
                returns[t] = rewards[t] + gamma * returns[t+1] * (1 - dones[t])
                
            # Normalize advantages
            advantages = returns - returns.mean()
            if returns.std() > 0:
                advantages = advantages / (returns.std() + 1e-8)
                
            all_states.append(states)
            all_actions.append(actions)
            all_returns.append(returns)
            all_advantages.append(advantages)
            
        batch = {
            'states': torch.FloatTensor(np.concatenate(all_states)),
            'actions': torch.LongTensor(np.concatenate(all_actions)),
            'returns': torch.FloatTensor(np.concatenate(all_returns)),
            'advantages': torch.FloatTensor(np.concatenate(all_advantages))
        }
        
        return batch
    
    def train_step(self, batch):
        if self.async_computer:
            # Async path: submit and immediately get previous gradients
            self.async_computer.submit_batch(batch)
            try:
                grads = self.async_computer.result_queue.get_nowait()
                # Apply gradients
                for name, param in self.model.named_parameters():
                    if name in grads:
                        param.grad = grads[name]
                self.optimizer.step()
                self.optimizer.zero_grad()
            except queue.Empty:
                pass
        else:
            # Sync path
            states = batch['states'].to(self.device)
            actions = batch['actions'].to(self.device)
            returns = batch['returns'].to(self.device)
            advantages = batch['advantages'].to(self.device)
            
            output = self.model(states)
            logits = output[:, :-1]
            values = output[:, -1]
            
            log_probs = F.log_softmax(logits, dim=-1)
            log_probs_actions = log_probs.gather(1, actions.unsqueeze(1)).squeeze()
            
            policy_loss = -(log_probs_actions * advantages).mean()
            value_loss = F.mse_loss(values, returns)
            loss = policy_loss + 0.5 * value_loss
            
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--compress_observations", action="store_true")
    parser.add_argument("--async_gradients", action="store_true")
    parser.add_argument("--num_envs", default=4, type=int)
    parser.add_argument("--batch_size", default=64, type=int)
    args = parser.parse_args()
    
    import gymnasium as gym
    
    collector = OptimizedSampleCollector(
        state_shape=(4,),
        action_dim=2,
        use_compression=args.compress_observations,
        use_async=args.async_gradients,
        device='cuda' if torch.cuda.is_available() else 'cpu'
    )
    
    envs = [gym.make("CartPole-v1") for _ in range(args.num_envs)]
    
    # Collection loop
    for iteration in range(100):
        trajectories = []
        for env in envs:
            traj = collector.collect_sample(env, collector.model)
            trajectories.append(traj)
            
        batch = collector.process_batch(trajectories)
        
        # Split into mini-batches
        n = len(batch['states'])
        indices = torch.randperm(n)
        for start in range(0, n, args.batch_size):
            end = min(start + args.batch_size, n)
            mb_idx = indices[start:end]
            mb = {k: v[mb_idx] for k, v in batch.items()}
            collector.train_step(mb)
            
        if iteration % 10 == 0:
            print(f"Iteration {iteration} completed")
            
    if args.async_gradients:
        collector.async_computer.stop()

if __name__ == "__main__":
    main()

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐