「Transformer核心必读」前馈全连接网络(FFN)深度解析:从数学原理、Dropout架构辨析到PyTorch实现
1、基本介绍
一、名称解析:“前馈全连接层”到底是什么?
1.1 拆解名字:三个关键词
| 词 | 含义 |
|---|---|
| 前馈(Feed-Forward) | 指信号单向流动,从输入到输出,无反馈回路(区别于 RNN、LSTM 等循环网络)。这是神经网络中最基本的信息传递方式。 |
| 全连接(Fully Connected) | 指一层中每个神经元都与上一层所有神经元相连,即使用 nn.Linear 实现的线性变换。 |
| 层(Layer) | 表示它是网络中的一个计算单元。 |
✅ 所以,“前馈全连接层” = 一个采用全连接结构、信息单向流动的神经网络层。
但注意:在 Transformer 语境下,“前馈全连接层”通常指的不是一个单层,而是一个小型的两层 MLP(多层感知机)!
1.2 它和“全连接层”是什么关系?
- “全连接层”(Fully Connected Layer) 是一个基础构件,等价于
nn.Linear,执行 y = W x + b y = Wx + b y=Wx+b。 - “前馈全连接层”(FFN) 在 Transformer 中是一个复合模块,由两个全连接层 + 一个非线性激活函数组成。
✅ 关系总结:
FFN 是由多个全连接层构成的前馈子网络。
“全连接层”是砖块,“前馈全连接层(FFN)”是一面墙。
1.3 命名与缩写:到底叫什么?简写是啥?
| 名称 | 是否常用 | 说明 |
|---|---|---|
| Feed-Forward Network (FFN) | ✅ 最标准、最广泛使用 | 论文、代码库(如 HuggingFace Transformers)、教材通用 |
| Position-wise Feed-Forward Network | ✅ 学术论文常用 | 强调“对每个位置独立应用”(见下文) |
| FF layer / FF block | ✅ 工程口语 | “FF” 即 Feed-Forward |
| MLP (Multi-Layer Perceptron) | ⚠️ 可用但不够精确 | MLP 是更广义的概念,FFN 是一种特定结构的 MLP |
| 全连接层 | ❌ 不准确(易混淆) | 会让人误以为是单个 Linear 层 |
📌 结论:
- 正式名称:Position-wise Feed-Forward Network (位置前馈网络)
- 通用简称:FFN
- 不要简称为“全连接层”,以免与单个
Linear混淆!✅ “前馈全连接层”在 Transformer 架构中的标准、正式名称就是 “Position-wise Feed-Forward Network”(位置前馈网络)。
为什么强调 “Position-wise”(位置-wise)?
这是为了精确描述其计算方式,与多头注意力形成对比:
模块 是否跨位置交互 计算方式 Multi-Head Attention ✅ 是 对序列中所有 token 联合计算(token 之间相互 attend) Feed-Forward Network ❌ 否 对每个位置(position)独立、相同地应用同一个全连接网络 因此:
- 它不是“全局”的一个大网络处理整个序列;
- 而是在每个 token 的位置上,单独跑一遍相同的 FFN;
- 所有位置共享同一套参数(即
W1,W2是全局的),但计算彼此独立。📌 这种“参数共享 + 位置独立计算”的模式,就叫做 position-wise。
权威出处
原始论文《Attention Is All You Need》(2017):
“Each of the layers in our encoder and decoder contains a fully connected feed-forward network, applied to each position separately and identically.”
(我们编码器和解码器中的每一层都包含一个全连接前馈网络,对每个位置分别且相同地应用。)后续所有 Transformer 变体(BERT、GPT、T5 等) 都沿用这一术语。
HuggingFace Transformers 文档 中也明确使用
position-wise feed-forward描述该组件。
常见简称
全称 简称 使用场景 Position-wise Feed-Forward Network FFN 最通用(论文、代码、讨论) Position-wise Feed-Forward Network Position-wise FFN 强调“位置独立”特性时 Feed-Forward Network FF 或 FF layer 口语化、工程中 ⚠️ 注意:虽然常简称为 “FFN”,但不能省略 “position-wise” 的含义——这是理解其作用的关键。
总结
- ✅ “前馈全连接层” = “Position-wise Feed-Forward Network”
- ✅ “Position-wise” 不是可有可无的修饰词,而是定义性特征
- ✅ 所有主流文献和框架都采用这一命名
所以:
“Transformer 中的前馈层,正式名称是 Position-wise Feed-Forward Network(FFN)。”
二、FFN 的数学定义与标准结构
2.1 标准公式(来自《Attention Is All You Need》)
Transformer 论文中定义的 FFN 如下:
FFN ( x ) = max ( 0 , x W 1 + b 1 ) W 2 + b 2 \text{FFN}(x) = \max(0, xW_1 + b_1)W_2 + b_2 FFN(x)=max(0,xW1+b1)W2+b2
其中:
- x ∈ R d model x \in \mathbb{R}^{d_{\text{model}}} x∈Rdmodel:输入向量(如多头注意力的输出)
- W 1 ∈ R d model × d f f W_1 \in \mathbb{R}^{d_{\text{model}} \times d_{ff}} W1∈Rdmodel×dff, b 1 ∈ R d f f b_1 \in \mathbb{R}^{d_{ff}} b1∈Rdff
- W 2 ∈ R d f f × d model W_2 \in \mathbb{R}^{d_{ff} \times d_{\text{model}}} W2∈Rdff×dmodel, b 2 ∈ R d model b_2 \in \mathbb{R}^{d_{\text{model}}} b2∈Rdmodel
- d f f d_{ff} dff:中间隐藏层维度,通常 d f f = 2048 d_{ff} = 2048 dff=2048(当 d model = 512 d_{\text{model}} = 512 dmodel=512 时),即 4 倍膨胀
🔍 注意:原始论文使用 ReLU(即 max ( 0 , ⋅ ) \max(0, \cdot) max(0,⋅)),但后续模型(如 BERT)普遍改用 GELU,现代大模型(如 LLaMA)则采用 SwiGLU 等门控机制。
2.2 结构图解
输入 x (d_model)
│
▼
[ Linear(d_model → d_ff) ] ← 第一个全连接层(升维)
│
▼
[ ReLU / GELU ] ← 非线性激活
│
▼
[ Linear(d_ff → d_model) ] ← 第二个全连接层(降维回原维度)
│
▼
输出 y (d_model)
✅ 关键特性:
- 输入输出维度相同( d model d_{\text{model}} dmodel),便于与残差连接配合
- 中间层更高维( d f f > d model d_{ff} > d_{\text{model}} dff>dmodel),提供更强的表达能力
- 参数共享 across positions:同一 FFN 应用于序列中所有位置(但不同位置的计算彼此独立)
💡 补充说明:虽然 FFN 对每个位置独立计算,但所有位置共享同一套权重参数(即
fc1和fc2是全局共享的),这是“position-wise”的准确含义。
三、“Position-wise” 是什么意思?为什么强调这一点?
这是理解 FFN 在 Transformer 中作用的关键!
3.1 对比:多头注意力 vs FFN
| 模块 | 是否跨位置交互 | 作用 |
|---|---|---|
| 多头注意力 | ✅ 是 | 建模 token 之间的依赖关系(如 “The cat” → “cat” 关注 “The”) |
| FFN | ❌ 否 | 对每个 token 的表示进行独立的非线性变换 |
3.2 “Position-wise” 的含义
-
假设输入序列长度为 T T T,batch size 为 B B B
-
多头注意力输出:$ X \in \mathbb{R}^{B \times T \times d_{\text{model}}} $
-
FFN 对每个位置 t ∈ [ 1 , T ] t \in [1, T] t∈[1,T] 独立应用相同的函数:
Output [ : , t , : ] = FFN ( X [ : , t , : ] ) \text{Output}[:, t, :] = \text{FFN}(X[:, t, :]) Output[:,t,:]=FFN(X[:,t,:])
🧠 想象:有 T T T 个相同的“小专家”(FFN),每人负责处理一个 token 的向量,彼此不交流,但使用同一本操作手册(共享参数)。
✅ 所以 FFN 不建模序列结构,只增强每个 token 的表示能力。
四、FFN 到底有什么用?—— 三大核心作用
4.1 作用 1:引入非线性(Non-linearity)
- 注意力机制本质上是加权求和(尽管 softmax 是非线性的,但整体仍是输入的凸组合)
- 如果没有 FFN,整个 Transformer 就退化为一系列线性变换的组合,无法拟合复杂函数
- FFN 的激活函数(ReLU/GELU/SwiGLU)提供了关键的非线性,使模型具备强大的表达能力
💡 类比:注意力决定“看哪里”,FFN 决定“怎么看懂看到的东西”。
4.2 作用 2:特征空间变换与增强
- 第一个 Linear 将 d model d_{\text{model}} dmodel 维向量映射到更高维空间(如 512 → 2048)
- 在高维空间中,特征更容易被线性分离(类似“核方法”的思想)
- 第二个 Linear 再压缩回原维度,保留有用信息
✅ 这类似于“先展开再提炼”的过程,增强了表示的丰富性和判别性。
4.3 作用 3:与注意力形成“分工协作”
Transformer Encoder Block 的标准流程:
Input
│
▼
[ Multi-Head Attention ] → 聚合上下文信息(跨 token)
│
▼
[ Add & Norm ]
│
▼
[ FFN ] → 独立增强每个 token 表示(per token)
│
▼
[ Add & Norm ]
│
▼
Output
- 注意力:解决“关系建模”问题
- FFN:解决“特征表达”问题
二者缺一不可!
五、FFN 的变体与演进
5.1 激活函数的演进
| 激活函数 | 使用场景 | 特点 |
|---|---|---|
| ReLU | 原始 Transformer | 简单高效,但有“死区”问题 |
| GELU | BERT、RoBERTa | 平滑、性能略优, GELU ( x ) = x Φ ( x ) \text{GELU}(x) = x \Phi(x) GELU(x)=xΦ(x)( Φ \Phi Φ 为标准正态累积分布) |
| SwiGLU | PaLM、LLaMA、Mistral | 更强性能,形式: SwiGLU ( x ) = Swish ( x W ) ⊗ ( x V ) \text{SwiGLU}(x) = \text{Swish}(xW) \otimes (xV) SwiGLU(x)=Swish(xW)⊗(xV) |
📌 LLaMA 等现代大模型已改用 SwiGLU 替代传统 FFN,因其具有门控机制,表达能力更强。
5.2 结构变体
(1) GLU Variants(Gated Linear Units)
# SwiGLU 示例
def swiglu(x):
x1, x2 = x.chunk(2, dim=-1)
return F.silu(x1) * x2
- 将输入分成两半,一半做激活,一半做门控
- 表达能力更强,已成为大模型标配
(2) Depth-wise Separable FFN(减少参数)
- 类似 CNN 中的 depthwise conv,降低计算量
- 用于移动端/边缘设备模型(如 MobileBERT)
六、工程实现:PyTorch 代码详解
6.1 标准 FFN 实现
import torch
import torch.nn as nn
class PositionWiseFFN(nn.Module):
def __init__(self, d_model: int, d_ff: int = None, dropout: float = 0.1, activation='relu'):
super().__init__()
self.d_model = d_model
self.d_ff = d_ff if d_ff is not None else 4 * d_model # 默认 4 倍
self.fc1 = nn.Linear(d_model, self.d_ff)
self.fc2 = nn.Linear(self.d_ff, d_model)
self.dropout = nn.Dropout(dropout)
if activation == 'relu':
self.activation = nn.ReLU()
elif activation == 'gelu':
self.activation = nn.GELU()
else:
raise ValueError(f"Unsupported activation: {activation}")
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: (B, T, d_model)
output: (B, T, d_model)
"""
# 注意:残差连接通常在外部由 TransformerBlock 添加
x = self.fc1(x) # (B, T, d_ff)
x = self.activation(x) # (B, T, d_ff)
x = self.dropout(x)
x = self.fc2(x) # (B, T, d_model)
return x
⚠️ 重要澄清:
残差连接(residual connection)和 LayerNorm 不属于 FFN 本身,而是 Transformer Block 的组成部分。FFN 只负责核心的非线性变换。
6.2 参数量计算
以 d_model=512, d_ff=2048 为例:
fc1: 512 × 2048 = 1 , 048 , 576 512 \times 2048 = 1,048,576 512×2048=1,048,576fc2: 2048 × 512 = 1 , 048 , 576 2048 \times 512 = 1,048,576 2048×512=1,048,576- 总计 ≈ 2.1M 参数
💡 FFN 的参数量通常是多头注意力的 2~4 倍,是 Transformer 的主要参数来源!
七、常见误区澄清
❌ 误区 1:“FFN 就是一个全连接层”
→ 错!它是两个全连接层 + 激活函数的组合。
❌ 误区 2:“FFN 也建模 token 间关系”
→ 错!它是 position-wise 的,不跨 token 交互。
❌ 误区 3:“FFN 可有可无,只是增加深度”
→ 错!去掉 FFN 会导致性能大幅下降(实验证明)。
❌ 误区 4:“FFN 的输入是原始词向量”
→ 错!它的输入是多头注意力的输出(已包含上下文信息)。
❌ 误区 5:“FFN 对每个位置使用不同的参数”
→ 错!所有位置共享同一套 FFN 参数(即 fc1 和 fc2 是全局的)。
八、总结:FFN 的核心要点
| 维度 | 内容 |
|---|---|
| 名称 | Position-wise Feed-Forward Network (FFN) |
| 结构 | 两层 Linear + 非线性激活(如 ReLU/GELU/SwiGLU) |
| 输入/输出 | ( B , T , d model ) → ( B , T , d model ) (B, T, d_{\text{model}}) \rightarrow (B, T, d_{\text{model}}) (B,T,dmodel)→(B,T,dmodel) |
| 核心作用 | 1. 引入非线性2. 增强每个 token 的表示3. 与注意力分工协作 |
| 关键特性 | Position-wise(独立处理每个位置,但参数全局共享) |
| 参数量 | 主要集中在 FFN(占 Transformer 大部分参数) |
| 现代变体 | SwiGLU、GLU 等门控机制 |
九、延伸思考
-
为什么 FFN 要放在注意力之后?
因为注意力先聚合了上下文信息,FFN 再对这个“富含上下文的表示”进行非线性加工。 -
能否交换顺序?
理论上可以,但效果差。因为先做 FFN 无法利用上下文信息。 -
FFN 能否共享参数 across layers?
一般不共享(每层学不同变换),但 ALBERT 等模型尝试过参数共享以压缩模型。
你现在不仅知道 FFN 是什么,更理解了它为什么存在、如何工作、有何演变。这为你下一步学习完整的 Transformer Block(注意力 + FFN + LayerNorm + 残差连接)打下了坚实基础。
继续前进吧!🚀
2、为什么 dropout 不算 “组成部分”
这正是区分"核心架构"和"工程优化"的关键!
让我给你彻底讲清楚:
🎯 核心答案
Dropout 是训练技巧,不是架构设计!
| 组件 | 是否必须 | 角色 | 在推理时 |
|---|---|---|---|
| Linear | ✅ 必须 | 核心变换 | 存在 |
| 激活函数 | ✅ 必须 | 非线性 | 存在 |
| Dropout | ❌ 可选 | 防止过拟合 | 关闭 |
🔬 为什么 Dropout 不算"组成部分"?
- 推理时的行为完全不同
# 训练时
x = self.fc1(x) # 线性变换 ✅
x = self.activation(x) # 非线性 ✅
x = self.dropout(x) # 随机置零 ⚠️ 训练专用
# 推理时
x = self.fc1(x) # 线性变换 ✅
x = self.activation(x) # 非线性 ✅
x = x # Dropout 完全消失!
Dropout 在推理时根本不存在,怎么能叫"组成部分"呢?
- 历史/论文视角
看《Attention Is All You Need》原文:
“Each layer has two sub-layers. The first is a multi-head self-attention mechanism, and the second is a simple, position-wise fully connected feed-forward network.”
论文定义 FFN = Linear + ReLU + Linear,Dropout 是单独在别处说明的!
- 代码层面的分离
# HuggingFace Transformers 的 BERT 实现
class BertIntermediate(nn.Module):
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states) # Linear
hidden_states = self.intermediate_act_fn(hidden_states) # Activation
return hidden_states # ❌ 没有 Dropout!
class BertOutput(nn.Module):
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states) # Linear
hidden_states = self.dropout(hidden_states) # ✅ Dropout 在这里!
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
看到了吗? Dropout 是放在残差连接之前,不属于 FFN 核心!
📊 架构层次表
| 层次 | 内容 | 是否FFN核心 | 推理时 |
|---|---|---|---|
| 核心架构 | Linear + Activation + Linear | ✅ 是 | ✅ 存在 |
| 训练优化 | Dropout | ❌ 否 | ❌ 关闭 |
| 残差连接 | x + sublayer(x) |
❌ 否 | ✅ 存在 |
| 层归一化 | LayerNorm | ❌ 否 | ✅ 存在 |
🎨 类比理解
🍔 汉堡包店
| 组件 | 类比 | 是否核心 |
|---|---|---|
| 面包 + 肉饼 + 生菜 | Linear + Activation + Linear | ✅ 核心配方 |
| 外卖包装盒 | Dropout | ❌ 只是配送需要 |
| 堂食盘子 | 残差连接 | ❌ 摆盘方式 |
| 酱料包 | LayerNorm | ❌ 调味品 |
你不会说"汉堡包=面包+肉饼+生菜+外卖盒"吧?
🏭 汽车工厂
| 组件 | 类比 | 是否核心 |
|---|---|---|
| 发动机 + 变速箱 + 车轮 | Linear + Activation + Linear | ✅ 核心 |
| 安全气囊 | Dropout | ❌ 安全配置 |
| 座椅 | 残差连接 | ❌ 舒适配置 |
| 车载空调 | LayerNorm | ❌ 辅助功能 |
FFN的定义就像"汽车=发动机+变速箱+车轮",安全气囊是额外安全配置!
💡 为什么你的直觉是对的
你觉得Dropout应该算进去,是因为在代码里它确实和FFN写在一起:
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.dropout(x) # ← 写在一起
x = self.fc2(x)
return x
这是工程实践,不是架构定义!
📝 精确的说法
✅ 严格定义(学术论文)
FFN = Linear → Activation → Linear
✅ 工程实现(PyTorch代码)
FFN = Linear → Activation → Dropout → Linear
✅ 最佳教学实践
FFN的核心是两层的全连接网络,实践中通常会在两个Linear之间加入Dropout防止过拟合。
🎯 终极答案
不把Dropout算作FFN的"组成部分",是因为:
- 它不是必须的 - 你可以去掉Dropout,仍然是FFN
- 推理时不存在 - Dropout只在训练时激活
- 论文定义 - Transformer原始论文就没把它算进FFN
- 代码分离 - 很多实现把Dropout放在残差分支,不在FFN内部
在写代码时,Dropout确实和FFN紧密结合。这是工程实践和学术定义的区别!
你现在既理解架构本质,也明白工程实现,这才是真正的通透!🎯
3、raise 抛出异常
🌟 一句话理解 raise
raise就是你在代码里“手动制造一个错误”,让程序停下来并告诉你:“这里出问题了!”
🔧 基本语法
raise 异常类型("错误信息")
常见用法示例:
# 1. 抛出 ValueError(值无效)
raise ValueError("年龄不能是负数")
# 2. 抛出 TypeError(类型错误)
raise TypeError("期望输入字符串,但得到了整数")
# 3. 抛出 RuntimeError(运行时错误)
raise RuntimeError("网络连接失败")
🆚 对比你熟悉的场景
| 场景 | 自动报错(Python 内置) | 手动报错(你用 raise) |
|---|---|---|
| 除零错误 | 1 / 0 → ZeroDivisionError |
— |
| 索引越界 | [][0] → IndexError |
— |
| 传错参数 | — | raise ValueError("激活函数不支持") |
| 文件不存在 | — | raise FileNotFoundError("配置文件缺失") |
💡
raise让你能在“Python 不会自动报错,但逻辑上不合理”的地方主动报错。
🎯 为什么代码需要 raise?
回顾代码:
if activation == 'relu':
self.activation = nn.ReLU()
else:
assert False, '还需要添加更多激活函数'
问题:如果用户写 activation='gelu',你希望发生什么?
- ❌ 什么都不做? → 后面会
AttributeError(更难 debug) - ❌ 只打印提示? → 程序继续跑,结果错得离谱
- ✅ 立刻停止并明确说“你传错了” → 这就是
raise的作用!
用 raise 改写:
if activation == 'relu':
self.activation = nn.ReLU()
else:
raise ValueError(f"不支持的激活函数: '{activation}'。目前仅支持 'relu'。")
🧪 实际运行效果
def test_activation(activation):
if activation == 'relu':
print("使用 ReLU")
else:
raise ValueError(f"不支持: {activation}")
# 正常情况
test_activation('relu') # 输出: 使用 ReLU
# 异常情况
test_activation('gelu')
输出:
Traceback (most recent call last):
File "test.py", line 9, in <module>
test_activation('gelu')
File "test.py", line 6, in test_activation
raise ValueError(f"不支持: {activation}")
ValueError: 不支持: gelu
✅ 清晰、立即、无法忽略!
❓ 那 assert 和 raise 有什么区别?
| 特性 | assert |
raise |
|---|---|---|
| 用途 | 调试时检查“不可能发生”的情况 | 处理“可能发生的错误输入” |
| 生产环境 | python -O 会完全删除 |
永远生效 |
| 异常类型 | 只能抛 AssertionError |
可抛任何异常(ValueError, TypeError…) |
| 语义 | “我断言这里一定为真” | “这里出错了,请处理” |
📌 规则:
- 检查用户输入/外部数据 → 用
raise- 检查内部逻辑是否自洽 → 用
assert
python -O中的-O是 “优化模式”(Optimization mode) 的缩写。当你用这个选项运行 Python 脚本时,解释器会跳过(完全忽略)所有的assert语句。
🔍 具体行为
正常运行(无
-O):# test.py def check_age(age): assert age >= 0, "年龄不能为负数" print("年龄有效:", age) check_age(-5)运行:
python test.py输出:
Traceback (most recent call last): File "test.py", line 5, in <module> check_age(-5) File "test.py", line 2, in check_age assert age >= 0, "年龄不能为负数" AssertionError: 年龄不能为负数✅
assert生效,程序报错停止
用
-O运行(优化模式):python -O test.py输出:
年龄有效: -5❌
assert被完全删除!程序继续执行,即使逻辑错误!💥 这就是为什么 不能用
assert做参数校验 —— 在生产环境(常启用-O提升性能)中,安全检查会静默失效!
🧠 为什么 Python 要这样设计?
assert的本意是 “调试时的内部逻辑检查”,比如:def divide(a, b): result = a / b assert result * b == a # 验证计算是否自洽(仅开发时需要) return result在生产环境,为了提升一点点性能,可以安全地移除这些“自我验证”代码。
但它绝不应该用于处理用户输入、外部数据或关键逻辑分支!
✅ 正确做法对比
场景 错误做法(用 assert)正确做法(用 raise)检查函数参数 assert x > 0if x <= 0: raise ValueError(...)验证配置文件 assert config is not Noneif config is None: raise RuntimeError(...)处理用户输入 assert mode in ['train', 'eval']if mode not in [...]: raise ValueError(...)✅
raise永远不会被-O删除!它在任何模式下都生效。
🛠 如何验证?
创建
demo.py:print("=== 正常模式 ===") try: assert False, "这条 assert 应该触发" except AssertionError as e: print("捕获到 assert 错误:", e) print("\n=== 现在请用 python -O demo.py 运行看效果 ===")
正常运行:
python demo.py输出:
捕获到 assert 错误: 这条 assert 应该触发优化模式运行:
python -O demo.py输出:
=== 正常模式 === === 现在请用 python -O demo.py 运行看效果 ===→
assert完全消失,没报错!
💡 总结
关键点 说明 python -O启用优化模式,移除所有 assert语句后果 依赖 assert的逻辑校验在生产环境失效正确替代 用 if ...: raise ValueError(...)做输入/状态校验assert用途仅用于开发阶段的内部逻辑断言(如算法不变式) 🚨 记住:
永远不要用assert来处理可能来自外部的数据或参数 —— 因为在-O模式下,“安全网”会凭空消失!你现在完全理解为什么我建议你用
raise而不是assert了吧?这可是专业 Python 开发的必修课!👏
✅ 在代码中正确使用 raise
class PositionWiseFeedForwardNetwork(nn.Module):
def __init__(self, ..., activation: str = 'relu'):
...
if activation == 'relu':
self.activation = nn.ReLU()
elif activation == 'gelu':
self.activation = nn.GELU()
else:
# 明确告诉用户:你传的值不对!
raise ValueError(
f"激活函数 '{activation}' 未实现。"
"支持的选项: 'relu', 'gelu'"
)
💡 最后总结
| 关键点 | 说明 |
|---|---|
raise 是什么 |
手动触发错误的关键字 |
| 为什么用它 | 让错误尽早暴露,避免隐藏 bug |
| 怎么用 | raise ValueError("错误描述") |
和 assert 区别 |
raise 用于输入校验,assert 用于调试断言 |
🚀 学会
raise,你就掌握了专业开发者控制程序错误流的第一把钥匙!
现在你可以自信地把 assert False 换成 raise ValueError 了!
4、代码
在《基本介绍》中也有
import torch
import torch.nn as nn
class PositionWiseFeedForwardNetwork(nn.Module):
def __init__(
self,
d_model: int,
d_ff: int = None,
dropout: float = 0.1,
activation = 'relu'
):
super().__init__()
self.d_model = d_model
self.d_ff = d_ff if d_ff is not None else 4 * d_model # 默认 4 倍
self.linear1 = nn.Linear(in_features=d_model, out_features=self.d_ff)
if activation == 'relu':
self.activation = nn.ReLU()
else:
raise ValueError('还需要添加更多激活函数')
raise ValueError(f"Unsupported activation: '{activation}'. Currently only 'relu' is supported.")
self.dropout = nn.Dropout(p=dropout)
self.linear2 = nn.Linear(in_features=self.d_ff, out_features=self.d_model)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# 返回值预计为 张量(但Python解释器可不管什么类型,即使不返回张量也不会报错)
"""
:param x: x.shape = (B, T, d_model)
:return: (B, T, d_model)
"""
x = self.linear1(x)
x = self.activation(x)
x = self.dropout(x)
x = self.linear2(x)
# 可一行搞定
# return self.linear2(self.dropout(self.activation(self.linear1(x))))
return x
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐



所有评论(0)