第 02 篇:Tensor 是一切的基础——维度变换完全指南
目录
第三步:reshape 和 view——改变形状,不改变数据
第四步:squeeze 和 unsqueeze——维度的增删
第五步:transpose 和 permute——调整轴的顺序
重要警告:permute/transpose 之后内存不再连续
第六步:flatten 和 unflatten——维度的合并与分裂
第八步:广播(Broadcasting)——最容易踩坑也最强大的机制
第九步:矩阵乘法家族——@ 、matmul、bmm、einsum
最终大战:Multi-Head Attention 的完整维度变换流程
如果你去读任何一个大型模型的源码——LLaMA、BERT、ViT、Stable Diffusion——你会发现充斥着这些东西:
x = x.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
x = x.reshape(B, T, self.n_heads * self.head_dim)
q = q.permute(0, 2, 1, 3).contiguous()
attn = torch.einsum('bhid,bhjd->bhij', q, k)
x = x.unsqueeze(0).expand(B, -1, -1)
mask = mask[:, None, None, :]
out = out.flatten(2)
这些东西不是装饰,是核心逻辑。每一行都在改变 Tensor 的形状、调整轴的顺序、合并或拆分维度——任何一行理解错了,整个模型就跑不通。
更糟糕的是,这些操作单独看都很简单,但当它们组合在一起、维度在五六层操作之间反复变化的时候,大多数人就开始犯晕了。
这篇文章就是要把这件事彻底搞清楚。
我们的目标不是背 API,而是建立一套看到任何维度变换代码都能立刻反应过来"它在把什么形状的数据变成什么形状"的直觉。整篇文章以一个从简单到复杂的矩阵乘法贯穿,最后落到一个完整的 Multi-Head Attention 的维度变换全过程——这是大型模型里最密集、最难懂的维度操作场景。
第一步:建立正确的维度观
在学任何具体操作之前,先要解决一个认知问题:怎么理解多维 Tensor?
大多数人遇到二维矩阵觉得很自然(行列),但到了三维、四维、五维就开始凭空想象,越想越绕。
有一个更实用的理解方式:把 Tensor 的每一个维度理解为一个"循环层"。
一个形状为 (2, 3, 4) 的 Tensor,可以这样理解:
- 最外层循环 2 次(第 0 维)
- 中间循环 3 次(第 1 维)
- 最内层循环 4 次(第 2 维)
在深度学习里,这些维度通常有具体含义,最常见的命名约定是:
(B, C, H, W) —— 图像:batch size, channels, height, width
(B, T, D) —— 序列:batch size, sequence length, embedding dim
(B, H, T, D) —— 注意力:batch, heads, sequence, head_dim
养成一个习惯:每次操作之后,在注释里写下当前 Tensor 的形状和每个维度的含义。 这不是强迫症,是大型模型代码能不能看懂的关键。看源码时,一旦某行不知道输入是什么形状,整个后续逻辑就全垮了。
import torch
# 一个典型的 NLP 输入
x = torch.randn(8, 512, 768) # (B=8, T=512, D=768)
# B: batch size,这批数据有 8 个样本
# T: sequence length,每个样本最多 512 个 token
# D: embedding dimension,每个 token 被表示为 768 维向量
第二步:形状查询——你的第一个工具
在做任何操作之前,先学会看。
x = torch.randn(8, 512, 768)
# 三种查看形状的方式,效果相同
print(x.shape) # torch.Size([8, 512, 768])
print(x.size()) # torch.Size([8, 512, 768])
print(x.shape[0]) # 8 —— 取具体某个维度的值
# 维度数量
print(x.ndim) # 3
print(x.dim()) # 3 —— 和 ndim 等价
# 元素总数
print(x.numel()) # 8 * 512 * 768 = 3145728
# 数据类型
print(x.dtype) # torch.float32
torch.Size 本质上就是一个整数元组,支持所有 Python 元组操作:
B, T, D = x.shape # 解包,非常常用
print(B, T, D) # 8 512 768
这种解包写法在模型源码里极其普遍,原因是它让后续的维度变换代码更可读:不需要写 x.shape[0],直接用 B、T、D 这样有语义的变量名。
第三步:reshape 和 view——改变形状,不改变数据
这是最基础的维度变换,也是理解后续一切的起点。
核心原则:只要元素总数相同,可以把 Tensor 变成任何形状。
x = torch.arange(24) # 一个有 24 个元素的一维向量
print(x.shape) # torch.Size([24])
# 变成 (2, 12)
a = x.reshape(2, 12)
print(a.shape) # torch.Size([2, 12])
# 变成 (2, 3, 4)
b = x.reshape(2, 3, 4)
print(b.shape) # torch.Size([2, 3, 4])
# 变成 (4, 6)
c = x.reshape(4, 6)
print(c.shape) # torch.Size([4, 6])
# -1 的作用:让 PyTorch 自动推算这个维度的大小
d = x.reshape(2, -1) # 24 / 2 = 12,所以 -1 自动变成 12
print(d.shape) # torch.Size([2, 12])
e = x.reshape(-1, 6) # 24 / 6 = 4
print(e.shape) # torch.Size([4, 6])
# 可以有且只有一个 -1
# x.reshape(-1, -1) ← 报错!不能同时有两个 -1
-1 这个用法在模型代码里无处不在,因为 batch size 在不同场景下是变化的,你不想把它硬编码进形状参数里。
view vs reshape:一个必须搞清楚的区别
view 和 reshape 从效果上几乎一样,但有一个关键差异:
view 要求 Tensor 在内存上必须是连续的(contiguous);reshape 不要求。
x = torch.arange(24).reshape(4, 6)
# 对连续 Tensor,view 和 reshape 效果相同
y = x.view(2, 12) # OK
z = x.reshape(2, 12) # OK
# 但是 transpose/permute 之后,内存布局会变得"不连续"
x_t = x.transpose(0, 1) # shape: (6, 4),但内存不连续
print(x_t.is_contiguous()) # False
# 此时 view 会报错
# x_t.view(24) # RuntimeError: view size is not compatible...
# reshape 会自动处理(内部会做一次内存拷贝)
x_t_flat = x_t.reshape(24) # OK,shape: (24,)
# 或者手动调用 contiguous() 再 view
x_t_cont = x_t.contiguous()
print(x_t_cont.is_contiguous()) # True
x_t_view = x_t_cont.view(24) # OK
这个细节在大模型源码里非常关键。你经常会看到这样的模式:
# LLaMA 源码中的典型写法
x = x.transpose(1, 2).contiguous().view(B, T, -1)
# ↑ 关键!transpose 后内存不连续,
# 必须先 contiguous() 才能 view()
实践建议:在搭模型的过程中,优先用 reshape,因为它更宽容。只有在明确知道自己在做性能优化的时候,才考虑用 view(因为 view 保证不拷贝内存,是零拷贝操作)。
第四步:squeeze 和 unsqueeze——维度的增删
这两个操作是一对:unsqueeze 在某个位置插入一个大小为 1 的新维度,squeeze 把大小为 1 的维度去掉。
x = torch.randn(8, 512) # (B, T)
# unsqueeze:在指定位置插入一个维度 1
x1 = x.unsqueeze(0) # 在第 0 维前面插入 → (1, 8, 512)
x2 = x.unsqueeze(1) # 在第 1 维前面插入 → (8, 1, 512)
x3 = x.unsqueeze(2) # 在第 2 维前面插入 → (8, 512, 1)
x4 = x.unsqueeze(-1) # 在最后一维后面插入 → (8, 512, 1),等价于 unsqueeze(2)
x5 = x.unsqueeze(-2) # 倒数第 2 维前面插入 → (8, 1, 512),等价于 unsqueeze(1)
print(x1.shape) # torch.Size([1, 8, 512])
print(x2.shape) # torch.Size([8, 1, 512])
print(x3.shape) # torch.Size([8, 512, 1])
squeeze 是反操作,去掉大小为 1 的维度:
x = torch.randn(1, 8, 1, 512, 1)
# 不指定维度:去掉所有大小为 1 的维度
y = x.squeeze()
print(y.shape) # torch.Size([8, 512])
# 指定维度:只去掉指定的那个(如果它大小是 1)
z = x.squeeze(0) # 去掉第 0 维(大小为 1)
print(z.shape) # torch.Size([8, 1, 512, 1])
w = x.squeeze(1) # 试图去掉第 1 维(大小为 8,不是 1)
print(w.shape) # torch.Size([1, 8, 1, 512, 1]) ← 无变化!
unsqueeze 在模型里最常见的用途:为广播做准备
这一点非常重要,放在广播那一节里配合解释。先记住这个场景:
# attention mask 的处理:从 (B, T) 变成 (B, 1, 1, T)
# 用来和 (B, H, T, T) 的注意力分数做广播相加
mask = torch.zeros(8, 512) # (B, T)
mask = mask.unsqueeze(1).unsqueeze(2) # (B, 1, 1, T)
# 或者等价的链式写法
mask = torch.zeros(8, 512)[:, None, None, :] # 用索引切片 None 也可以插入维度
print(mask.shape) # torch.Size([8, 1, 1, 512])
None 作为索引用来插入维度是另一种常见写法,效果和 unsqueeze 一样。你在 Transformer 源码里会经常见到 mask[:, None, None, :] 这样的写法,就是在给 mask 加上 head 维度以便和注意力矩阵广播。
第五步:transpose 和 permute——调整轴的顺序
reshape 只改变维度大小,不改变数据在维度间的排列顺序。如果你需要交换或者重排轴,要用 transpose 和 permute。
transpose:交换两个指定的维度
x = torch.randn(8, 512, 768) # (B, T, D)
# transpose(dim0, dim1):交换两个维度
y = x.transpose(0, 1) # 交换第 0 和第 1 维 → (T, B, D) = (512, 8, 768)
z = x.transpose(1, 2) # 交换第 1 和第 2 维 → (B, D, T) = (8, 768, 512)
print(y.shape) # torch.Size([512, 8, 768])
print(z.shape) # torch.Size([8, 768, 512])
# 对二维矩阵,transpose() 不传参数就是转置
A = torch.randn(3, 4)
print(A.T.shape) # torch.Size([4, 3])
print(A.transpose(0,1).shape) # torch.Size([4, 3]) 等价
permute:一次性重排所有维度的顺序
当你需要同时移动多个维度的时候,permute 比反复 transpose 更清晰。
x = torch.randn(8, 512, 12, 64) # (B, T, H, D_head)
# 注意力中经常需要把 H 维度移到 T 前面,变成 (B, H, T, D_head)
# 用两次 transpose:
y = x.transpose(1, 2) # (B, H, T, D_head) —— 这里其实只需要一次
# 更一般的情况:用 permute,直接指定新的维度顺序
# permute(0, 2, 1, 3) 意思是:
# 新的第 0 维 = 旧的第 0 维(B)
# 新的第 1 维 = 旧的第 2 维(H)
# 新的第 2 维 = 旧的第 1 维(T)
# 新的第 3 维 = 旧的第 3 维(D_head)
z = x.permute(0, 2, 1, 3) # (B, H, T, D_head)
print(z.shape) # torch.Size([8, 12, 512, 64])
# 再比如一个 5 维 Tensor
a = torch.randn(2, 3, 4, 5, 6)
b = a.permute(4, 0, 2, 1, 3) # 新顺序是 (6维, 2维, 4维, 3维, 5维)
print(b.shape) # torch.Size([6, 2, 4, 3, 5])
重要警告:permute/transpose 之后内存不再连续
这一点前面提过,但要单独强调,因为它是初学者非常容易碰到的报错来源:
x = torch.randn(8, 512, 768)
y = x.transpose(1, 2) # (8, 768, 512)
print(y.is_contiguous()) # False ← 注意!
# 直接 view 会报错
# y.view(8, -1) ← RuntimeError
# 解决方案 1:先 contiguous()
y.contiguous().view(8, -1) # OK
# 解决方案 2:用 reshape(自动处理)
y.reshape(8, -1) # OK
为什么 transpose 之后内存不连续?
这涉及到 PyTorch Tensor 的内存模型。一个 Tensor 包含两个东西:实际的数据存储(storage)和描述如何解释数据的元信息(stride、shape)。
stride 是一个整数元组,表示"沿着某个维度前进一步,需要在内存里跳过多少个元素"。
x = torch.randn(3, 4)
print(x.stride()) # (4, 1) —— 沿第 0 维(行)走一步要跳 4 个,沿第 1 维(列)走一步跳 1 个
y = x.transpose(0, 1) # shape: (4, 3)
print(y.stride()) # (1, 4) —— stride 也跟着变了
# 内存里数据没动,只是 stride 的解读方式变了
# 但 view 要求数据在内存中必须按 C 连续顺序排列(stride[-1]=1, stride[-2]=shape[-1], ...)
# y 的 stride 是 (1, 4),不满足这个要求,所以 view 会报错
contiguous() 会把数据重新排列到内存中,使其满足连续要求,代价是一次内存拷贝。
第六步:flatten 和 unflatten——维度的合并与分裂
flatten 把若干连续的维度合并成一个。
x = torch.randn(8, 196, 768) # (B, T, D),比如 ViT 的 patch 序列
# flatten() 默认展平所有维度
y = x.flatten()
print(y.shape) # torch.Size([8 * 196 * 768]) = torch.Size([1204224])
# flatten(start_dim, end_dim):只展平指定范围内的维度
z = x.flatten(1, 2) # 把第 1 和第 2 维合并 → (B, T*D) = (8, 150528)
print(z.shape) # torch.Size([8, 150528])
# 在 ViT 里常见的用法:把 (B, H, W, C) 的 patch 展平为 (B, H*W, C)
patches = torch.randn(8, 14, 14, 768) # (B, grid_h, grid_w, D)
seq = patches.flatten(1, 2) # (B, 196, 768)
print(seq.shape) # torch.Size([8, 196, 768])
unflatten 是反操作,把一个维度拆分成多个维度:
x = torch.randn(8, 196, 768)
# unflatten(dim, sizes):把第 1 维的 196 拆成 (14, 14)
y = x.unflatten(1, (14, 14))
print(y.shape) # torch.Size([8, 14, 14, 768])
# 另一个例子:把 head 维度拆出来
# (B, T, H*D_head) → (B, T, H, D_head)
z = torch.randn(8, 512, 768) # H=12, D_head=64, 12*64=768
w = z.unflatten(2, (12, 64))
print(w.shape) # torch.Size([8, 512, 12, 64])
第七步:expand 和 repeat——维度的复制扩展
这两个操作都是用来"复制"数据的,但原理和内存消耗完全不同。
expand:零拷贝的扩展(推荐)
expand 只能在大小为 1 的维度上扩展,而且它是零拷贝的——不实际复制数据,只是修改 stride 让同一块内存被"访问多次"。
x = torch.tensor([[1], [2], [3]]) # shape: (3, 1)
# 把大小为 1 的维度扩展为 4
y = x.expand(3, 4)
print(y)
# tensor([[1, 1, 1, 1],
# [2, 2, 2, 2],
# [3, 3, 3, 3]])
print(y.shape) # torch.Size([3, 4])
# -1 表示这个维度保持不变
z = x.expand(-1, 4) # 等价于 expand(3, 4)
print(z.shape) # torch.Size([3, 4])
# 在 attention 里常见的用法:
# 把 (1, T, D) 的 position embedding 扩展到 batch 里每个样本
pos_emb = torch.randn(1, 512, 768) # (1, T, D)
pos_emb_batch = pos_emb.expand(8, -1, -1) # (B, T, D) = (8, 512, 768)
print(pos_emb_batch.shape) # torch.Size([8, 512, 768])
# 实际上 8 个样本共享同一块内存,没有数据被复制!
repeat:真实的数据复制
repeat 接受每个维度重复的次数,会真实地复制数据:
x = torch.tensor([[1, 2], [3, 4]]) # (2, 2)
# repeat(2, 3):第 0 维重复 2 次,第 1 维重复 3 次
y = x.repeat(2, 3)
print(y)
# tensor([[1, 2, 1, 2, 1, 2],
# [3, 4, 3, 4, 3, 4],
# [1, 2, 1, 2, 1, 2],
# [3, 4, 3, 4, 3, 4]])
print(y.shape) # torch.Size([4, 6])
实践建议:能用 expand 的地方不用 repeat,因为 expand 是零拷贝,节省内存也更快。repeat 在某些需要真实独立副本(比如需要分别修改每个副本)的场景才用。
第八步:广播(Broadcasting)——最容易踩坑也最强大的机制
广播是指:当两个形状不完全相同的 Tensor 进行运算时,PyTorch 会自动"扩展"较小的那个,使两者形状匹配,然后做运算。
广播遵循以下规则(从尾部维度开始对齐):
- 如果两个 Tensor 维度数不同,短的那个在最前面补 1,直到维度数相同;
- 对于每个维度,如果两者大小相同,直接匹配;如果一方是 1,则扩展到另一方的大小;如果两者都不是 1 且大小不同,报错。
# 最简单的例子
a = torch.tensor([1.0, 2.0, 3.0]) # shape: (3,)
b = torch.tensor(2.0) # shape: () —— 标量
print((a + b).shape) # torch.Size([3]) —— 标量被广播到每个元素
# 矩阵 + 向量
A = torch.randn(4, 3) # (4, 3)
v = torch.randn(3) # (3,)
# PyTorch 把 v 的形状从 (3,) 补为 (1, 3),再广播为 (4, 3)
print((A + v).shape) # torch.Size([4, 3])
# 三维的情况
x = torch.randn(8, 512, 768) # (B, T, D)
bias = torch.randn(768) # (D,)
# bias 形状:(768,) → 补为 (1, 1, 768) → 广播为 (8, 512, 768)
print((x + bias).shape) # torch.Size([8, 512, 768])
广播的实际应用:attention mask
这是大模型里最典型的广播场景,理解它可以让你读懂所有 Transformer 源码里的 mask 操作。
# attention scores 的形状:(B, H, T_q, T_k)
attn_scores = torch.randn(8, 12, 512, 512) # (B, H, T, T)
# padding mask 的形状:(B, T_k) —— 标记哪些位置是填充,应当被忽略
padding_mask = torch.zeros(8, 512) # 0 表示有效位置,1 表示填充位置
padding_mask[:, 300:] = 1 # 假设 300 之后都是填充
# 想把 mask 加到 attn_scores 上,需要让形状对齐
# 目标:(B, T_k) → (B, 1, 1, T_k),这样可以广播到 (B, H, T_q, T_k)
mask_4d = padding_mask.unsqueeze(1).unsqueeze(2) # (B, 1, 1, T_k)
print(mask_4d.shape) # torch.Size([8, 1, 1, 512])
# 广播加法:(8, 12, 512, 512) + (8, 1, 1, 512) → (8, 12, 512, 512)
attn_scores = attn_scores + mask_4d * (-1e9) # 填充位置加一个极大负值,softmax 后趋近于 0
print(attn_scores.shape) # torch.Size([8, 12, 512, 512])
这里的广播逻辑:(8, 1, 1, 512) 中的两个 1 在计算时会自动扩展为 12 和 512,也就是对所有 head(H)和所有 query 位置(T_q)施加相同的 mask。
广播和形状不匹配报错的区别
a = torch.randn(3, 4)
b = torch.randn(3, 5)
# a + b 会报错吗?
# 从尾部对齐:4 vs 5,两者都不是 1 且不相等 → RuntimeError!
a = torch.randn(3, 1)
b = torch.randn(3, 5)
# 从尾部对齐:1 vs 5,1 可以广播为 5 → OK,结果是 (3, 5)
print((a + b).shape) # torch.Size([3, 5])
第九步:矩阵乘法家族——@ 、matmul、bmm、einsum
矩阵乘法是深度学习里最核心的计算,但它在不同维度下有多种写法,每种都有自己的适用场景。
基础:二维矩阵乘法
A = torch.randn(3, 4) # (3, 4)
B = torch.randn(4, 5) # (4, 5)
# 三种等价写法
C1 = torch.mm(A, B) # (3, 5) —— 只支持二维
C2 = torch.matmul(A, B) # (3, 5) —— 支持任意维度
C3 = A @ B # (3, 5) —— @ 是 matmul 的语法糖
print(C1.shape) # torch.Size([3, 5])
带 batch 维度的矩阵乘法:bmm
bmm(Batch Matrix Multiplication)处理三维输入,第一维是 batch 维度,对 batch 里的每个矩阵分别做乘法:
A = torch.randn(8, 3, 4) # (B, 3, 4)
B = torch.randn(8, 4, 5) # (B, 4, 5)
C = torch.bmm(A, B) # (B, 3, 5) = (8, 3, 5)
print(C.shape) # torch.Size([8, 3, 5])
# 等价于:对 batch 里每个样本,独立做 (3,4) × (4,5) = (3,5)
matmul 的广播能力
matmul(以及 @)在处理高维 Tensor 时,除了最后两维做矩阵乘法,前面的维度都做广播:
# 四维的情况:(B, H, T, D) × (B, H, D, T) → (B, H, T, T)
# 这正是 attention 里计算 Q×K^T 的操作
Q = torch.randn(8, 12, 512, 64) # (B, H, T, D_head)
K = torch.randn(8, 12, 512, 64) # (B, H, T, D_head)
# K 需要转置最后两维才能做乘法
scores = Q @ K.transpose(-2, -1) # (B, H, T, D_head) × (B, H, D_head, T) → (B, H, T, T)
print(scores.shape) # torch.Size([8, 12, 512, 512])
# 注意 transpose(-2, -1) 是转置最后两个维度,等价于 transpose(2, 3)
# 使用负数索引更通用,不需要知道总共有几维
# 广播的例子:如果 K 只有 (H, T, D_head),没有 batch 维
K_no_batch = torch.randn(12, 512, 64) # (H, T, D_head)
# matmul 会自动广播:(8, 12, 512, 64) × (12, 64, 512) → (8, 12, 512, 512)
scores2 = Q @ K_no_batch.transpose(-2, -1)
print(scores2.shape) # torch.Size([8, 12, 512, 512])
einsum:最强大也最难读懂的操作
einsum 是 Einstein Summation(爱因斯坦求和约定)的实现,可以用一个字符串公式表达几乎任何张量操作。在大型模型的源码(尤其是研究代码)里非常常见。
语法规则:
torch.einsum("输入下标->输出下标", tensor1, tensor2, ...)
- 每个字母代表一个维度
- 出现在左边但不出现在右边的维度会被求和(收缩)
- 相同字母在不同 tensor 里代表相同的维度,做逐元素运算
从简单到复杂,逐步讲:
# 1. 向量点积:(n,) × (n,) → 标量
a = torch.randn(4)
b = torch.randn(4)
dot = torch.einsum('i,i->', a, b)
print(dot.shape) # torch.Size([]) —— 标量
# 2. 外积:(m,) × (n,) → (m, n)
outer = torch.einsum('i,j->ij', a, b)
print(outer.shape) # torch.Size([4, 4])
# 3. 矩阵乘法:(m, k) × (k, n) → (m, n)
A = torch.randn(3, 4)
B = torch.randn(4, 5)
C = torch.einsum('ik,kj->ij', A, B)
print(C.shape) # torch.Size([3, 5])
# k 在右边没有出现 → k 维度被求和(这就是矩阵乘法的"对齐收缩")
# 4. batch 矩阵乘法:(B, m, k) × (B, k, n) → (B, m, n)
A = torch.randn(8, 3, 4)
B = torch.randn(8, 4, 5)
C = torch.einsum('bik,bkj->bij', A, B) # b 是 batch,共享不求和
print(C.shape) # torch.Size([8, 3, 5])
# 5. 转置:(m, n) → (n, m)
A = torch.randn(3, 4)
B = torch.einsum('ij->ji', A)
print(B.shape) # torch.Size([4, 3])
# 6. 逐行求和:(m, n) → (m,)
A = torch.randn(3, 4)
row_sum = torch.einsum('ij->i', A)
print(row_sum.shape) # torch.Size([3])
# 7. trace(矩阵的迹,对角线元素之和):(n, n) → 标量
A = torch.randn(4, 4)
trace = torch.einsum('ii->', A)
print(trace.shape) # torch.Size([])
# 8. attention 的 Q×K^T:(B, H, T, D) × (B, H, T, D) → (B, H, T, T)
Q = torch.randn(8, 12, 512, 64)
K = torch.randn(8, 12, 512, 64)
scores = torch.einsum('bhid,bhjd->bhij', Q, K)
# b: batch,h: head,i: query position,j: key position,d: head dim
# d 在右边没有出现 → d 维度被求和(这正是 Q 和 K 的点积)
print(scores.shape) # torch.Size([8, 12, 512, 512])
# 9. attention output 的加权求和:(B, H, T, T) × (B, H, T, D) → (B, H, T, D)
attn_weights = torch.softmax(scores, dim=-1)
V = torch.randn(8, 12, 512, 64)
out = torch.einsum('bhij,bhjd->bhid', attn_weights, V)
# i: query position,j: key/value position(被求和)
print(out.shape) # torch.Size([8, 12, 512, 64])
einsum 的优势在于,公式本身就是对维度语义的注释——你一眼就能看出哪些维度在参与什么操作,哪些被求和。它的劣势是初次看到很难读懂,需要一些练习。
第十步:其他常用操作
除了上面的主角,模型源码里还有一些频繁出现的操作需要认识。
cat 和 stack:Tensor 的拼接
a = torch.randn(8, 512, 256)
b = torch.randn(8, 512, 256)
# cat:沿已有维度拼接,不增加新维度
# 沿最后一维拼接,把两个 256 维特征合并为 512 维
c = torch.cat([a, b], dim=-1)
print(c.shape) # torch.Size([8, 512, 512])
# 沿第 0 维拼接,把两个 batch 合并
d = torch.cat([a, b], dim=0)
print(d.shape) # torch.Size([16, 512, 256])
# stack:沿新维度堆叠,会增加一个新维度
e = torch.stack([a, b], dim=0) # 在第 0 维新建一个维度
print(e.shape) # torch.Size([2, 8, 512, 256])
e2 = torch.stack([a, b], dim=1) # 在第 1 维新建
print(e2.shape) # torch.Size([8, 2, 512, 256])
cat vs stack 的区别:cat 要求被拼接的 Tensor 在拼接维度之外形状相同;stack 要求所有 Tensor 形状完全相同,结果会多一个维度。
chunk 和 split:Tensor 的切分
x = torch.randn(8, 512, 768)
# chunk(n, dim):沿 dim 维均匀切成 n 份
parts = x.chunk(3, dim=-1) # 沿最后一维切成 3 份(768 / 3 = 256)
print(len(parts)) # 3
print(parts[0].shape) # torch.Size([8, 512, 256])
# split(size, dim):沿 dim 维按指定大小切
# size 是一个整数(等分)或 list(不等分)
parts2 = x.split(256, dim=-1) # 每份大小为 256
print(len(parts2)) # 3
print(parts2[0].shape) # torch.Size([8, 512, 256])
# 不等分
parts3 = x.split([256, 256, 256], dim=-1) # 等价于上面
parts4 = x.split([256, 512], dim=-1) # 切成 256 和 512 两份
print(parts4[0].shape) # torch.Size([8, 512, 256])
print(parts4[1].shape) # torch.Size([8, 512, 512])
在 Transformer 的 QKV 计算里,split 非常常见:
# 一个线性层同时投影 Q、K、V
qkv_proj = torch.nn.Linear(768, 768 * 3)
x = torch.randn(8, 512, 768)
qkv = qkv_proj(x) # (B, T, 3*D) = (8, 512, 2304)
# 切分成 Q、K、V
q, k, v = qkv.split(768, dim=-1)
# q: (8, 512, 768), k: (8, 512, 768), v: (8, 512, 768)
gather 和 index_select:按索引选取元素
x = torch.randn(8, 512, 768)
# index_select(dim, index):沿指定维度选取指定索引的切片
indices = torch.tensor([0, 5, 10, 100]) # 选第 1、6、11、101 个 token
selected = x.index_select(1, indices) # (B, 4, D)
print(selected.shape) # torch.Size([8, 4, 768])
# gather(dim, index):更灵活的按索引选取,index 和 input 形状相同
# 常用于按照模型输出的类别 index 选取对应的 logit
logits = torch.randn(8, 512, 30000) # (B, T, vocab_size)
label_ids = torch.randint(0, 30000, (8, 512, 1)) # (B, T, 1)
selected_logits = logits.gather(-1, label_ids) # (B, T, 1),选出每个位置真实 label 的 logit
print(selected_logits.shape) # torch.Size([8, 512, 1])
masked_fill:按条件填充值
这在 attention mask 里极为常见:
x = torch.randn(8, 12, 512, 512) # attention scores
# 创建一个上三角 mask(用于 causal attention,防止 token 看到未来)
# triu 返回上三角矩阵
causal_mask = torch.triu(torch.ones(512, 512), diagonal=1).bool()
# causal_mask[i, j] = True 表示位置 i 不应该 attend 到位置 j(j > i)
# masked_fill(mask, value):mask 为 True 的位置填充 value
x_masked = x.masked_fill(causal_mask, float('-inf'))
# softmax 之后 -inf 位置的权重会变成 0
print(x_masked.shape) # torch.Size([8, 12, 512, 512])
flip 和 roll:翻转和滚动
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
# flip:沿指定维度翻转顺序
print(torch.flip(x, dims=[0])) # 沿第 0 维翻转(行顺序颠倒)
# tensor([[4, 5, 6],
# [1, 2, 3]])
print(torch.flip(x, dims=[1])) # 沿第 1 维翻转(列顺序颠倒)
# tensor([[3, 2, 1],
# [6, 5, 4]])
# roll:滚动,把末尾元素移到开头(或反之)
print(torch.roll(x, shifts=1, dims=1))
# tensor([[3, 1, 2],
# [6, 4, 5]]) —— 每行向右滚动 1 位
repeat_interleave:交错重复
x = torch.tensor([1, 2, 3])
print(torch.repeat_interleave(x, 3))
# tensor([1, 1, 1, 2, 2, 2, 3, 3, 3]) —— 每个元素重复 3 次
# 在 GQA(Grouped Query Attention)里常用:
# K、V 的 head 数量比 Q 少,需要把 K、V 的 head 复制几份
k = torch.randn(8, 4, 512, 64) # (B, num_kv_heads=4, T, D)
k_expanded = k.repeat_interleave(3, dim=1) # 每个 kv_head 重复 3 次
print(k_expanded.shape) # torch.Size([8, 12, 512, 64]) —— 和 Q 的 12 个 head 对应
第十一步:einops——让维度操作变得可读
如果你读过 ViT、DiT 或者一些研究型代码,会见到 einops 这个库。它提供了一套用字符串描述维度操作的接口,可读性极高。
# pip install einops
from einops import rearrange, repeat, reduce
x = torch.randn(8, 512, 768) # (B, T, D)
# rearrange:等价于 permute + reshape 的组合,但更易读
# 例 1:把 (B, T, H*D) 拆成 (B, H, T, D)
y = rearrange(x, 'b t (h d) -> b h t d', h=12)
print(y.shape) # torch.Size([8, 12, 512, 64])
# 例 2:把 (B, H, T, D) 合并回 (B, T, H*D)
z = rearrange(y, 'b h t d -> b t (h d)')
print(z.shape) # torch.Size([8, 512, 768])
# 例 3:把图像的 (B, C, H, W) 转成 patch 序列 (B, N, P*P*C)
img = torch.randn(8, 3, 224, 224) # (B, C, H, W)
patches = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=16, p2=16)
print(patches.shape) # torch.Size([8, 196, 768]) —— 224/16=14, 14*14=196, 16*16*3=768
# repeat:带命名的扩展,比 expand 更清晰
pos_emb = torch.randn(1, 512, 768) # (1, T, D)
pos_batch = repeat(pos_emb, '1 t d -> b t d', b=8)
print(pos_batch.shape) # torch.Size([8, 512, 768])
# reduce:带命名的 reduction(mean、max、sum 等)
x = torch.randn(8, 512, 768)
x_mean = reduce(x, 'b t d -> b d', 'mean') # 沿 T 维做平均池化
print(x_mean.shape) # torch.Size([8, 768])
einops 的公式是自文档的:'b t (h d) -> b h t d' 直接告诉你输入是 (B, T, H*D),输出是 (B, H, T, D)。这比 x.reshape(B, T, 12, 64).permute(0, 2, 1, 3) 可读性高很多,在合作项目和代码审查中特别有价值。
最终大战:Multi-Head Attention 的完整维度变换流程
好,所有工具都介绍完了。现在我们用它们来完整走一遍 Multi-Head Attention 的维度变换,这是大型模型里维度操作最密集、也最考验理解的地方。
用的设置:B=8(batch),T=512(sequence length),D=768(embedding dim),H=12(heads),D_h=64(head dim,768/12=64)。
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
# 输入
B, T, D, H = 8, 512, 768, 12
D_h = D // H # 64,每个 head 的维度
x = torch.randn(B, T, D) # (8, 512, 768)
print(f"输入: {x.shape}") # (B, T, D)
# ------- 步骤 1:线性投影,生成 Q、K、V -------
W_q = nn.Linear(D, D, bias=False)
W_k = nn.Linear(D, D, bias=False)
W_v = nn.Linear(D, D, bias=False)
Q = W_q(x) # (B, T, D) = (8, 512, 768)
K = W_k(x) # (B, T, D)
V = W_v(x) # (B, T, D)
print(f"Q K V 投影后: {Q.shape}") # (B, T, D)
# ------- 步骤 2:把 D 维拆成 H 个 head -------
# (B, T, D) → (B, T, H, D_h)
Q = Q.reshape(B, T, H, D_h) # (8, 512, 12, 64)
K = K.reshape(B, T, H, D_h)
V = V.reshape(B, T, H, D_h)
print(f"拆 head 后: {Q.shape}") # (B, T, H, D_h)
# ------- 步骤 3:把 H 维挪到 T 前面,方便做注意力 -------
# 注意力的计算是"每个 head,对序列做 attention"
# 需要形状 (B, H, T, D_h),让 H 和 T 分离
# (B, T, H, D_h) → (B, H, T, D_h)
Q = Q.permute(0, 2, 1, 3) # 等价于 Q.transpose(1, 2),但 permute 更显式
K = K.permute(0, 2, 1, 3)
V = V.permute(0, 2, 1, 3)
print(f"permute 后: {Q.shape}") # (B, H, T, D_h)
# 注意:permute 之后内存不连续,如果后续要用 view,需要先 contiguous()
# .contiguous() 在这里不是必须的,因为下面的 matmul 可以处理不连续 Tensor
# 但在某些框架实现里你会看到显式加上
# ------- 步骤 4:计算注意力分数 Q × K^T -------
# (B, H, T, D_h) × (B, H, D_h, T) → (B, H, T, T)
# K.transpose(-2, -1):交换最后两个维度 (B, H, T, D_h) → (B, H, D_h, T)
scale = math.sqrt(D_h)
attn_scores = Q @ K.transpose(-2, -1) / scale
print(f"attention scores: {attn_scores.shape}") # (B, H, T, T)
# ------- 步骤 5:加入 causal mask(decoder 中防止看到未来)-------
# 上三角 mask:位置 i 不能看到 j > i 的位置
causal_mask = torch.triu(torch.ones(T, T), diagonal=1).bool() # (T, T)
# 需要广播到 (B, H, T, T),用 unsqueeze 或者直接利用广播
# (T, T) → 广播匹配 (B, H, T, T),前面的维度自动补 1
attn_scores = attn_scores.masked_fill(causal_mask, float('-inf'))
print(f"加 mask 后: {attn_scores.shape}") # (B, H, T, T)
# ------- 步骤 6:softmax + dropout -------
attn_weights = F.softmax(attn_scores, dim=-1) # (B, H, T, T)
attn_weights = F.dropout(attn_weights, p=0.1)
print(f"attention weights: {attn_weights.shape}") # (B, H, T, T)
# ------- 步骤 7:加权求和 weights × V -------
# (B, H, T, T) × (B, H, T, D_h) → (B, H, T, D_h)
context = attn_weights @ V
print(f"context 向量: {context.shape}") # (B, H, T, D_h)
# ------- 步骤 8:把 H 个 head 的结果合并回来 -------
# (B, H, T, D_h) → (B, T, H, D_h) → (B, T, H*D_h) = (B, T, D)
# 先把 H 维换回到 T 后面
context = context.permute(0, 2, 1, 3) # (B, T, H, D_h)
print(f"permute 回来: {context.shape}") # (B, T, H, D_h)
# 关键:permute 之后不连续,必须 contiguous() 才能 view()
context = context.contiguous()
print(f"is_contiguous: {context.is_contiguous()}") # True
# 合并 H 和 D_h 维度
context = context.view(B, T, D) # (B, T, H*D_h) = (B, T, 768)
print(f"合并 head 后: {context.shape}") # (B, T, D)
# ------- 步骤 9:最终输出线性层 -------
W_o = nn.Linear(D, D)
output = W_o(context) # (B, T, D)
print(f"最终输出: {output.shape}") # (B, T, D)
把全部形状变换汇总成一张清单:
输入 (B, T, D) = (8, 512, 768)
↓ W_q/W_k/W_v 线性投影
Q/K/V (B, T, D) = (8, 512, 768)
↓ reshape(B, T, H, D_h)
Q/K/V 拆 head (B, T, H, D_h) = (8, 512, 12, 64)
↓ permute(0, 2, 1, 3)
Q/K/V H 移前 (B, H, T, D_h) = (8, 12, 512, 64)
↓ Q @ K.transpose(-2,-1)
attention scores (B, H, T, T) = (8, 12, 512, 512)
↓ masked_fill + softmax
attention weights (B, H, T, T) = (8, 12, 512, 512)
↓ weights @ V
context (B, H, T, D_h) = (8, 12, 512, 64)
↓ permute(0, 2, 1, 3)
context H 移后 (B, T, H, D_h) = (8, 512, 12, 64)
↓ .contiguous().view(B, T, D)
context 合并 head (B, T, D) = (8, 512, 768)
↓ W_o 线性投影
输出 (B, T, D) = (8, 512, 768)
整个流程里,没有任何一行是在做复杂的数学运算——全是形状的折叠、展开、调序。这就是为什么说"理解维度变换就是理解 Transformer 的一半"。
常见错误速查
最后汇总几个最容易出现的维度相关错误:
错误 1:transpose 后直接 view
# 错误
x = torch.randn(4, 6).transpose(0, 1).view(24) # RuntimeError
# 正确
x = torch.randn(4, 6).transpose(0, 1).contiguous().view(24) # OK
# 或者
x = torch.randn(4, 6).transpose(0, 1).reshape(24) # OK
错误 2:矩阵乘法维度不匹配
A = torch.randn(3, 4)
B = torch.randn(3, 4)
# A @ B 会报错,因为 4 ≠ 3(第一个 tensor 的最后一维必须等于第二个的倒数第二维)
C = A @ B.T # 正确:(3,4) × (4,3) → (3,3)
错误 3:广播方向搞反
a = torch.randn(3, 1)
b = torch.randn(1, 4)
# 广播:(3,1) 和 (1,4) → (3,4),OK
a = torch.randn(3)
b = torch.randn(4, 3)
# a 的形状 (3,) 从尾部对齐 → 相当于 (1, 3),广播为 (4, 3) → OK
a = torch.randn(3)
b = torch.randn(3, 4)
# a 的形状 (3,) 从尾部对齐,第 -1 维:3 vs 4,不能广播 → RuntimeError
错误 4:squeeze 了不应该 squeeze 的维度
x = torch.randn(1, 32, 1, 64)
y = x.squeeze() # 把所有大小为 1 的维度都去了,得到 (32, 64)
# 如果你只想去掉第 0 维,用
y = x.squeeze(0) # (32, 1, 64) —— 只去掉第 0 维
错误 5:期望独立拷贝但拿到了视图
a = torch.randn(4, 4)
b = a.reshape(2, 8) # b 是 a 的视图,共享内存
b[0, 0] = 999
print(a[0, 0]) # 999!a 也变了
# 如果需要独立拷贝
b = a.reshape(2, 8).clone() # clone() 产生真正的独立拷贝
b[0, 0] = 999
print(a[0, 0]) # 不变
小结
这一篇覆盖了所有你在大型模型源码里会遇到的维度变换操作:
- 形状查询:
.shape、.ndim、解包 - 改变形状:
reshape、view(以及contiguous()的必要性) - 增删维度:
unsqueeze、squeeze、None索引 - 重排轴顺序:
transpose、permute(以及内存不连续问题) - 合并与拆分维度:
flatten、unflatten - 复制扩展:
expand(零拷贝)、repeat(真实拷贝)、repeat_interleave - 广播:规则、典型应用(mask)、常见报错
- 矩阵乘法家族:
@、matmul、bmm、einsum - 其他高频操作:
cat、stack、split、chunk、gather、masked_fill - einops:更可读的维度操作库
最后用一个完整的 Multi-Head Attention 把这些操作串联起来——你现在应该能看懂那段维度变换流程里每一行在做什么,以及为什么这么做了。
下一篇,我们进入 PyTorch 最核心的机制:自动微分。我们会用一个最简单的线性回归,把梯度是怎么流动的这件事彻底讲清楚。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐
所有评论(0)