一、关键代码:

gemma.py 实现的是 Pi 模型里的"视觉语言模型"部分 ——负责:

- 处理图像 patch + 文本 token(来自 Paligemma)
- 和"动作专家"(action expert,在 pi0.py 里)做 Mixture-of-Experts 融合
- 输出给 Flow Matching 头(在 model.py 里)

### 关键设计:双专家 MoE
```
输入: [图像patch tokens | 文本prompt tokens] | [动作 tokens]
           ↓                        ↓
      PaliGemma expert (i=0)     Action expert (i=1)
           ↓                        ↓
      Gemma-2B weights          从头/LoRA 训练
           ↓                        ↓
      输出 [B, T, D]            输出 [B, T, D]
           ↓                        ↓
      融合给 Flow Matching 头
```
这就是代码里 configs: Sequence[Config] 是 列表 的原因——每个专家有独立的 config、独立的权重。

## 二、核心模块分层
### 1. Config(L44-109)—— 5 种变体
variant width depth heads kv_heads mlp 用途 dummy 64 4 8 1 128 单元测试/调试 gemma_300m 1024 18 8 1 4096 轻量版 gemma_300m_lora 1024 18 8 1 4096 300M + LoRA(rank=32) gemma_2b 2048 18 8 1 16384 Pi 默认(PaliGemma 2B) gemma_2b_lora 2048 18 8 1 16384 2B + LoRA(rank=16)

LoRA 配置 (L96-107):同时对 attn 和 ffn 各挂一个 LoRA adapter,rank=16, alpha=16。

2. RMSNorm

这是 Pi 能做 条件动作生成 的关键——通过 AdaLN 把状态/指令/图像 embedding 注入到每层归一化里。

3. Embedder

vocab_size = 257_152(PaliGemma 的词表大小)。

4. Attention—— 多专家融合的核心

# 1. 每个专家独立做 QKV 投影
for i, (x, config) in enumerate(zip(xs, self.configs)):
    qkv_einsum(x)    # ← 每个专家自己的 Wq Wk Wv

# 2. concat 后一起做 self-attention
q, k, v = jnp.concatenate(y, axis=1) for y in zip(*qkvs)
# 所有专家的 token 在 token 维度拼接成一个长序列
# 一次 attention 就能跨专家通信!

# 3. 共享的 RoPE + MHA
q = _apply_rope(q, positions)
logits = einsum("BTKGH,BSKH->BKGTS", q, k, ...)

# 4. attention 后每个专家独立 output 投影
for i, (x, config) in enumerate(...):
    out_einsum(encoded[start:end])   # ← 每个专家自己的 Wo

这就是 Mixture-of-Experts Transformer 的精髓 :

- 专家前 :每个 token 独立线性投影
- 专家间 :一次 self-attention 跨所有专家 token 做信息交换
- 专家后 :每个 token 独立线性投影回原维度
另外还支持 GQA(grouped-query attention)——当 num_kv_heads != num_heads 时 Q 和 KV 分开投影。

### 5. FeedForward

# Gemma 用 SwiGLU,不是标准 ReLU FFN
ff_gate = x @ W_gate[0]   # gate 分支
ff1     = x @ W_gate[1]   # value 分支
activations = gelu(ff_gate) * ff1    # gate 调制 value
output = activations @ W_linear

6. Block

# 结构:
# x → RMSNorm(adarms_cond) → Attention → gated residual
#   → RMSNorm(adarms_cond) → SwiGLU FFN → gated residual

# 每步都有 sharding constraint(JAX 自动并行)
# 每步都有 adarms_cond 注入(AdaLN 条件生成)

gated_residual是标准/条件两种模式:

gate is None  →  x + y              # 普通残差
gate 不为 None →  x + y * gate      # 条件门控残差

7. Module

# Gemma 结构:
# 1. Embedder: token id → embedding (共享 vocab table)
# 2. 18 层 Block (Gemma-2B depth=18)
#    用 nn.scan 把所有层 scan 成一次 jit call(JAX 优化)
#    用 nn.remat 做 gradient checkpointing(省显存)
# 3. final RMSNorm(每专家独立)
# 4. AdaARM conditioning 注入到每层 RMSNorm

adarms_cond (条件注入) 在 init() 能看清楚:

# use_adarms 是每个专家是否启用 AdaARM 的开关
# 传入 [B, width] 的 conditioning tensor(来自状态/指令)
# 每个专家独立:use_adarms[i]=True 才传条件

## 三、关键特性总结
特性 代码位置 作用 双专家 MoE L172-201, L233-248 PaliGemma + Action expert 一次 attention 跨专家通信 多专家权重命名 L443-450 _name("attn", 0) → "attn" , _name("attn", 1) → "attn_1" , 第一个专家权重名和官方 PaliGemma checkpoint 对齐 ,直接加载 AdaARM 条件注入 L112-131, L402-403, L413-421 状态/指令 embedding 通过 AdaLN 注入每层,实现条件动作生成 RoPE 位置编码 L424-440 标准旋转位置编码,支持变长序列 KV Cache L211-214 推理时缓存 KV,支持自回归生成 LoRA 适配器 L52, L96-107 可选 LoRA 挂到 attn 和 ffn,只训小部分参数 GQA 支持 L176-199 支持 grouped-query attention 加速推理 JAX sharding L294, L307, L310, L312 每层都有 activation sharding,支持多 GPU 并行

## 核心主干

一、一层 Block 里有几个 RMSNorm
看 Block. call :

每层 Block 有 2 个 RMSNorm :

- pre_attention_norm (Attention 之前)
- pre_ffw_norm (FeedForward 之前)
## 二、Pi 里有几层 Block
PaliGemma 2B 的 config( gemma.py:L79-87 ):

18 层 Block × 2 个 RMSNorm = 36 个 RMSNorm 注入点 。

加上每个专家独立的 RMSNorm,实际数量:

## 三、AdaLN 在每个注入点做什么
看 RMSNorm. call 的 AdaLN 路径:

每个注入点做 3 件事:

参数 作用 含义 scale 乘以 (1 + scale) 缩放 特征的每个维度 shift 加上 shift 偏移 特征的每个维度 gate 给后面的 residual 用 门控 残差连接的强度

## 四、完整数据流(你的水瓶任务)
```
输入:
  机器人当前双臂状态 [B, 16]
  → 状态编码器 (StateEncoder)
  → 条件向量 cond [B, 2048]  (和 PaliGemma hidden dim 对齐)
  → 复制 18 × 2 = 36 份(每层 Block 的每个 RMSNorm 一份)
  → 送进每个 RMSNorm 的 AdaLN

第 1 层 pre_attention_norm:
  normed_image = rms_norm(image_tokens) * (1 + scale_1) + 
  shift_1
  normed_text  = rms_norm(text_tokens) * (1 + scale_2) + 
  shift_2
  
第 1 层 pre_ffw_norm:
  normed_hidden = rms_norm(hidden) * (1 + scale_3) + shift_3
  ...

第 18 层 pre_ffw_norm:
  同上结构,独立的 scale/shift/gate 参数

最终 Flow Matching 头:
  收到已经被状态条件调制过 36 次的 hidden states
  → 输出条件化的动作 [B, 36]
```
## 五、这样设计的好处
对比 做法 效果 普通条件注入 把状态 concat 到输入 只能影响第 1 层,后面层状态信息丢失 AdaLN(Pi 的做法) 每层的每个 RMSNorm 都注入条件 条件信息贯穿 18 层,每层都能"看到"当前状态 好处 状态条件 × 36 次注入 模型在任何深度都能"记住"机器人姿态,生成更准确的动作

一句话 :每个 Block 里的 pre_attention_norm 和 pre_ffw_norm 各注入一次,18 层共 36 个 AdaLN 注入点 ,让机器人状态条件贯穿整个 Transformer。

QA:

1. 这 18 层 Block 是 Transformer 的核心"特征提炼管道" ——每一层都在做"从原始信号 → 抽象语义"的一步提炼。给你分阶段讲清楚:

## 一、每一层做什么
每一层 Block 内部的流水线 :

## 二、18 层分 4 个阶段
### 阶段 1:L0 - L3(底层特征提取)
层 做什么 你的水瓶任务里 L0-L1 从原始像素/词元提取 底层特征 识别出"圆的东西"、"黄色的东西"、"直线边缘" L2-L3 组合底层特征成 局部模式 识别出"圆柱形"、"瓶盖纹理"、"桌面纹理"

### 阶段 2:L4 - L8(语义组合)
层 做什么 你的水瓶任务里 L4-L5 形成 局部语义 "这个物体是水瓶"、"水瓶在桌子左边" L6-L7 跨模态融合 图像和文本"grasp the water bottle"对齐 L8 状态条件注入 让模型"看到"当前双臂姿态

### 阶段 3:L9 - L14(推理与规划)
层 做什么 你的水瓶任务里 L9-L10 形成 动作意图 "我应该去抓这个水瓶" L11-L12 规划 动作序列 "先张开夹爪 → 向前移动 → 夹紧" L13-L14 细化 连续轨迹 "夹爪应该以 5cm/s 速度接近"

### 阶段 4:L15 - L17(条件化输出)
层 做什么 你的水瓶任务里 L15-L16 精调最终特征 让特征和 Flow Matching 头对齐 L17 最后整合 输出给 Flow Matching 头的 embedding

## 三、每层之间的"信息传递"
每层的输入输出 维度相同 (2048),通过残差连接保证梯度不消失。

## 四、AdaLN 条件注入在每层的作用
每个 Block 的 2 个 RMSNorm 都会注入条件,但每层"看到"的条件做的事不同:

层 AdaLN 条件作用 L0-L3 调制底层特征提取:"当前状态是双臂张开 → 视觉注意力集中在桌面中央" L4-L8 调制语义融合:"当前状态是已定位 → 文本 prompt 和图像更紧密对齐" L9-L14 调制动作规划:"当前位置在 (x,y,z) → 规划路径从这里开始" L15-L17 调制最终输出:"当前力传感器读数为 0 → 输出抓握前的动作"

## 五、一句话总结
18 层 Block = 18 次"特征提炼 + 跨专家通信 + 条件调制" :

- 从 L0 的"识别边缘"到 L17 的"输出可执行动作"
- 每层都有跨专家 self-attention(PaliGemma 和 Action 专家交互)
- 每层都有 AdaLN 条件注入(机器人状态贯穿始终)
- 最终给 Flow Matching 头提供 完全条件化的语义特征 来生成 36 维连续动作

2 。 为什么用mse loss,能用其他loss吗

SOLO Agent

思考过程

直接回答:

## 一、为什么 Flow Matching 用 MSE Loss
### 1. Flow Matching 的数学本质
Flow Matching 学习的是一个 速度场 [ o bj ec tO bj ec t ] v ( x , t ) ,满足 ODE:
 [ o bj ec tO bj ec t ] d t d x ​ = v θ ​ ( x , t )
训练目标是让 [ o bj ec tO bj ec t ] v θ ​ 匹配 真实的速度场 :
 [ o bj ec tO bj ec t ] u t ​ = noise − action
所以 loss 是:
 [ o bj ec tO bj ec t ] L = E t , x t ​ ​ [ ∥ v θ ​ ( x t ​ , t ) − u t ​ ∥ 2 ]
这就是 MSE (均方误差)。

### 2. MSE 是 Flow Matching 的标准选择
方面 原因 数学一致性 Flow Matching 推导时就是用 [ o bj ec tO bj ec t ] L 2 ​ 损失,MSE 是理论要求 优化稳定性 MSE 对连续值回归最稳定,梯度光滑 物理意义 速度场的 [ o bj ec tO bj ec t ] L 2 ​ 距离就是 能量 ,有物理意义 简单高效 不用调超参,直接用

### 3. 代码里的体现(pi0.py:L229)
就是对每个 36 维的 [ o bj ec tO bj ec t ] v t ​ 和 [ o bj ec tO bj ec t ] u t ​ 计算 squared L2 距离,再取均值。

## 二、能用其他 Loss 吗?
### 可以,但要理解每一种的影响:
Loss 公式 效果 适合场景 MSE(当前) [ o bj ec tO bj ec t ] ∥ v − u ∥ 2 对大误差惩罚重 默认推荐 MAE / L1 [ o bj ec tO bj ec t ] ∥ v − u ∥ 对异常值鲁棒 动作有少量异常值时 Huber 混合 MSE 和 MAE 可调鲁棒性 平衡 MSE 和 MAE Smooth L1 Huber 的变体 对大误差线性惩罚 动作平滑性要求高 Cosine Similarity [ o bj ec tO bj ec t ] 1 − cos ( v , u ) 关注方向而非幅值 只关心动作方向

## 三、不同 Loss 对水瓶抓放任务的影响
### MSE(当前默认)
### MAE / L1
### Huber Loss(推荐试试)
### Cosine Similarity
## 四、如果要改,怎么改
在 pi0.py:L229 改一行:

## 五、我的建议
先保持 MSE 不动 ,原因:

1. MSE 是 Flow Matching 的理论标准 ——pi0-base 预训练时就用 MSE,你 finetune 时换 loss 会导致模型权重和新 loss 不匹配
2. 你的数据量不大 ——几百条 episode,MSE 已经足够稳定
3. 先训起来看效果 ——如果 MSE 训完效果不好(比如某些瓶子抓不稳),再考虑换 Huber
如果要试 , 优先试 Huber Loss ——比 MSE 更鲁棒,适合机器人动作这种有少量异常值的场景,而且不会丢失 MSE 的平滑梯度特性。

Q3:

Flow Matching 头提供 完全条件化的语义特征 来生成 36 维连续动作 这一部分的代码在哪

找到了。Flow Matching 头 就在 pi0.py 里,从"18 层 Gemma 输出"到"36 维连续动作"的完整链路:

## 一、代码位置分布
组件 文件位置 作用 初始化(所有层定义) pi0.py:L66-103 定义 action_in_proj, time_mlp, action_out_proj embed_prefix(图像+文本) pi0.py:L106-152 编码图像+文本 tokens embed_suffix(状态+动作+时间) pi0.py:L154-201 编码状态+噪声动作+时间步 主干 Gemma forward pi0.py:L224-226 18 层 Block 处理 Flow Matching 输出头 pi0.py:L227 action_out_proj 线性层 损失计算(Flow Matching loss) pi0.py:L229 MSE loss 推理采样(Flow Matching ODE) pi0.py:L231+ 迭代去噪生成动作

## 二、完整数据流(训练时)
### Step 1:初始化(L66-103)
### Step 2:embed_prefix — 编码图像+文本(L106-152)
### Step 3:embed_suffix — 编码状态+噪声动作+时间(L154-201)
### Step 4:Gemma 18 层 forward(L224-226)
### Step 5:Flow Matching 头输出(L227)⭐
这就是 Flow Matching 头 ——一个从 2048 维投影到 36 维的线性层。

### Step 6:Flow Matching 损失(L229)
## 三、推理时(sample_actions)
## 四、总结
组件 代码位置 维度变换 图像→SigLIP pi0.py:L118 [B,3,224,224] → [B,196,2048] 文本→Embedding pi0.py:L133 [B,5] → [B,5,2048] 动作→投影 pi0.py:L174 [B,15,36] → [B,15,2048] 时间步→MLP pi0.py:L179-182 [B,1] → [B,2048] 18 层 Gemma pi0.py:L224-226 [B,221,2048] → [B,221,2048] Flow Matching 头 pi0.py:L227 [B,15,2048] → [B,15,36]

Flow Matching 头就是 action_out_proj = Linear(2048 → 36) ——把 Gemma 输出的 2048 维语义特征,投影成 36 维连续动作(双臂 14 关节 + 夹爪 2 × 15 步)。

Q4: 关于激活函数

我用 你当前训练的水瓶抓放任务 (batch=2,3相机,文本prompt,action_horizon=15,action_dim=36)的 具体维度 ,结合代码,把前向传播求预测值的 每一步计算、激活函数位置、维度变化 完全具象化:

## 一、先明确:前向传播的2种场景
场景 求的预测值是什么 代码入口 训练时 带噪声动作的 速度场预测值 [ o bj ec tO bj ec t ] v t ​ pi0.py:L204-229 推理时 最终干净的 36维连续动作 pi0.py:L231+

## 二、训练时前向传播(求 [objectObject]vt​ 预测值)
以 batch=2,3相机,10个文本token,15步动作horizon,36维动作 为例,每一步对应代码和维度:

### Step 1:embed_prefix(编码图像+文本)
代码位置: pi0.py:L106-152

### Step 2:embed_suffix(编码噪声动作+时间步)
代码位置: pi0.py:L154-201

### Step 3:Gemma 18层前向(核心计算)
代码位置: pi0.py:L224-226

### Step 4:Flow Matching头输出 [objectObject]vt​ 预测值
代码位置: pi0.py:L227

### Step 5:计算Loss
代码位置: pi0.py:L229

## 三、推理时前向传播(求最终干净动作)
推理时是 多次迭代 的前向(10步默认),每一步都是上面的前向流程,只是噪声动作的初始值不同:

代码入口: pi0.py:L231+

## 四、核心总结
### 前向传播的计算本质 每一步都是「线性变换 + 激活函数」的组合 ,最终通过 输出层线性投影 得到预测值。
### 激活函数在前向的位置
激活函数 位置 作用 Softmax Attention层 把注意力权重归一化(非线性) GELU SwiGLU FFN层 引入非线性,拟合复杂模式 Swish 时间步MLP 处理连续时间步的非线性变换 无激活 输出层(action_out_proj) 直接投影到36维连续动作(线性输出)

### 维度变化全链路

指明 这几个激活函数的代码 和逻辑处理,为什么要这样用

思考过程

我将精确对应 Pi 模型中用到的 3 种核心激活函数 ,明确其 代码位置、逻辑处理、针对机器人任务的设计原因 :

## 一、激活函数总览
激活函数 代码位置 核心作用 针对机器人任务的设计原因 Softmax gemma.py:L198-205 注意力权重归一化 动态聚焦关键空间信息(抓取点/目标位置) GELU (SwiGLU门控) gemma.py:L260-263 FFN层非线性门控 拟合连续时序动作的平滑特征,自适应筛选动作模式 Swish (SiLU) pi0.py:L176-183 时间步编码MLP非线性 精准捕捉Flow Matching去噪过程的 连续时间特性

## 二、逐个激活函数详解(代码+逻辑+设计原因)
### 1. Softmax(注意力权重激活) 代码位置: gemma.py:L198-205 逻辑处理:
- 把 任意实数的注意力得分 映射到 [0,1] 区间,且 所有得分加和为1
- 转换公式: [ o bj ec tO bj ec t ] S o f t ma x ( x i ​ ) = ∑ j ​ e x j ​ e x i ​ ​
- 作用:让注意力权重具有 概率意义 ,明确告诉模型"每个位置的重要性占比" 设计原因(针对机器人任务):
1. 多相机/多模态融合的精准聚焦 :机器人任务需要同时处理3个相机图像、文本指令、 proprioception(本体状态),Softmax能让模型 动态聚焦关键信息 (比如抓水瓶时聚焦瓶身,放箱子时聚焦箱子内部),而非平均分配注意力。
2. 避免无关信息干扰 :机器人场景中存在大量噪声(比如背景杂物、机械臂的无关移动),Softmax的归一化特性能抑制噪声区域的权重,让模型专注于 任务核心区域 。
3. 时序动作的关联建模 :18层Transformer的注意力能捕捉 跨时间步的动作关联 (比如"当前抓取力度"和"后续放置位置"的关联),Softmax让这种关联更清晰。
### 2. GELU(SwiGLU FFN的门控激活) 代码位置: gemma.py:L260-263 逻辑处理:
- GELU激活公式: [ o bj ec tO bj ec t ] GE LU ( x ) = x ⋅ Φ ( x ) ( [ o bj ec tO bj ec t ] Φ 是标准高斯CDF,比ReLU平滑,负值区域有非零输出)
- SwiGLU的门控逻辑 :
  1. 两路并行线性变换,一路过GELU当**"门控信号" (控制信息通过量),另一路当 "信息载体"**(传递特征)
  2. 门控信号(GELU输出)× 信息载体 = 最终有效特征, 自适应筛选有用特征 设计原因(针对机器人任务):
1. 连续动作的平滑拟合 :机器人动作是 连续时序信号 (比如手臂移动的速度/轨迹),ReLU的硬截断(负值变0)会丢失连续信息,GELU的平滑特性保证 动作的连续性 ,避免机器人移动出现突兀卡顿。
2. 自适应动作特征筛选 :
   - 抓水瓶时,门控自动筛选 夹爪闭合的力/位置 特征,忽略机械臂的无关移动;
   - 放箱子时,门控自动筛选 手臂高度/水平位置 特征,聚焦放置动作。
3. Flow Matching的速度场预测 :Flow Matching需要预测 连续的速度场 (从噪声到动作的插值速度),GELU的平滑特性能更精准地拟合速度场的非线性变化,提升去噪精度。
### 3. Swish(SiLU,时间步编码的激活) 代码位置: pi0.py:L176-183 逻辑处理:
- Swish激活公式: [ o bj ec tO bj ec t ] Sw i s h ( x ) = x ⋅ σ ( x ) ( [ o bj ec tO bj ec t ] σ 是Sigmoid,自门控特性:负值有非零输出,正值接近线性)
- 作用:把**Flow Matching的连续时间步t(0→1)**编码成高维向量,注入AdaLN层, 动态调制每一层的归一化参数 设计原因(针对Flow Matching任务):
1. 捕捉去噪过程的非线性特性 :Flow Matching的时间步t从1(纯噪声)→0(干净动作)是 非线性的去噪过程 :
   - t接近1时:需要强去噪(大幅调整动作);
   - t接近0时:需要微调动作(小幅修正误差);
     Swish的自门控特性能精准编码这种 非线性的时间特性 ,让AdaLN能自适应调整每层的去噪强度。
2. 适合连续标量输入 :时间步t是 单个连续标量 (0~1),Swish的自门控特性( [ o bj ec tO bj ec t ] x ⋅ σ ( x ) )比GELU更灵活,能更好地捕捉标量输入的细微变化(比如t从0.9到0.8的去噪强度变化)。
3. 提升AdaLN的条件注入精度 :AdaLN是Pi实现"条件动作生成"的核心(把时间步/状态/指令注入每一层),Swish的非线性编码能让注入的条件更 精准匹配当前去噪阶段的需求 。
## 三、关键补充:激活函数的"学习"本质
所有这些激活函数(Softmax/GELU/Swish)都是 固定的数学函数 ,不参与训练。真正学习的是 激活函数前后的权重矩阵 :

- 比如GELU前后的 w_gating 、 w_linear (SwiGLU的权重)
- 比如Swish前后的 Dense 层权重(时间步编码的权重)
- 比如Softmax前后的 QKV 投影权重(注意力的权重)
激活函数只是 引入非线性的"工具" ,让模型能学习 连续动作的复杂模式 (比如抓水瓶的力控制、放箱子的位置控制),避免模型退化为线性变换。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐