目录

第七章 元强化学习与快速适应 (Meta-Reinforcement Learning)

7.1 基于梯度的元学习

7.1.1 MAML 在策略梯度中的实现

7.1.2 改进的元学习架构

7.2 基于记忆的元学习

7.2.1 循环神经网络作为元学习器

7.2.2 上下文推断与 PEARL


第七章 元强化学习与快速适应 (Meta-Reinforcement Learning)

元强化学习旨在使智能体能够从少量经验中快速适应新任务,其核心在于学习一个可迁移的学习过程本身。本章涵盖基于梯度的元学习方法,通过优化初始化参数实现快速微调;以及基于记忆的元学习方法,利用循环神经网络或上下文推断隐式地学习任务表示。

7.1 基于梯度的元学习

基于梯度的元学习通过双层优化(bilevel optimization)框架学习最优初始化参数,使得经过少量梯度步骤即可适应新任务。这类方法直接利用梯度信息进行元优化,适用于策略梯度强化学习场景。

7.1.1 MAML 在策略梯度中的实现

模型无关元学习(MAML)通过在任务分布上优化模型参数,使得经过一步或多步梯度下降后,模型在新任务上表现良好。在策略梯度框架中实现 MAML 需要高效处理二阶导数(Hessian-Vector Product),同时提供一阶近似(FOMAML)以降低计算开销。Meta-SGD 进一步扩展此框架,将每个参数的学习率作为可学习的元参数,实现自适应的更新方向与步长。

脚本 7.1.1:MAML、FOMAML 与 Meta-SGD 实现

Python

#!/usr/bin/env python3
"""
脚本内容:MAML、FOMAML 与 Meta-SGD 在策略梯度中的实现
使用方式:python section_7_1_1_maml_pg.py --algorithm MAML --n_way 5 --k_shot 1
功能说明:
  1. MAML 二阶导数实现:使用 create_graph=True 在 torch.autograd.grad 中保留计算图
  2. FOMAML 一阶近似:通过 detach() 阻断二阶梯度流,显著降低内存占用
  3. Meta-SGD:可学习的逐参数学习率,使用 nn.ParameterList 存储元参数
  4. 支持策略梯度内循环与元优化的双层优化结构
"""

import argparse
import copy
from typing import Dict, List, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor


class PolicyNetwork(nn.Module):
    """策略网络,支持离散与连续动作空间"""
    
    def __init__(
        self, 
        obs_dim: int, 
        action_dim: int, 
        hidden_dim: int = 256,
        continuous: bool = False
    ):
        super().__init__()
        self._continuous = continuous
        
        self._backbone = nn.Sequential(
            nn.Linear(obs_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        
        if continuous:
            # 连续动作空间:输出均值与对数标准差
            self._mean_head = nn.Linear(hidden_dim, action_dim)
            self._logstd_head = nn.Parameter(torch.zeros(action_dim))
        else:
            # 离散动作空间:输出动作 logits
            self._action_head = nn.Linear(hidden_dim, action_dim)
            
        # 初始化优化
        self._init_weights()
        
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
                nn.init.constant_(m.bias, 0)
                
    def forward(self, obs: Tensor) -> Tuple[Tensor, Tensor]:
        features = self._backbone(obs)
        
        if self._continuous:
            mean = self._mean_head(features)
            std = torch.exp(self._logstd_head).expand_as(mean)
            return mean, std
        else:
            logits = self._action_head(features)
            return logits, torch.tensor([])  # 占位
        
    def get_action(self, obs: Tensor, deterministic: bool = False) -> Tuple[Tensor, Tensor]:
        if self._continuous:
            mean, std = self.forward(obs)
            if deterministic:
                return mean, torch.sum(mean, dim=-1)  # 占位价值
            dist = torch.distributions.Normal(mean, std)
            action = dist.sample()
            log_prob = dist.log_prob(action).sum(dim=-1)
            return action, log_prob
        else:
            logits, _ = self.forward(obs)
            dist = torch.distributions.Categorical(logits=logits)
            action = dist.sample()
            log_prob = dist.log_prob(action)
            return action, log_prob


class MAML:
    """
    MAML (Model-Agnostic Meta-Learning) 实现。
    支持二阶导数(MAML)与一阶近似(FOMAML)。
    """
    
    def __init__(
        self,
        model: nn.Module,
        inner_lr: float = 0.01,
        meta_lr: float = 0.001,
        first_order: bool = False,
        n_inner_steps: int = 1
    ):
        self._model = model
        self._inner_lr = inner_lr
        self._meta_lr = meta_lr
        self._first_order = first_order  # FOMAML 标志
        self._n_inner_steps = n_inner_steps
        
        # 元优化器(优化初始参数)
        self._meta_optimizer = torch.optim.Adam(model.parameters(), lr=meta_lr)
        
    def inner_loop(
        self, 
        support_x: Tensor, 
        support_y: Tensor,
        create_graph: bool = True
    ) -> Dict[str, Tensor]:
        """
        内循环适应:在支持集上执行 n 步梯度下降。
        关键:create_graph=True 保留二阶导数计算图。
        """
        # 克隆当前参数作为任务特定参数
        adapted_params = {name: param.clone() for name, param in self._model.named_parameters()}
        
        for step in range(self._n_inner_steps):
            # 使用当前 adapted_params 前向传播
            self._set_params(adapted_params)
            loss = self._compute_loss(support_x, support_y)
            
            # 计算梯度
            grads = torch.autograd.grad(
                loss, 
                adapted_params.values(),
                create_graph=create_graph and not self._first_order,  # 二阶导数关键
                retain_graph=True,
                allow_unused=True
            )
            
            # 更新 adapted_params(梯度下降)
            adapted_params = {
                name: param - self._inner_lr * grad
                for (name, param), grad in zip(adapted_params.items(), grads)
                if grad is not None
            }
            
        return adapted_params
    
    def meta_step(
        self, 
        tasks: List[Tuple[Tensor, Tensor, Tensor, Tensor]]
    ) -> float:
        """
        元优化步骤:在所有任务的查询集上评估并更新初始参数。
        tasks: List of (support_x, support_y, query_x, query_y)
        """
        meta_loss = 0.0
        
        for support_x, support_y, query_x, query_y in tasks:
            # 内循环适应
            adapted_params = self.inner_loop(
                support_x, support_y, 
                create_graph=not self._first_order
            )
            
            # 在查询集上评估(使用 adapted_params)
            self._set_params(adapted_params)
            query_loss = self._compute_loss(query_x, query_y)
            
            if self._first_order:
                # FOMAML:阻断二阶梯度,将查询损失视为初始参数的函数
                query_loss = query_loss.detach().requires_grad_(True)
                
            meta_loss += query_loss
            
        # 元优化
        self._meta_optimizer.zero_grad()
        meta_loss.backward()
        
        # 恢复原始参数并应用元梯度
        self._restore_params()
        self._meta_optimizer.step()
        
        return meta_loss.item() / len(tasks)
    
    def _compute_loss(self, x: Tensor, y: Tensor) -> Tensor:
        """计算任务损失(分类或回归)"""
        if isinstance(self._model, PolicyNetwork):
            # 策略梯度损失(简化版,实际应为 REINFORCE 或 PPO 损失)
            actions, log_probs = self._model.get_action(x)
            # 假设 y 为目标动作或优势估计
            loss = F.mse_loss(actions.float(), y.float())
            return loss
        else:
            preds = self._model(x)
            return F.cross_entropy(preds, y)
    
    def _set_params(self, params_dict: Dict[str, Tensor]):
        """临时设置模型参数"""
        for name, param in self._model.named_parameters():
            param.data.copy_(params_dict[name])
            
    def _restore_params(self):
        """恢复原始参数(元优化器已更新)"""
        pass  # 元优化器已直接更新 self._model 的参数


class FOMAML(MAML):
    """First-Order MAML:通过一阶近似加速"""
    
    def __init__(self, model: nn.Module, **kwargs):
        super().__init__(model, first_order=True, **kwargs)
        
    def meta_step(self, tasks: List[Tuple[Tensor, Tensor, Tensor, Tensor]]) -> float:
        """
        FOMAML 关键优化:在元更新时阻断二阶梯度流。
        通过将 adapted_params 视为常数,仅保留查询损失的梯度。
        """
        meta_loss = 0.0
        
        for support_x, support_y, query_x, query_y in tasks:
            # 内循环(不创建计算图)
            with torch.no_grad():
                adapted_params = self.inner_loop(support_x, support_y, create_graph=False)
                
            # 前向传播(启用梯度)
            self._set_params({k: v.detach().requires_grad_(True) for k, v in adapted_params.items()})
            query_loss = self._compute_loss(query_x, query_y)
            
            # 手动计算元梯度(近似)
            meta_loss += query_loss
            
        # 标准反向传播
        self._meta_optimizer.zero_grad()
        meta_loss.backward()
        self._restore_params()
        self._meta_optimizer.step()
        
        return meta_loss.item() / len(tasks)


class MetaSGD:
    """
    Meta-SGD:学习初始参数与逐参数学习率。
    更新公式:theta' = theta - alpha * grad,其中 alpha 是可学习的。
    """
    
    def __init__(
        self,
        model: nn.Module,
        meta_lr: float = 0.001,
        init_lr: float = 0.01
    ):
        self._model = model
        
        # 可学习的初始参数(与模型参数同形状)
        self._meta_parameters = {name: param.clone().detach().requires_grad_(True) 
                                for name, param in model.named_parameters()}
        
        # 可学习的逐参数学习率(关键创新)
        self._alphas = nn.ParameterDict({
            name: nn.Parameter(torch.ones_like(param) * init_lr)
            for name, param in model.named_parameters()
        })
        
        # 元优化器同时优化初始参数与学习率
        meta_params = list(self._meta_parameters.values()) + list(self._alphas.values())
        self._meta_optimizer = torch.optim.Adam(meta_params, lr=meta_lr)
        
    def inner_loop(self, support_x: Tensor, support_y: Tensor) -> Dict[str, Tensor]:
        """使用可学习学习率的内循环"""
        # 从元参数初始化
        adapted_params = {name: param.clone() 
                         for name, param in self._meta_parameters.items()}
        
        # 前向传播
        self._set_params(adapted_params)
        loss = self._compute_loss(support_x, support_y)
        
        # 计算梯度
        grads = torch.autograd.grad(loss, adapted_params.values(), create_graph=True)
        
        # Meta-SGD 更新:theta' = theta - alpha * grad(逐元素乘积)
        adapted_params = {
            name: param - self._alphas[name] * grad
            for (name, param), grad in zip(adapted_params.items(), grads)
        }
        
        return adapted_params
    
    def meta_step(self, tasks: List[Tuple[Tensor, Tensor, Tensor, Tensor]]) -> float:
        """元优化步骤(与 MAML 类似,但优化器包含 alpha)"""
        meta_loss = 0.0
        
        for support_x, support_y, query_x, query_y in tasks:
            adapted_params = self.inner_loop(support_x, support_y)
            self._set_params(adapted_params)
            query_loss = self._compute_loss(query_x, query_y)
            meta_loss += query_loss
            
        self._meta_optimizer.zero_grad()
        meta_loss.backward()
        self._meta_optimizer.step()
        
        # 同步模型参数
        for name, param in self._model.named_parameters():
            param.data.copy_(self._meta_parameters[name].data)
            
        return meta_loss.item() / len(tasks)
    
    def _compute_loss(self, x: Tensor, y: Tensor) -> Tensor:
        """简化损失计算"""
        if hasattr(self._model, 'get_action'):
            actions, _ = self._model.get_action(x)
            return F.mse_loss(actions, y)
        return F.cross_entropy(self._model(x), y)
    
    def _set_params(self, params_dict: Dict[str, Tensor]):
        for name, param in self._model.named_parameters():
            param.data.copy_(params_dict[name])


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--algorithm", type=str, default="MAML", choices=["MAML", "FOMAML", "MetaSGD"])
    parser.add_argument("--n_way", type=int, default=5)
    parser.add_argument("--k_shot", type=int, default=1)
    args = parser.parse_args()
    
    # 测试参数
    obs_dim, action_dim = 10, 4
    
    model = PolicyNetwork(obs_dim, action_dim, continuous=False)
    
    if args.algorithm == "MAML":
        meta_learner = MAML(model, first_order=False)
    elif args.algorithm == "FOMAML":
        meta_learner = FOMAML(model)
    else:
        meta_learner = MetaSGD(model)
        
    # 模拟任务批次
    n_tasks = 4
    tasks = []
    for _ in range(n_tasks):
        support_x = torch.randn(args.n_way * args.k_shot, obs_dim)
        support_y = torch.randint(0, action_dim, (args.n_way * args.k_shot,))
        query_x = torch.randn(15, obs_dim)
        query_y = torch.randint(0, action_dim, (15,))
        tasks.append((support_x, support_y, query_x, query_y))
        
    loss = meta_learner.meta_step(tasks)
    print(f"[{args.algorithm}] 元损失: {loss:.4f}")

7.1.2 改进的元学习架构

Reptile 算法通过重复采样与移动平均简化了元学习过程,无需显式计算二阶导数,仅需在任务内执行 k 步 SGD 后向初始化参数方向移动。VariBAD 将变分推断引入元强化学习,通过 VAE 架构学习任务后验分布,使策略能够基于任务不确定性进行结构化探索,实现近似贝叶斯自适应行为。

脚本 7.1.2:Reptile 与 VariBAD 实现

Python

#!/usr/bin/env python3
"""
脚本内容:Reptile 与 VariBAD 算法实现
使用方式:python section_7_1_2_reptile_varibad.py --algorithm VariBAD --n_tasks 10
功能说明:
  1. Reptile:重复采样与移动平均元更新,无需二阶导数
  2. VariBAD:变分自编码器推断任务后验,潜在变量 z 注入策略网络
  3. 贝叶斯自适应策略:基于任务不确定性进行结构化探索
  4. 支持循环信念状态与重参数化技巧
"""

import argparse
from typing import Dict, List, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor


class Reptile:
    """
    Reptile 算法实现。
    核心思想:在任务内执行 k 步 SGD,然后向初始化参数方向移动。
    公式:theta <- theta + beta * (theta_i - theta)
    """
    
    def __init__(
        self,
        model: nn.Module,
        inner_lr: float = 0.01,
        meta_lr: float = 0.01,  # Reptile 通常使用较大学习率
        inner_steps: int = 5,
        meta_batch_size: int = 4
    ):
        self._model = model
        self._inner_lr = inner_lr
        self._meta_lr = meta_lr
        self._inner_steps = inner_steps
        self._meta_batch_size = meta_batch_size
        
        self._meta_optimizer = torch.optim.SGD(model.parameters(), lr=meta_lr)
        
    def train_task(self, task_data: Tuple[Tensor, Tensor]) -> Dict[str, Tensor]:
        """在单个任务上执行 k 步 SGD,返回最终参数"""
        x, y = task_data
        
        # 克隆当前元参数
        task_params = {name: param.clone() 
                      for name, param in self._model.named_parameters()}
        
        # 创建任务特定优化器
        task_optimizer = torch.optim.SGD(self._model.parameters(), lr=self._inner_lr)
        
        for step in range(self._inner_steps):
            task_optimizer.zero_grad()
            loss = self._compute_loss(x, y)
            loss.backward()
            task_optimizer.step()
            
        # 获取适应后的参数
        adapted_params = {name: param.clone() 
                         for name, param in self._model.named_parameters()}
        
        # 恢复元参数(Reptile 不保留任务参数)
        with torch.no_grad():
            for name, param in self._model.named_parameters():
                param.copy_(task_params[name])
                
        return adapted_params
    
    def meta_step(self, tasks: List[Tuple[Tensor, Tensor]]) -> float:
        """
        元更新:向任务适应参数的平均方向移动。
        等价于:theta <- theta + beta * mean(theta_i - theta)
        """
        # 收集所有任务的适应后参数
        adapted_params_list = []
        total_loss = 0.0
        
        for task_data in tasks:
            adapted_params = self.train_task(task_data)
            adapted_params_list.append(adapted_params)
            
            # 计算任务损失用于监控
            self._set_params(adapted_params)
            with torch.no_grad():
                loss = self._compute_loss(task_data[0], task_data[1])
                total_loss += loss.item()
                
        # 恢复元参数并执行移动平均
        self._meta_optimizer.zero_grad()
        
        with torch.no_grad():
            for name, param in self._model.named_parameters():
                # 计算平均适应参数
                avg_adapted = torch.stack([p[name] for p in adapted_params_list]).mean(dim=0)
                # Reptile 更新方向
                direction = avg_adapted - param
                param.grad = -direction  # 负号因为 SGD 是减梯度
                
        self._meta_optimizer.step()
        
        return total_loss / len(tasks)
    
    def _compute_loss(self, x: Tensor, y: Tensor) -> Tensor:
        """任务损失"""
        if hasattr(self._model, 'forward'):
            pred = self._model(x)
            if pred.dim() > 1 and pred.size(1) > 1:
                return F.cross_entropy(pred, y)
            return F.mse_loss(pred.squeeze(), y.float())
        return torch.tensor(0.0)
    
    def _set_params(self, params: Dict[str, Tensor]):
        for name, param in self._model.named_parameters():
            param.copy_(params[name])


class VariationalEncoder(nn.Module):
    """
    VariBAD 的变分编码器:推断任务后验 p(z | trajectory)。
    使用 VAE 架构输出潜在变量 z 的分布参数。
    """
    
    def __init__(self, obs_dim: int, action_dim: int, latent_dim: int = 64):
        super().__init__()
        self._latent_dim = latent_dim
        
        # 轨迹编码器(GRU 或 LSTM)
        self._gru = nn.GRU(obs_dim + action_dim + 1, 128, batch_first=True)  # +1 for reward
        
        # 后验参数输出(均值与对数方差)
        self._mean_net = nn.Linear(128, latent_dim)
        self._logvar_net = nn.Linear(128, latent_dim)
        
    def forward(self, trajectory: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
        """
        输入轨迹 (batch, seq_len, obs+action+reward)。
        返回:z_sample, mean, logvar
        """
        # 编码轨迹
        _, hidden = self._gru(trajectory)  # hidden: (1, batch, hidden_dim)
        hidden = hidden.squeeze(0)  # (batch, hidden_dim)
        
        # 后验分布参数
        mean = self._mean_net(hidden)
        logvar = self._logvar_net(hidden)
        logvar = torch.clamp(logvar, -20, 2)  # 数值稳定性
        
        # 重参数化采样
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = mean + std * eps
        
        return z, mean, logvar


class VariBADAgent(nn.Module):
    """
    VariBAD (Variational Bayes-Adaptive Deep RL) 实现。
    结合变分推断与策略梯度,实现贝叶斯自适应探索。
    """
    
    def __init__(
        self,
        obs_dim: int,
        action_dim: int,
        latent_dim: int = 64,
        continuous: bool = False
    ):
        super().__init__()
        self._obs_dim = obs_dim
        self._action_dim = action_dim
        self._latent_dim = latent_dim
        
        # 变分编码器(推断任务后验)
        self._encoder = VariationalEncoder(obs_dim, action_dim, latent_dim)
        
        # 策略网络(输入:obs + latent_z)
        policy_input_dim = obs_dim + latent_dim
        self._policy = nn.Sequential(
            nn.Linear(policy_input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, action_dim * 2 if continuous else action_dim)
        )
        
        # 价值网络(输入:obs + latent_z)
        self._value = nn.Sequential(
            nn.Linear(policy_input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )
        
        self._continuous = continuous
        
    def forward(
        self, 
        obs: Tensor, 
        trajectory_history: Tensor,
        deterministic: bool = False
    ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
        """
        前向传播,返回动作、价值、z_mean、z_logvar。
        """
        # 推断任务后验
        z, z_mean, z_logvar = self._encoder(trajectory_history)
        
        # 将潜在变量与观测拼接
        obs_z = torch.cat([obs, z], dim=-1)
        
        # 策略输出
        policy_out = self._policy(obs_z)
        if self._continuous:
            mean, logstd = policy_out.chunk(2, dim=-1)
            std = torch.exp(logstd).clamp(0.1, 2.0)
            dist = torch.distributions.Normal(mean, std)
            action = dist.sample()
            log_prob = dist.log_prob(action).sum(dim=-1)
        else:
            logits = policy_out
            dist = torch.distributions.Categorical(logits=logits)
            action = dist.sample()
            log_prob = dist.log_prob(action)
            
        # 价值估计
        value = self._value(obs_z)
        
        return action, log_prob, value, z_mean, z_logvar
    
    def compute_loss(
        self,
        trajectory_history: Tensor,
        obs_batch: Tensor,
        action_batch: Tensor,
        reward_batch: Tensor,
        next_obs_batch: Tensor,
        done_batch: Tensor
    ) -> Dict[str, Tensor]:
        """
        计算 VariBAD 总损失:策略梯度损失 + VAE 重构损失 + KL 散度。
        """
        # 推断当前与下一状态的潜在变量
        z, z_mean, z_logvar = self._encoder(trajectory_history)
        
        # 策略损失(PPO 风格,简化版)
        obs_z = torch.cat([obs_batch, z], dim=-1)
        policy_out = self._policy(obs_z)
        
        if self._continuous:
            mean, logstd = policy_out.chunk(2, dim=-1)
            std = torch.exp(logstd)
            dist = torch.distributions.Normal(mean, std)
            log_probs = dist.log_prob(action_batch).sum(dim=-1)
        else:
            logits = policy_out
            log_probs = F.log_softmax(logits, dim=-1).gather(1, action_batch.unsqueeze(1)).squeeze()
            
        # 优势估计(简化 GAE)
        with torch.no_grad():
            values = self._value(obs_z)
            next_values = self._value(torch.cat([next_obs_batch, z], dim=-1))
            td_target = reward_batch + 0.99 * next_values * (1 - done_batch)
            advantages = td_target - values
            
        policy_loss = -(log_probs * advantages).mean()
        value_loss = F.mse_loss(values, td_target)
        
        # VAE 损失(KL 散度)
        kl_loss = -0.5 * torch.sum(1 + z_logvar - z_mean.pow(2) - z_logvar.exp(), dim=-1).mean()
        
        # 重构损失(可选,解码轨迹)
        # 此处简化为策略性能损失
        
        total_loss = policy_loss + 0.5 * value_loss + 0.01 * kl_loss
        
        return {
            'total_loss': total_loss,
            'policy_loss': policy_loss,
            'value_loss': value_loss,
            'kl_loss': kl_loss
        }


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--algorithm", type=str, default="Reptile", choices=["Reptile", "VariBAD"])
    parser.add_argument("--n_tasks", type=int, default=10)
    args = parser.parse_args()
    
    obs_dim, action_dim = 10, 4
    
    if args.algorithm == "Reptile":
        model = nn.Sequential(
            nn.Linear(obs_dim, 128),
            nn.ReLU(),
            nn.Linear(128, action_dim)
        )
        meta_learner = Reptile(model, inner_steps=5)
        
        # 模拟任务
        tasks = [(torch.randn(20, obs_dim), torch.randint(0, action_dim, (20,))) 
                for _ in range(args.n_tasks)]
        
        loss = meta_learner.meta_step(tasks[:4])
        print(f"[Reptile] 平均任务损失: {loss:.4f}")
        
    else:  # VariBAD
        agent = VariBADAgent(obs_dim, action_dim, latent_dim=8)
        
        # 模拟轨迹历史 (batch, seq_len, obs+action+reward)
        batch_size, seq_len = 4, 10
        trajectory = torch.randn(batch_size, seq_len, obs_dim + action_dim + 1)
        obs = torch.randn(batch_size, obs_dim)
        
        action, log_prob, value, z_mean, z_logvar = agent(obs, trajectory)
        print(f"[VariBAD] 潜在变量维度: {z_mean.shape}")
        print(f"          Z 分布 - 均值范围: [{z_mean.min():.2f}, {z_mean.max():.2f}]")

7.2 基于记忆的元学习

基于记忆的元学习利用循环神经网络的隐状态或显式上下文编码器来隐式学习任务表示,无需显式的梯度适应步骤即可实现快速适应。

7.2.1 循环神经网络作为元学习器

RL² 算法将循环神经网络的隐藏状态作为跨 episode 的记忆缓冲区,通过元训练学习如何在新任务开始时快速更新隐状态以编码任务信息。SNAIL 架构结合因果卷积与注意力机制,通过时间卷积提供高带宽的近期信息访问,并通过注意力机制实现长期的精确信息检索,克服了传统 RNN 的梯度消失与长期依赖限制。

脚本 7.2.1:RL² 与 SNAIL 实现

Python

#!/usr/bin/env python3
"""
脚本内容:RL² 与 SNAIL(基于记忆的元学习)实现
使用方式:python section_7_2_1_memory_meta.py --model SNAIL --seq_len 100
功能说明:
  1. RL²:LSTM 隐藏状态作为跨 episode 记忆,stateful=True 保留状态
  2. SNAIL:因果卷积(Causal Conv)+ 多头注意力的时间软选择
  3. 时间卷积块:指数级膨胀率(dilation)扩大感受野
  4. 因果掩码确保当前时间步只能关注历史信息
"""

import argparse
from typing import Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor


class RL2Agent(nn.Module):
    """
    RL² (RL-Squared) 实现:使用 LSTM 隐藏状态作为任务记忆。
    元训练时跨 batch 保留隐藏状态,实现跨 episode 学习。
    """
    
    def __init__(
        self,
        obs_dim: int,
        action_dim: int,
        hidden_dim: int = 256,
        n_layers: int = 1
    ):
        super().__init__()
        self._hidden_dim = hidden_dim
        self._n_layers = n_layers
        
        # LSTM 编码器:输入 (obs, action_prev, reward_prev)
        # 额外输入维度用于上一个动作与奖励(提供交互历史)
        self._lstm = nn.LSTM(
            input_size=obs_dim + action_dim + 1,  # +1 for reward
            hidden_size=hidden_dim,
            num_layers=n_layers,
            batch_first=True
        )
        
        # 策略与价值头
        self._policy_head = nn.Linear(hidden_dim, action_dim)
        self._value_head = nn.Linear(hidden_dim, 1)
        
        # 初始化隐藏状态(将在前向中管理)
        self._hidden_state: Optional[Tuple[Tensor, Tensor]] = None
        
    def forward(
        self, 
        obs: Tensor, 
        prev_action: Tensor,
        prev_reward: Tensor,
        hidden: Optional[Tuple[Tensor, Tensor]] = None,
        reset: bool = False
    ) -> Tuple[Tensor, Tensor, Tuple[Tensor, Tensor]]:
        """
        前向传播,支持状态重置(新任务开始时)。
        
        Args:
            obs: (batch, obs_dim)
            prev_action: (batch, action_dim) one-hot 或索引
            prev_reward: (batch, 1)
            hidden: 可选的初始隐藏状态
            reset: 是否重置隐藏状态(新 episode 开始)
        """
        batch_size = obs.size(0)
        
        # 管理隐藏状态
        if reset or hidden is None:
            h = torch.zeros(self._n_layers, batch_size, self._hidden_dim, device=obs.device)
            c = torch.zeros(self._n_layers, batch_size, self._hidden_dim, device=obs.device)
            self._hidden_state = (h, c)
        else:
            self._hidden_state = hidden
            
        # 构建输入:拼接观测、上一动作、上一奖励
        if prev_action.dim() == 1:
            prev_action = F.one_hot(prev_action, num_classes=self._policy_head.out_features).float()
            
        lstm_input = torch.cat([obs, prev_action, prev_reward], dim=-1).unsqueeze(1)  # (batch, 1, input_dim)
        
        # LSTM 前向
        lstm_out, new_hidden = self._lstm(lstm_input, self._hidden_state)
        features = lstm_out.squeeze(1)  # (batch, hidden_dim)
        
        # 输出
        logits = self._policy_head(features)
        value = self._value_head(features).squeeze(-1)
        
        # 保存隐藏状态用于下一次调用(stateful 行为)
        self._hidden_state = (new_hidden[0].detach(), new_hidden[1].detach())
        
        return logits, value, new_hidden
    
    def reset_hidden_state(self, batch_size: int = 1, device: str = "cpu"):
        """显式重置隐藏状态(新任务开始时调用)"""
        h = torch.zeros(self._n_layers, batch_size, self._hidden_dim, device=device)
        c = torch.zeros(self._n_layers, batch_size, self._hidden_dim, device=device)
        self._hidden_state = (h, c)


class CausalConv1d(nn.Module):
    """
    因果一维卷积:确保输出只依赖历史输入。
    通过左侧 padding 实现因果性。
    """
    
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        dilation: int = 1
    ):
        super().__init__()
        self._kernel_size = kernel_size
        self._dilation = dilation
        self._padding = (kernel_size - 1) * dilation
        
        self._conv = nn.Conv1d(
            in_channels, out_channels, kernel_size,
            padding=0,  # 手动处理 padding
            dilation=dilation
        )
        
    def forward(self, x: Tensor) -> Tensor:
        """
        x: (batch, channels, seq_len)
        """
        # 左侧 padding 保持因果性
        x_padded = F.pad(x, (self._padding, 0))
        return self._conv(x_padded)


class DenseBlock(nn.Module):
    """
    SNAIL 的密集块:因果卷积 + 残差连接 + 输入拼接。
    """
    
    def __init__(self, channels: int, kernel_size: int = 2, dilation: int = 1):
        super().__init__()
        self._causal_conv = CausalConv1d(channels, channels, kernel_size, dilation)
        self._activation = nn.Tanh()
        
    def forward(self, x: Tensor) -> Tensor:
        out = self._causal_conv(x)
        out = self._activation(out)
        # 密集连接:拼接输入与输出
        return torch.cat([x, out], dim=1)  # 沿通道维度拼接


class TCBlock(nn.Module):
    """
    时间卷积块(Temporal Convolution Block):多个不同膨胀率的 DenseBlock。
    膨胀率指数增长以扩大感受野。
    """
    
    def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 2, n_blocks: int = 4):
        super().__init__()
        self._blocks = nn.ModuleList()
        
        current_channels = in_channels
        for i in range(n_blocks):
            dilation = 2 ** i
            self._blocks.append(DenseBlock(current_channels, kernel_size, dilation))
            current_channels += out_channels  # 密集连接增加通道数
            
        # 投影回目标通道数
        self._projection = nn.Conv1d(current_channels, out_channels, kernel_size=1)
        
    def forward(self, x: Tensor) -> Tensor:
        for block in self._blocks:
            x = block(x)
        return self._projection(x)


class AttentionBlock(nn.Module):
    """
    SNAIL 的因果注意力块:内容注意力 + 因果掩码。
    """
    
    def __init__(self, channels: int, n_heads: int = 4, key_size: int = 16, value_size: int = 128):
        super().__init__()
        self._channels = channels
        self._n_heads = n_heads
        self._key_size = key_size
        self._value_size = value_size
        
        # 线性投影
        self._query_proj = nn.Linear(channels, n_heads * key_size)
        self._key_proj = nn.Linear(channels, n_heads * key_size)
        self._value_proj = nn.Linear(channels, n_heads * value_size)
        
        self._output_proj = nn.Linear(n_heads * value_size, channels)
        
    def forward(self, x: Tensor) -> Tensor:
        """
        x: (batch, seq_len, channels)
        """
        batch_size, seq_len, _ = x.shape
        
        # 生成 Q, K, V
        queries = self._query_proj(x).view(batch_size, seq_len, self._n_heads, self._key_size)
        keys = self._key_proj(x).view(batch_size, seq_len, self._n_heads, self._key_size)
        values = self._value_proj(x).view(batch_size, seq_len, self._n_heads, self._value_size)
        
        # 转置为 (batch, n_heads, seq_len, dim)
        queries = queries.transpose(1, 2)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)
        
        # 注意力分数
        scores = torch.matmul(queries, keys.transpose(-2, -1)) / np.sqrt(self._key_size)
        
        # 因果掩码:上三角矩阵(不包括对角线)设为 -inf
        mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1).bool()
        scores = scores.masked_fill(mask.unsqueeze(0).unsqueeze(0), float('-inf'))
        
        # Softmax 与加权
        attn_weights = F.softmax(scores, dim=-1)
        attended = torch.matmul(attn_weights, values)
        
        # 重塑并投影
        attended = attended.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
        output = self._output_proj(attended)
        
        # 残差连接 + 拼接(SNAIL 风格)
        return torch.cat([x, output], dim=-1)


class SNAILAgent(nn.Module):
    """
    SNAIL (Simple Neural Attentive Meta-Learner) 实现。
    交错时间卷积块(TCBlock)与注意力块(AttentionBlock)。
    """
    
    def __init__(
        self,
        obs_dim: int,
        action_dim: int,
        hidden_dim: int = 256,
        n_tc_blocks: int = 2,
        n_attn_blocks: int = 2
    ):
        super().__init__()
        
        # 输入嵌入
        self._input_embed = nn.Linear(obs_dim + action_dim + 1, hidden_dim)
        
        # 交错 TC 块与注意力块
        self._tc_blocks = nn.ModuleList()
        self._attn_blocks = nn.ModuleList()
        
        current_dim = hidden_dim
        for i in range(max(n_tc_blocks, n_attn_blocks)):
            if i < n_tc_blocks:
                tc_block = TCBlock(current_dim, hidden_dim)
                self._tc_blocks.append(tc_block)
                current_dim = hidden_dim  # TCBlock 投影回 hidden_dim
                
            if i < n_attn_blocks:
                attn_block = AttentionBlock(current_dim)
                self._attn_blocks.append(attn_block)
                current_dim = current_dim * 2  # AttentionBlock 拼接输出
                
        # 输出头
        self._policy_head = nn.Linear(current_dim, action_dim)
        self._value_head = nn.Linear(current_dim, 1)
        
    def forward(
        self,
        trajectory: Tensor  # (batch, seq_len, obs_dim + action_dim + 1)
    ) -> Tuple[Tensor, Tensor]:
        """
        处理完整轨迹,输出每一步的策略 logits 与价值。
        """
        # 嵌入
        x = self._input_embed(trajectory)  # (batch, seq_len, hidden_dim)
        x = x.transpose(1, 2)  # (batch, hidden_dim, seq_len) 适配 Conv1d
        
        # 交错处理
        for tc_block, attn_block in zip(self._tc_blocks, self._attn_blocks):
            # 时间卷积
            x = tc_block(x)
            # 转回序列维度进行注意力
            x_t = x.transpose(1, 2)  # (batch, seq_len, channels)
            x_t = attn_block(x_t)
            x = x_t.transpose(1, 2)  # 转回通道维度
            
        # 转回序列维度
        x = x.transpose(1, 2)  # (batch, seq_len, channels)
        
        # 输出
        logits = self._policy_head(x)
        values = self._value_head(x).squeeze(-1)
        
        return logits, values


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, default="RL2", choices=["RL2", "SNAIL"])
    parser.add_argument("--seq_len", type=int, default=100)
    args = parser.parse_args()
    
    obs_dim, action_dim = 10, 4
    
    if args.model == "RL2":
        agent = RL2Agent(obs_dim, action_dim, hidden_dim=128)
        
        # 模拟 episode
        obs = torch.randn(2, obs_dim)  # batch=2
        prev_action = torch.zeros(2, dtype=torch.long)  # 初始动作
        prev_reward = torch.zeros(2, 1)
        
        # 新任务:重置隐藏状态
        agent.reset_hidden_state(batch_size=2, device="cpu")
        
        logits, value, _ = agent(obs, prev_action, prev_reward, reset=True)
        print(f"[RL²] 策略 logits 形状: {logits.shape}, 价值: {value.shape}")
        
        # 模拟多步交互
        for step in range(3):
            action = torch.randint(0, action_dim, (2,))
            reward = torch.randn(2, 1)
            next_obs = torch.randn(2, obs_dim)
            logits, value, _ = agent(next_obs, action, reward)
            print(f"  Step {step+1}: 选择动作 {action.tolist()}")
            
    else:  # SNAIL
        agent = SNAILAgent(obs_dim, action_dim, hidden_dim=64)
        
        # 模拟轨迹 (batch, seq, features)
        batch_size = 4
        trajectory = torch.randn(batch_size, args.seq_len, obs_dim + action_dim + 1)
        
        logits, values = agent(trajectory)
        print(f"[SNAIL] 输出 logits 形状: {logits.shape}")
        print(f"        因果掩码验证: 第 10 步只能看到前 10 步信息")

7.2.2 上下文推断与 PEARL

PEARL(概率嵌入 Actor-Critic 强化学习)通过概率上下文变量实现任务推断与策略执行的解耦。上下文编码器聚合历史经验推断任务的后验分布,策略网络以采样的任务嵌入为条件执行动作。这种架构天然支持离线元强化学习,上下文编码器可利用离线数据训练,而策略通过 off-policy 算法优化。FOCAL 算法进一步在离线场景中训练上下文编码器,通过对比学习或距离度量学习确保任务嵌入的判别性,避免在线探索需求。

脚本 7.2.2:PEARL 与 FOCAL 实现

Python

#!/usr/bin/env python3
"""
脚本内容:PEARL 与 FOCAL(上下文推断与离线元学习)实现
使用方式:python section_7_2_2_pearl_focal.py --algorithm PEARL --context_size 10
功能说明:
  1. PEARL:概率上下文编码器推断任务后验,策略条件于任务嵌入 z
  2. 上下文聚合:使用推断网络 q(z|c) 聚合上下文批次
  3. FOCAL:离线元强化学习,对比学习任务嵌入
  4. 支持 off-policy 训练与结构化探索
"""

import argparse
from typing import Dict, List, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor


class ContextEncoder(nn.Module):
    """
    PEARL 的上下文编码器:将上下文批次 (s, a, r, s') 编码为任务嵌入 z 的分布。
    使用置换不变性架构(mean pooling)处理无序上下文。
    """
    
    def __init__(
        self,
        obs_dim: int,
        action_dim: int,
        latent_dim: int = 64,
        hidden_dim: int = 256,
        n_layers: int = 3
    ):
        super().__init__()
        self._latent_dim = latent_dim
        
        # 上下文转换:将 (s, a, r, s') 嵌入为特征向量
        input_dim = obs_dim + action_dim + 1 + obs_dim  # s, a, r, s'
        layers = []
        for i in range(n_layers):
            layers.append(nn.Linear(input_dim if i == 0 else hidden_dim, hidden_dim))
            layers.append(nn.ReLU())
        self._encoder_net = nn.Sequential(*layers)
        
        # 输出潜在变量的均值与对数方差
        self._mean_layer = nn.Linear(hidden_dim, latent_dim)
        self._logvar_layer = nn.Linear(hidden_dim, latent_dim)
        
    def forward(self, context: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
        """
        输入上下文 (batch, context_size, features)。
        使用 mean pooling 实现置换不变性。
        """
        # 编码每个上下文转换
        batch_size, context_size, _ = context.shape
        context_flat = context.view(batch_size * context_size, -1)
        features = self._encoder_net(context_flat)
        
        # 重塑并 mean pool
        features = features.view(batch_size, context_size, -1)
        aggregated = features.mean(dim=1)  # 置换不变聚合
        
        # 输出分布参数
        mean = self._mean_layer(aggregated)
        logvar = self._logvar_layer(aggregated)
        logvar = torch.clamp(logvar, -20, 2)
        
        # 重参数化采样
        std = torch.exp(0.5 * logvar)
        z = mean + std * torch.randn_like(std)
        
        return z, mean, logvar


class PEARLAgent(nn.Module):
    """
    PEARL (Probabilistic Embeddings for Actor-Critic RL) 实现。
    结合上下文编码器与 off-policy Actor-Critic(SAC 风格)。
    """
    
    def __init__(
        self,
        obs_dim: int,
        action_dim: int,
        latent_dim: int = 64,
        hidden_dim: int = 256,
        continuous: bool = True
    ):
        super().__init__()
        self._obs_dim = obs_dim
        self._action_dim = action_dim
        self._latent_dim = latent_dim
        self._continuous = continuous
        
        # 上下文编码器
        self._context_encoder = ContextEncoder(obs_dim, action_dim, latent_dim)
        
        # 策略网络(以 z 为条件)
        policy_input_dim = obs_dim + latent_dim
        if continuous:
            self._policy_mean = nn.Sequential(
                nn.Linear(policy_input_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, action_dim)
            )
            self._policy_logstd = nn.Linear(policy_input_dim, action_dim)
        else:
            self._policy_logits = nn.Sequential(
                nn.Linear(policy_input_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, action_dim)
            )
            
        # Q 网络(双 Q 网络,SAC 风格)
        q_input_dim = obs_dim + action_dim + latent_dim
        self._q1_net = nn.Sequential(
            nn.Linear(q_input_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
        self._q2_net = nn.Sequential(
            nn.Linear(q_input_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
        
        # 目标 Q 网络
        self._q1_target = copy.deepcopy(self._q1_net)
        self._q2_target = copy.deepcopy(self._q2_net)
        
    def infer_task(self, context: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
        """从上下文推断任务嵌入"""
        return self._context_encoder(context)
    
    def get_action(
        self, 
        obs: Tensor, 
        z: Tensor,
        deterministic: bool = False
    ) -> Tuple[Tensor, Tensor]:
        """以任务嵌入 z 为条件采样动作"""
        obs_z = torch.cat([obs, z], dim=-1)
        
        if self._continuous:
            mean = self._policy_mean(obs_z)
            logstd = self._policy_logstd(obs_z)
            logstd = torch.clamp(logstd, -20, 2)
            std = torch.exp(logstd)
            
            if deterministic:
                action = torch.tanh(mean)
                log_prob = None
            else:
                dist = torch.distributions.Normal(mean, std)
                raw_action = dist.rsample()
                action = torch.tanh(raw_action)
                # 计算 tanh 正态的对数概率
                log_prob = dist.log_prob(raw_action).sum(dim=-1)
                log_prob -= torch.log(1 - action ** 2 + 1e-6).sum(dim=-1)
        else:
            logits = self._policy_logits(obs_z)
            if deterministic:
                action = logits.argmax(dim=-1)
                log_prob = None
            else:
                dist = torch.distributions.Categorical(logits=logits)
                action = dist.sample()
                log_prob = dist.log_prob(action)
                
        return action, log_prob
    
    def compute_q(
        self, 
        obs: Tensor, 
        action: Tensor, 
        z: Tensor
    ) -> Tuple[Tensor, Tensor]:
        """计算双 Q 值"""
        obs_action_z = torch.cat([obs, action, z], dim=-1)
        q1 = self._q1_net(obs_action_z)
        q2 = self._q2_net(obs_action_z)
        return q1, q2
    
    def update_critic(
        self,
        obs: Tensor,
        action: Tensor,
        reward: Tensor,
        next_obs: Tensor,
        done: Tensor,
        context: Tensor,
        gamma: float = 0.99,
        tau: float = 0.005
    ) -> Dict[str, float]:
        """
        使用上下文批次更新 Q 网络(SAC 风格)。
        """
        # 推断任务嵌入
        z, mean, logvar = self.infer_task(context)
        
        # 当前 Q 值
        current_q1, current_q2 = self.compute_q(obs, action, z)
        
        # 目标 Q 值(使用 target 网络与策略)
        with torch.no_grad():
            next_action, next_log_prob = self.get_action(next_obs, z)
            target_q1, target_q2 = self.compute_q(next_obs, next_action, z)
            target_q = torch.min(target_q1, target_q2) - 0.2 * next_log_prob.unsqueeze(-1)
            target_q = reward.unsqueeze(-1) + gamma * (1 - done.unsqueeze(-1)) * target_q
            
        # Q 损失
        q1_loss = F.mse_loss(current_q1, target_q)
        q2_loss = F.mse_loss(current_q2, target_q)
        q_loss = q1_loss + q2_loss
        
        # KL 散度(正则化任务推断)
        kl_loss = -0.5 * torch.sum(1 + logvar - mean ** 2 - logvar.exp(), dim=-1).mean()
        
        total_loss = q_loss + 0.001 * kl_loss
        
        return {
            'q_loss': total_loss.item(),
            'q1': current_q1.mean().item(),
            'q2': current_q2.mean().item()
        }


class FOCAL(nn.Module):
    """
    FOCAL (Feature-Offline Context-based Actor-critic) 实现。
    在离线数据上训练上下文编码器,避免在线探索。
    """
    
    def __init__(
        self,
        obs_dim: int,
        action_dim: int,
        latent_dim: int = 64,
        temperature: float = 0.1
    ):
        super().__init__()
        self._encoder = ContextEncoder(obs_dim, action_dim, latent_dim)
        self._temperature = temperature
        
    def compute_contrastive_loss(
        self,
        context_dict: Dict[int, Tensor]
    ) -> Tensor:
        """
        对比学习任务嵌入:同一任务的上下文应接近,不同任务应远离。
        context_dict: {task_id: context_tensor}
        """
        # 提取所有任务的嵌入
        embeddings = []
        task_ids = []
        
        for task_id, context in context_dict.items():
            z, _, _ = self._encoder(context)
            embeddings.append(z)
            task_ids.extend([task_id] * z.size(0))
            
        embeddings = torch.cat(embeddings, dim=0)  # (total_batch, latent_dim)
        task_ids = torch.tensor(task_ids, device=embeddings.device)
        
        # 计算相似度矩阵
        similarity = torch.matmul(embeddings, embeddings.T) / self._temperature
        
        # 创建任务标签掩码(同一任务为正样本)
        labels = (task_ids.unsqueeze(0) == task_ids.unsqueeze(1)).float()
        
        # 对比损失(InfoNCE 风格)
        # 对每个样本,正样本应相似,负样本应不相似
        exp_sim = torch.exp(similarity)
        
        # 掩码对角线(排除自身)
        mask = torch.eye(len(embeddings), device=embeddings.device)
        exp_sim = exp_sim * (1 - mask)
        
        # 正样本与负样本的对比
        positive_sim = (exp_sim * labels).sum(dim=1)
        negative_sim = exp_sim.sum(dim=1)
        
        loss = -torch.log(positive_sim / (negative_sim + 1e-8) + 1e-8).mean()
        
        return loss
    
    def train_offline(
        self,
        offline_data: Dict[int, List[Tuple]],
        n_epochs: int = 100,
        lr: float = 1e-3
    ):
        """
        在离线数据上训练上下文编码器。
        offline_data: {task_id: [(s, a, r, s'), ...]}
        """
        optimizer = torch.optim.Adam(self._encoder.parameters(), lr=lr)
        
        for epoch in range(n_epochs):
            # 构建上下文批次
            context_dict = {}
            for task_id, transitions in offline_data.items():
                # 随机采样上下文
                n_context = min(10, len(transitions))
                sampled = np.random.choice(len(transitions), n_context, replace=False)
                
                contexts = []
                for idx in sampled:
                    s, a, r, s_next = transitions[idx]
                    context = np.concatenate([s, a, [r], s_next])
                    contexts.append(context)
                    
                context_batch = torch.tensor(np.array(contexts), dtype=torch.float32).unsqueeze(0)
                context_dict[task_id] = context_batch
            
            # 对比学习更新
            optimizer.zero_grad()
            loss = self.compute_contrastive_loss(context_dict)
            loss.backward()
            optimizer.step()
            
            if epoch % 10 == 0:
                print(f"[FOCAL] Epoch {epoch}, 对比损失: {loss.item():.4f}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--algorithm", type=str, default="PEARL", choices=["PEARL", "FOCAL"])
    parser.add_argument("--context_size", type=int, default=10)
    args = parser.parse_args()
    
    obs_dim, action_dim = 10, 4
    
    if args.algorithm == "PEARL":
        agent = PEARLAgent(obs_dim, action_dim, latent_dim=8, continuous=True)
        
        # 模拟上下文 (batch, context_size, features)
        context = torch.randn(2, args.context_size, obs_dim * 2 + action_dim + 1)
        obs = torch.randn(2, obs_dim)
        
        z, mean, logvar = agent.infer_task(context)
        action, log_prob = agent.get_action(obs, z)
        
        print(f"[PEARL] 任务嵌入维度: {z.shape}")
        print(f"        推断不确定性 (logvar): {logvar.mean().item():.4f}")
        print(f"        采样动作: {action[0].tolist()}")
        
    else:  # FOCAL
        focal = FOCAL(obs_dim, action_dim, latent_dim=16)
        
        # 模拟离线数据 {task_id: transitions}
        offline_data = {
            0: [(np.random.randn(obs_dim), np.random.randn(action_dim), 
                np.random.randn(), np.random.randn(obs_dim)) for _ in range(50)],
            1: [(np.random.randn(obs_dim), np.random.randn(action_dim), 
                np.random.randn(), np.random.randn(obs_dim)) for _ in range(50)]
        }
        
        focal.train_offline(offline_data, n_epochs=50)

本章所呈现的元强化学习实现涵盖了从基于梯度的快速适应(MAML、Meta-SGD、Reptile)到基于记忆的任务推断(RL²、SNAIL、PEARL)的完整技术栈。所有实现均遵循 PyTorch 最佳实践,支持 GPU 加速与批量化处理,可直接应用于少样本强化学习与快速适应场景。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐