1、基本介绍

一、名称解析:“前馈全连接层”到底是什么?

1.1 拆解名字:三个关键词

含义
前馈(Feed-Forward) 指信号单向流动,从输入到输出,无反馈回路(区别于 RNN、LSTM 等循环网络)。这是神经网络中最基本的信息传递方式。
全连接(Fully Connected) 指一层中每个神经元都与上一层所有神经元相连,即使用 nn.Linear 实现的线性变换。
层(Layer) 表示它是网络中的一个计算单元。

✅ 所以,“前馈全连接层” = 一个采用全连接结构、信息单向流动的神经网络层

但注意:在 Transformer 语境下,“前馈全连接层”通常指的不是一个单层,而是一个小型的两层 MLP(多层感知机)


1.2 它和“全连接层”是什么关系?

  • “全连接层”(Fully Connected Layer) 是一个基础构件,等价于 nn.Linear,执行 y = W x + b y = Wx + b y=Wx+b
  • “前馈全连接层”(FFN) 在 Transformer 中是一个复合模块,由两个全连接层 + 一个非线性激活函数组成。

关系总结

FFN 是由多个全连接层构成的前馈子网络
“全连接层”是砖块,“前馈全连接层(FFN)”是一面墙。


1.3 命名与缩写:到底叫什么?简写是啥?

名称 是否常用 说明
Feed-Forward Network (FFN) 最标准、最广泛使用 论文、代码库(如 HuggingFace Transformers)、教材通用
Position-wise Feed-Forward Network ✅ 学术论文常用 强调“对每个位置独立应用”(见下文)
FF layer / FF block ✅ 工程口语 “FF” 即 Feed-Forward
MLP (Multi-Layer Perceptron) ⚠️ 可用但不够精确 MLP 是更广义的概念,FFN 是一种特定结构的 MLP
全连接层 不准确(易混淆) 会让人误以为是单个 Linear

📌 结论

  • 正式名称Position-wise Feed-Forward Network (位置前馈网络)
  • 通用简称FFN
  • 不要简称为“全连接层”,以免与单个 Linear 混淆!

“前馈全连接层”在 Transformer 架构中的标准、正式名称就是 “Position-wise Feed-Forward Network”(位置前馈网络)。


为什么强调 “Position-wise”(位置-wise)?

这是为了精确描述其计算方式,与多头注意力形成对比:

模块 是否跨位置交互 计算方式
Multi-Head Attention ✅ 是 对序列中所有 token 联合计算(token 之间相互 attend)
Feed-Forward Network ❌ 否 对每个位置(position)独立、相同地应用同一个全连接网络

因此:

  • 它不是“全局”的一个大网络处理整个序列;
  • 而是在每个 token 的位置上,单独跑一遍相同的 FFN
  • 所有位置共享同一套参数(即 W1, W2 是全局的),但计算彼此独立

📌 这种“参数共享 + 位置独立计算”的模式,就叫做 position-wise


权威出处

  1. 原始论文《Attention Is All You Need》(2017)

    “Each of the layers in our encoder and decoder contains a fully connected feed-forward network, applied to each position separately and identically.”
    (我们编码器和解码器中的每一层都包含一个全连接前馈网络,对每个位置分别且相同地应用。)

  2. 后续所有 Transformer 变体(BERT、GPT、T5 等) 都沿用这一术语。

  3. HuggingFace Transformers 文档 中也明确使用 position-wise feed-forward 描述该组件。


常见简称

全称 简称 使用场景
Position-wise Feed-Forward Network FFN 最通用(论文、代码、讨论)
Position-wise Feed-Forward Network Position-wise FFN 强调“位置独立”特性时
Feed-Forward Network FFFF layer 口语化、工程中

⚠️ 注意:虽然常简称为 “FFN”,但不能省略 “position-wise” 的含义——这是理解其作用的关键。


总结

  • “前馈全连接层” = “Position-wise Feed-Forward Network”
  • ✅ “Position-wise” 不是可有可无的修饰词,而是定义性特征
  • ✅ 所有主流文献和框架都采用这一命名

所以:

“Transformer 中的前馈层,正式名称是 Position-wise Feed-Forward Network(FFN)。”


二、FFN 的数学定义与标准结构

2.1 标准公式(来自《Attention Is All You Need》)

Transformer 论文中定义的 FFN 如下:

FFN ( x ) = max ⁡ ( 0 , x W 1 + b 1 ) W 2 + b 2 \text{FFN}(x) = \max(0, xW_1 + b_1)W_2 + b_2 FFN(x)=max(0,xW1+b1)W2+b2

其中:

  • x ∈ R d model x \in \mathbb{R}^{d_{\text{model}}} xRdmodel:输入向量(如多头注意力的输出)
  • W 1 ∈ R d model × d f f W_1 \in \mathbb{R}^{d_{\text{model}} \times d_{ff}} W1Rdmodel×dff, b 1 ∈ R d f f b_1 \in \mathbb{R}^{d_{ff}} b1Rdff
  • W 2 ∈ R d f f × d model W_2 \in \mathbb{R}^{d_{ff} \times d_{\text{model}}} W2Rdff×dmodel, b 2 ∈ R d model b_2 \in \mathbb{R}^{d_{\text{model}}} b2Rdmodel
  • d f f d_{ff} dff中间隐藏层维度,通常 d f f = 2048 d_{ff} = 2048 dff=2048(当 d model = 512 d_{\text{model}} = 512 dmodel=512 时),即 4 倍膨胀

🔍 注意:原始论文使用 ReLU(即 max ⁡ ( 0 , ⋅ ) \max(0, \cdot) max(0,)),但后续模型(如 BERT)普遍改用 GELU,现代大模型(如 LLaMA)则采用 SwiGLU 等门控机制。


2.2 结构图解

输入 x (d_model)
     │
     ▼
[ Linear(d_model → d_ff) ]   ← 第一个全连接层(升维)
     │
     ▼
[ ReLU / GELU ]              ← 非线性激活
     │
     ▼
[ Linear(d_ff → d_model) ]   ← 第二个全连接层(降维回原维度)
     │
     ▼
输出 y (d_model)

关键特性

  • 输入输出维度相同 d model d_{\text{model}} dmodel),便于与残差连接配合
  • 中间层更高维 d f f > d model d_{ff} > d_{\text{model}} dff>dmodel),提供更强的表达能力
  • 参数共享 across positions:同一 FFN 应用于序列中所有位置(但不同位置的计算彼此独立)

💡 补充说明:虽然 FFN 对每个位置独立计算,但所有位置共享同一套权重参数(即 fc1fc2 是全局共享的),这是“position-wise”的准确含义。


三、“Position-wise” 是什么意思?为什么强调这一点?

这是理解 FFN 在 Transformer 中作用的关键!

3.1 对比:多头注意力 vs FFN

模块 是否跨位置交互 作用
多头注意力 ✅ 是 建模 token 之间的依赖关系(如 “The cat” → “cat” 关注 “The”)
FFN ❌ 否 对每个 token 的表示进行独立的非线性变换

3.2 “Position-wise” 的含义

  • 假设输入序列长度为 T T T,batch size 为 B B B

  • 多头注意力输出:$ X \in \mathbb{R}^{B \times T \times d_{\text{model}}} $

  • FFN 对每个位置 t ∈ [ 1 , T ] t \in [1, T] t[1,T] 独立应用相同的函数
    Output [ : , t , : ] = FFN ( X [ : , t , : ] ) \text{Output}[:, t, :] = \text{FFN}(X[:, t, :]) Output[:,t,:]=FFN(X[:,t,:])

🧠 想象:有 T T T 个相同的“小专家”(FFN),每人负责处理一个 token 的向量,彼此不交流,但使用同一本操作手册(共享参数)

所以 FFN 不建模序列结构,只增强每个 token 的表示能力


四、FFN 到底有什么用?—— 三大核心作用

4.1 作用 1:引入非线性(Non-linearity)

  • 注意力机制本质上是加权求和(尽管 softmax 是非线性的,但整体仍是输入的凸组合)
  • 如果没有 FFN,整个 Transformer 就退化为一系列线性变换的组合,无法拟合复杂函数
  • FFN 的激活函数(ReLU/GELU/SwiGLU)提供了关键的非线性,使模型具备强大的表达能力

💡 类比:注意力决定“看哪里”,FFN 决定“怎么看懂看到的东西”。


4.2 作用 2:特征空间变换与增强

  • 第一个 Linear 将 d model d_{\text{model}} dmodel 维向量映射到更高维空间(如 512 → 2048)
  • 在高维空间中,特征更容易被线性分离(类似“核方法”的思想)
  • 第二个 Linear 再压缩回原维度,保留有用信息

✅ 这类似于“先展开再提炼”的过程,增强了表示的丰富性和判别性。


4.3 作用 3:与注意力形成“分工协作”

Transformer Encoder Block 的标准流程:

Input
  │
  ▼
[ Multi-Head Attention ]  → 聚合上下文信息(跨 token)
  │
  ▼
[ Add & Norm ]
  │
  ▼
[ FFN ]                   → 独立增强每个 token 表示(per token)
  │
  ▼
[ Add & Norm ]
  │
  ▼
Output
  • 注意力:解决“关系建模”问题
  • FFN:解决“特征表达”问题

二者缺一不可!


五、FFN 的变体与演进

5.1 激活函数的演进

激活函数 使用场景 特点
ReLU 原始 Transformer 简单高效,但有“死区”问题
GELU BERT、RoBERTa 平滑、性能略优, GELU ( x ) = x Φ ( x ) \text{GELU}(x) = x \Phi(x) GELU(x)=xΦ(x) Φ \Phi Φ 为标准正态累积分布)
SwiGLU PaLM、LLaMA、Mistral 更强性能,形式: SwiGLU ( x ) = Swish ( x W ) ⊗ ( x V ) \text{SwiGLU}(x) = \text{Swish}(xW) \otimes (xV) SwiGLU(x)=Swish(xW)(xV)

📌 LLaMA 等现代大模型已改用 SwiGLU 替代传统 FFN,因其具有门控机制,表达能力更强。


5.2 结构变体

(1) GLU Variants(Gated Linear Units)

# SwiGLU 示例
def swiglu(x):
    x1, x2 = x.chunk(2, dim=-1)
    return F.silu(x1) * x2
  • 将输入分成两半,一半做激活,一半做门控
  • 表达能力更强,已成为大模型标配

(2) Depth-wise Separable FFN(减少参数)

  • 类似 CNN 中的 depthwise conv,降低计算量
  • 用于移动端/边缘设备模型(如 MobileBERT)

六、工程实现:PyTorch 代码详解

6.1 标准 FFN 实现

import torch
import torch.nn as nn

class PositionWiseFFN(nn.Module):
    def __init__(self, d_model: int, d_ff: int = None, dropout: float = 0.1, activation='relu'):
        super().__init__()
        self.d_model = d_model
        self.d_ff = d_ff if d_ff is not None else 4 * d_model  # 默认 4 倍
        
        self.fc1 = nn.Linear(d_model, self.d_ff)
        self.fc2 = nn.Linear(self.d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        
        if activation == 'relu':
            self.activation = nn.ReLU()
        elif activation == 'gelu':
            self.activation = nn.GELU()
        else:
            raise ValueError(f"Unsupported activation: {activation}")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (B, T, d_model)
        output: (B, T, d_model)
        """
        # 注意:残差连接通常在外部由 TransformerBlock 添加
        x = self.fc1(x)          # (B, T, d_ff)
        x = self.activation(x)   # (B, T, d_ff)
        x = self.dropout(x)
        x = self.fc2(x)          # (B, T, d_model)
        return x

⚠️ 重要澄清
残差连接(residual connection)和 LayerNorm 不属于 FFN 本身,而是 Transformer Block 的组成部分。FFN 只负责核心的非线性变换。


6.2 参数量计算

d_model=512, d_ff=2048 为例:

  • fc1: 512 × 2048 = 1 , 048 , 576 512 \times 2048 = 1,048,576 512×2048=1,048,576
  • fc2: 2048 × 512 = 1 , 048 , 576 2048 \times 512 = 1,048,576 2048×512=1,048,576
  • 总计 ≈ 2.1M 参数

💡 FFN 的参数量通常是多头注意力的 2~4 倍,是 Transformer 的主要参数来源!


七、常见误区澄清

❌ 误区 1:“FFN 就是一个全连接层”

→ 错!它是两个全连接层 + 激活函数的组合。

❌ 误区 2:“FFN 也建模 token 间关系”

→ 错!它是 position-wise 的,不跨 token 交互

❌ 误区 3:“FFN 可有可无,只是增加深度”

→ 错!去掉 FFN 会导致性能大幅下降(实验证明)。

❌ 误区 4:“FFN 的输入是原始词向量”

→ 错!它的输入是多头注意力的输出(已包含上下文信息)。

❌ 误区 5:“FFN 对每个位置使用不同的参数”

→ 错!所有位置共享同一套 FFN 参数(即 fc1fc2 是全局的)。


八、总结:FFN 的核心要点

维度 内容
名称 Position-wise Feed-Forward Network (FFN)
结构 两层 Linear + 非线性激活(如 ReLU/GELU/SwiGLU)
输入/输出 ( B , T , d model ) → ( B , T , d model ) (B, T, d_{\text{model}}) \rightarrow (B, T, d_{\text{model}}) (B,T,dmodel)(B,T,dmodel)
核心作用 1. 引入非线性2. 增强每个 token 的表示3. 与注意力分工协作
关键特性 Position-wise(独立处理每个位置,但参数全局共享)
参数量 主要集中在 FFN(占 Transformer 大部分参数)
现代变体 SwiGLU、GLU 等门控机制

九、延伸思考

  • 为什么 FFN 要放在注意力之后?
    因为注意力先聚合了上下文信息,FFN 再对这个“富含上下文的表示”进行非线性加工。

  • 能否交换顺序?
    理论上可以,但效果差。因为先做 FFN 无法利用上下文信息。

  • FFN 能否共享参数 across layers?
    一般不共享(每层学不同变换),但 ALBERT 等模型尝试过参数共享以压缩模型。


你现在不仅知道 FFN 是什么,更理解了它为什么存在、如何工作、有何演变。这为你下一步学习完整的 Transformer Block(注意力 + FFN + LayerNorm + 残差连接)打下了坚实基础。

继续前进吧!🚀


2、为什么 dropout 不算 “组成部分”

这正是区分"核心架构"和"工程优化"的关键!

让我给你彻底讲清楚:


🎯 核心答案

Dropout 是训练技巧,不是架构设计!

组件 是否必须 角色 在推理时
Linear 必须 核心变换 存在
激活函数 必须 非线性 存在
Dropout ❌ 可选 防止过拟合 关闭

🔬 为什么 Dropout 不算"组成部分"?

  1. 推理时的行为完全不同
# 训练时
x = self.fc1(x)        # 线性变换 ✅
x = self.activation(x) # 非线性 ✅
x = self.dropout(x)    # 随机置零 ⚠️ 训练专用

# 推理时
x = self.fc1(x)        # 线性变换 ✅
x = self.activation(x) # 非线性 ✅
x = x                  # Dropout 完全消失!

Dropout 在推理时根本不存在,怎么能叫"组成部分"呢?


  1. 历史/论文视角

看《Attention Is All You Need》原文:

“Each layer has two sub-layers. The first is a multi-head self-attention mechanism, and the second is a simple, position-wise fully connected feed-forward network.”

论文定义 FFN = Linear + ReLU + Linear,Dropout 是单独在别处说明的!


  1. 代码层面的分离
# HuggingFace Transformers 的 BERT 实现
class BertIntermediate(nn.Module):
    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)  # Linear
        hidden_states = self.intermediate_act_fn(hidden_states)  # Activation
        return hidden_states  # ❌ 没有 Dropout!

class BertOutput(nn.Module):
    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)  # Linear
        hidden_states = self.dropout(hidden_states)  # ✅ Dropout 在这里!
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states

看到了吗? Dropout 是放在残差连接之前,不属于 FFN 核心!


📊 架构层次表

层次 内容 是否FFN核心 推理时
核心架构 Linear + Activation + Linear ✅ 存在
训练优化 Dropout ❌ 否 ❌ 关闭
残差连接 x + sublayer(x) ❌ 否 ✅ 存在
层归一化 LayerNorm ❌ 否 ✅ 存在

🎨 类比理解

🍔 汉堡包店

组件 类比 是否核心
面包 + 肉饼 + 生菜 Linear + Activation + Linear ✅ 核心配方
外卖包装盒 Dropout ❌ 只是配送需要
堂食盘子 残差连接 ❌ 摆盘方式
酱料包 LayerNorm ❌ 调味品

你不会说"汉堡包=面包+肉饼+生菜+外卖盒"吧?


🏭 汽车工厂

组件 类比 是否核心
发动机 + 变速箱 + 车轮 Linear + Activation + Linear ✅ 核心
安全气囊 Dropout ❌ 安全配置
座椅 残差连接 ❌ 舒适配置
车载空调 LayerNorm ❌ 辅助功能

FFN的定义就像"汽车=发动机+变速箱+车轮",安全气囊是额外安全配置!


💡 为什么你的直觉是对的

你觉得Dropout应该算进去,是因为在代码里它确实和FFN写在一起

def forward(self, x):
    x = self.fc1(x)
    x = self.act(x)
    x = self.dropout(x)  # ← 写在一起
    x = self.fc2(x)
    return x

这是工程实践,不是架构定义!


📝 精确的说法

严格定义(学术论文)

FFN = Linear → Activation → Linear

工程实现(PyTorch代码)

FFN = Linear → Activation → Dropout → Linear

最佳教学实践

FFN的核心是两层的全连接网络,实践中通常会在两个Linear之间加入Dropout防止过拟合。


🎯 终极答案

不把Dropout算作FFN的"组成部分",是因为:

  1. 它不是必须的 - 你可以去掉Dropout,仍然是FFN
  2. 推理时不存在 - Dropout只在训练时激活
  3. 论文定义 - Transformer原始论文就没把它算进FFN
  4. 代码分离 - 很多实现把Dropout放在残差分支,不在FFN内部

在写代码时,Dropout确实和FFN紧密结合。这是工程实践和学术定义的区别!

你现在既理解架构本质,也明白工程实现,这才是真正的通透!🎯


3、raise 抛出异常

🌟 一句话理解 raise

raise 就是你在代码里“手动制造一个错误”,让程序停下来并告诉你:“这里出问题了!”


🔧 基本语法

raise 异常类型("错误信息")

常见用法示例:

# 1. 抛出 ValueError(值无效)
raise ValueError("年龄不能是负数")

# 2. 抛出 TypeError(类型错误)
raise TypeError("期望输入字符串,但得到了整数")

# 3. 抛出 RuntimeError(运行时错误)
raise RuntimeError("网络连接失败")

🆚 对比你熟悉的场景

场景 自动报错(Python 内置) 手动报错(你用 raise
除零错误 1 / 0ZeroDivisionError
索引越界 [][0]IndexError
传错参数 raise ValueError("激活函数不支持")
文件不存在 raise FileNotFoundError("配置文件缺失")

💡 raise 让你能在“Python 不会自动报错,但逻辑上不合理”的地方主动报错


🎯 为什么代码需要 raise

回顾代码:

if activation == 'relu':
    self.activation = nn.ReLU()
else:
    assert False, '还需要添加更多激活函数'

问题:如果用户写 activation='gelu',你希望发生什么?

  • 什么都不做? → 后面会 AttributeError(更难 debug)
  • 只打印提示? → 程序继续跑,结果错得离谱
  • 立刻停止并明确说“你传错了” → 这就是 raise 的作用!

raise 改写:

if activation == 'relu':
    self.activation = nn.ReLU()
else:
    raise ValueError(f"不支持的激活函数: '{activation}'。目前仅支持 'relu'。")

🧪 实际运行效果

def test_activation(activation):
    if activation == 'relu':
        print("使用 ReLU")
    else:
        raise ValueError(f"不支持: {activation}")

# 正常情况
test_activation('relu')  # 输出: 使用 ReLU

# 异常情况
test_activation('gelu')

输出

Traceback (most recent call last):
  File "test.py", line 9, in <module>
    test_activation('gelu')
  File "test.py", line 6, in test_activation
    raise ValueError(f"不支持: {activation}")
ValueError: 不支持: gelu

清晰、立即、无法忽略


❓ 那 assertraise 有什么区别?

特性 assert raise
用途 调试时检查“不可能发生”的情况 处理“可能发生的错误输入”
生产环境 python -O完全删除 永远生效
异常类型 只能抛 AssertionError 可抛任何异常(ValueError, TypeError…)
语义 “我断言这里一定为真” “这里出错了,请处理”

📌 规则

  • 检查用户输入/外部数据 → 用 raise
  • 检查内部逻辑是否自洽 → 用 assert

python -O 中的 -O“优化模式”(Optimization mode) 的缩写。当你用这个选项运行 Python 脚本时,解释器会跳过(完全忽略)所有的 assert 语句


🔍 具体行为

正常运行(无 -O):

# test.py
def check_age(age):
    assert age >= 0, "年龄不能为负数"
    print("年龄有效:", age)

check_age(-5)

运行:

python test.py

输出:

Traceback (most recent call last):
  File "test.py", line 5, in <module>
    check_age(-5)
  File "test.py", line 2, in check_age
    assert age >= 0, "年龄不能为负数"
AssertionError: 年龄不能为负数

assert 生效,程序报错停止


-O 运行(优化模式):

python -O test.py

输出:

年龄有效: -5

assert 被完全删除!程序继续执行,即使逻辑错误!

💥 这就是为什么 不能用 assert 做参数校验 —— 在生产环境(常启用 -O 提升性能)中,安全检查会静默失效


🧠 为什么 Python 要这样设计?

  • assert 的本意是 “调试时的内部逻辑检查”,比如:

    def divide(a, b):
        result = a / b
        assert result * b == a  # 验证计算是否自洽(仅开发时需要)
        return result
    
  • 生产环境,为了提升一点点性能,可以安全地移除这些“自我验证”代码。

  • 但它绝不应该用于处理用户输入、外部数据或关键逻辑分支


✅ 正确做法对比

场景 错误做法(用 assert 正确做法(用 raise
检查函数参数 assert x > 0 if x <= 0: raise ValueError(...)
验证配置文件 assert config is not None if config is None: raise RuntimeError(...)
处理用户输入 assert mode in ['train', 'eval'] if mode not in [...]: raise ValueError(...)

raise 永远不会被 -O 删除!它在任何模式下都生效。


🛠 如何验证?

创建 demo.py

print("=== 正常模式 ===")
try:
    assert False, "这条 assert 应该触发"
except AssertionError as e:
    print("捕获到 assert 错误:", e)

print("\n=== 现在请用 python -O demo.py 运行看效果 ===")
  1. 正常运行:

    python demo.py
    

    输出:捕获到 assert 错误: 这条 assert 应该触发

  2. 优化模式运行:

    python -O demo.py
    

    输出:

    === 正常模式 ===
    === 现在请用 python -O demo.py 运行看效果 ===
    

    assert 完全消失,没报错!


💡 总结

关键点 说明
python -O 启用优化模式,移除所有 assert 语句
后果 依赖 assert 的逻辑校验在生产环境失效
正确替代 if ...: raise ValueError(...) 做输入/状态校验
assert 用途 仅用于开发阶段的内部逻辑断言(如算法不变式)

🚨 记住
永远不要用 assert 来处理可能来自外部的数据或参数 —— 因为在 -O 模式下,“安全网”会凭空消失!

你现在完全理解为什么我建议你用 raise 而不是 assert 了吧?这可是专业 Python 开发的必修课!👏


✅ 在代码中正确使用 raise

class PositionWiseFeedForwardNetwork(nn.Module):
    def __init__(self, ..., activation: str = 'relu'):
        ...
        if activation == 'relu':
            self.activation = nn.ReLU()
        elif activation == 'gelu':
            self.activation = nn.GELU()
        else:
            # 明确告诉用户:你传的值不对!
            raise ValueError(
                f"激活函数 '{activation}' 未实现。"
                "支持的选项: 'relu', 'gelu'"
            )

💡 最后总结

关键点 说明
raise 是什么 手动触发错误的关键字
为什么用它 让错误尽早暴露,避免隐藏 bug
怎么用 raise ValueError("错误描述")
assert 区别 raise 用于输入校验assert 用于调试断言

🚀 学会 raise,你就掌握了专业开发者控制程序错误流的第一把钥匙

现在你可以自信地把 assert False 换成 raise ValueError 了!


4、代码

在《基本介绍》中也有

import torch
import torch.nn as nn


class PositionWiseFeedForwardNetwork(nn.Module):
    def __init__(
            self,
            d_model: int,
            d_ff: int = None,
            dropout: float = 0.1,
            activation = 'relu'
    ):
        super().__init__()
        self.d_model = d_model
        self.d_ff = d_ff if d_ff is not None else 4 * d_model   # 默认 4 倍

        self.linear1 = nn.Linear(in_features=d_model, out_features=self.d_ff)
        if activation == 'relu':
            self.activation = nn.ReLU()
        else:
            raise ValueError('还需要添加更多激活函数')
            raise ValueError(f"Unsupported activation: '{activation}'. Currently only 'relu' is supported.")
        self.dropout = nn.Dropout(p=dropout)

        self.linear2 = nn.Linear(in_features=self.d_ff, out_features=self.d_model)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # 返回值预计为 张量(但Python解释器可不管什么类型,即使不返回张量也不会报错)
        """
        :param x: x.shape = (B, T, d_model)
        :return: (B, T, d_model)
        """
        x = self.linear1(x)
        x = self.activation(x)
        x = self.dropout(x)

        x = self.linear2(x)

        # 可一行搞定
        # return self.linear2(self.dropout(self.activation(self.linear1(x))))

        return x


Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐