从零开始写Qwen3(五-其一)使用Triton实现自注意力
概述
自注意力是大模型的核心组件,也是计算最密集的两个部分之一(另一个是注意力后的MLP,这也是参数最多的部分)
第五章将依次用Triton实现基础版本和FlashAttn,本节先来实现基础版本
自注意力原理
用公式表示就是
O = softmax ( mask ( Q K ⊤ / D ) ) V O=\text{softmax}\left(\text{mask}(QK^\top/\sqrt{D})\right)V O=softmax(mask(QK⊤/D))V
根据公式可以很容易写出torch版本
batch_size, head_q, seq_len_q, head_dim = q.shape
head_kv = k.shape[1]
# 计算缩放因子
scale = self.n_head_dim ** -0.5
# GQA: 将q的head分组,每组对应一个kv head
assert head_q % head_kv == 0, f"head_q ({head_q}) must be divisible by head_kv ({head_kv})"
n_groups = head_q // head_kv
# 重塑q为 [batch, head_kv, n_groups, seq_len_q, head_dim]
q_reshaped = q.reshape(batch_size, head_kv, n_groups, seq_len_q, head_dim)
# 扩展k和v的维度以匹配q的分组 [batch, head_kv, 1, seq_len_k, head_dim]
k_expanded = k.unsqueeze(2)
v_expanded = v.unsqueeze(2)
# 计算注意力分数: [batch, head_kv, n_groups, seq_len_q, seq_len_k]
scores = torch.matmul(q_reshaped, k_expanded.transpose(-2, -1)) * scale
# 应用因果掩码
seq_len_k = k.shape[2]
if seq_len_q > 1 and is_causal:
# 创建因果掩码
causal_mask = torch.tril(
torch.ones(seq_len_q, seq_len_k, device=q.device, dtype=torch.bool)
)
scores = scores.masked_fill(~causal_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0), float('-inf'))
# Softmax
attn_weights = torch.softmax(scores, dim=-1)
# 计算输出: [batch, head_kv, n_groups, seq_len_q, head_dim]
out = torch.matmul(attn_weights, v_expanded)
# 重塑回原始形状 [batch, head_q, seq_len_q, head_dim]
out = out.reshape(batch_size, head_q, seq_len_q, head_dim)
要实现它的GPU算子,最简单的就是像上面这样分成多步:
- QK矩阵乘法,除以sqrt(D)可以放在一起
- 遮罩和softmax得到注意力权重矩阵
- 注意力权重和V矩阵乘法得到结果
矩阵乘法的Triton实现
基础版
二维分块,(y,x)块负责计算(y,x)的结果,然后对D维度进行循环乘积
问题:
- 乘法中的其中一个是行顺序,一个是列顺序,列顺序这个无法合并读取
- 存在大量重复读取,具体分析如下
考虑
O i j = ∑ k A i k B k j O_{ij}=\sum_k A_{ik} B_{kj} Oij=k∑AikBkj
有
O 00 = ∑ k A 0 k B k 0 , O 01 = ∑ k A 0 k B k 1 , O 10 = ∑ k A 1 k B k 0 O_{00}=\sum_k A_{0k}B_{k0}, O_{01}=\sum_k A_{0k}B_{k1}, O_{10}=\sum_k A_{1k}B_{k0} O00=k∑A0kBk0,O01=k∑A0kBk1,O10=k∑A1kBk0
也就是相邻之间线程这些读取都是重复的,产生了很多浪费
改进-共享内存
使用共享内存,实现两件事情:
- 把列遍历转为行遍历,实现合并读写
- 复用全局内存读取
现在一个线程块读取A的BMxBK的数据,B的BKxBN的数据,然后沿着K方向遍历进行子矩阵矩阵乘法并求和,最后输出BMxBN的区域。也就是变成分块矩阵乘法
下面是CUTLASS的示意图,这个框架不仅使用分块矩阵乘法,还使用了三层,线程块一层,线程束一层,线程一层
Triton不用写这么多分层代码,简单实现就有
@triton.jit
def matrix_multiplication_kernel(
a,
b,
c,
M,
N,
K,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
scale,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offsets_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offsets_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
mask_m = offsets_m[:, None] < M
mask_n = offsets_n[None, :] < N
offsets_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a + offsets_m[:, None] * stride_am + offsets_k[None, :] * stride_ak
b_ptrs = b + offsets_n[None, :] * stride_bn + offsets_k[:, None] * stride_bk
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in tl.range(0, K, BLOCK_SIZE_K):
block_a = tl.load(a_ptrs, (offsets_k[None, :] < K - k) & mask_m, 0.0)
block_b = tl.load(b_ptrs, (offsets_k[:, None] < K - k) & mask_n, 0.0)
acc = tl.dot(block_a, block_b, acc, input_precision="ieee")
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
tl.store(
c + offsets_m[:, None] * stride_cm + offsets_n[None, :] * stride_cn,
acc * scale,
mask_m & mask_n,
)
有几个要点说明:
- offsets_m和offsets_n对M和N取余,这是从Triton官网示例看过来的。也可以不取余,改成在mask上增加条件
- 因为最终写入的 c i j = a i k b k j c_{ij}=a_{ik}b_{kj} cij=aikbkj,对ij进行了限制,所以超出限制的ij即使参与了读取,参与了计算,不管它用的是什么值,反正最后不写入,所以可以使用模M、N的方式读取越界值
- tl.dot在支持张量核的显卡上会使用张量核,这是专门用来加速矩阵乘法的机器指令,效率很高
- 至少在3060的安培架构上,tl.dot会使用TF这种浮点数格式来计算,可能产生些精度损失,如果不想精度损失,可以指定
input_precision="ieee",但是速度会慢很多
- 至少在3060的安培架构上,tl.dot会使用TF这种浮点数格式来计算,可能产生些精度损失,如果不想精度损失,可以指定
免转置矩阵乘法
在自注意力中可以看到,矩阵乘法中K是要转置的,QKV输入过来都是BxSxD的格式,不转置不能使用上面的计算
虽然torch的转置并不会真的改动物理内存,只会更改stride,但是修改后K的行就不再是主方向,此时原本假设的行优先的合并读取的优化全部失效,反而会产生错误,所以通常许多地方都会要求张量内存是连续的,x.stride(-1)==1
可以不转置,单独写一个矩阵乘法算子适配这种情况
@triton.jit
def matrix_multiplication_kernel_with_transpose(
a,
b,
c,
M,
N,
K,
stride_am,
stride_ak,
stride_bn,
stride_bk,
stride_cm,
stride_cn,
scale,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offsets_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M))[:, None]
ori_offsets_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offsets_n = ori_offsets_n[:, None]
mask_m = offsets_m < M
mask_n = offsets_n < N
offsets_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a + offsets_m * stride_am + offsets_k[None, :] * stride_ak
b_ptrs = b + offsets_n * stride_bn + offsets_k[None, :] * stride_bk
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in tl.range(0, K, BLOCK_SIZE_K):
block_a = tl.load(a_ptrs, (offsets_k[None, :] < K - k) & mask_m, 0.0)
block_b = tl.load(b_ptrs, (offsets_k[None, :] < K - k) & mask_n, 0.0)
acc = tl.dot(block_a, block_b.T, acc, input_precision="ieee")
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
tl.store(
c + offsets_m * stride_cm + ori_offsets_n[None, :] * stride_cn,
acc * scale,
mask_m & mask_n,
)
在不转置的情况下,a读取BMxBK的大小,b读取BNxBK的大小,矩阵乘法的时候还是需要转置一下,不过这里使用共享内存处理好了,张量核能直接使用
softmax的triton实现
softmax可以让一个线程块处理一行,做这些事情
- 找到本行的最大值
- 所有值减去最大值
- 取指数
- 求和得到分母
- 除以分母
一个线程块处理完整一行正好,如果多个线程块处理一行还需要启动多次核函数,开销就大了
使用triton实现
@triton.jit
def softmax(A, row_stride, col_stride, M, N, scale, BLOCK_SIZE: tl.constexpr):
row_start = tl.program_id(0)
row_step = tl.num_programs(0)
for row_idx in tl.range(row_start, M, row_step):
row_start_ptr = A + row_idx * row_stride
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptr = row_start_ptr + col_offsets * col_stride
mask = col_offsets < N
row = tl.load(input_ptr, mask=mask, other=-float("inf"))
# row = tl.load(input_ptr, mask=mask & col_offsets <= row_idx,
# other=-float("inf"))
max_val = tl.max(row)
row = row - max_val
row = tl.exp(row)
sum_val = tl.sum(row)
row = row / sum_val
tl.store(input_ptr, row, mask=mask)
这里让每个线程块多处理了几行
求最大值、求和这些线程块内同步triton都帮我们做好了,不需要自己写
softmax中需要求最大值,然后分子分母的指数中同减去最大值,结果和减之前一样。从数学角度上看没有减的必要,但是在计算机实际计算中如果不减,某个激活值可能较大,指数后就爆炸了,实际计算需要防止这种情况发生,非常小的值则不用管,指数后也只会变成0
然后如果要使用因果遮罩,可以在条件中加上列号小于等于行号,这样把对角线以上的部分全部填充为负无穷,在指数后变成0
row = tl.load(input_ptr, mask=mask & col_offsets <= row_idx,
other=-float("inf"))
串起来
三个核函数,其中第一个免转置矩阵乘法核函数需要开辟 M × N M\times N M×N 的注意力权重矩阵,第二个softmax原地修改注意力权重并增加mask,第三个普通矩阵乘法
多头自注意力
多头自注意力就是把QKV原本的D维度向量拆分为多个头部,然后每个头部分别进行自注意力计算,最后拼接起来
Q = [ q 1 q 2 ⋮ q m ] , K = [ k 1 k 2 ⋮ k n ] , V = [ v 1 v 2 ⋮ v n ] Q=\left[\begin{matrix}q_1\\q_2 \\\vdots\\q_m\end{matrix}\right] ,K=\left[\begin{matrix}k_1\\k_2\\\vdots\\k_n \end{matrix}\right], V=\left[\begin{matrix}v_1\\v_2\\\vdots\\v_n \end{matrix}\right] Q= q1q2⋮qm ,K= k1k2⋮kn ,V= v1v2⋮vn
得到结果
O = [ o 1 o 2 ⋮ o m ] , o i = softmax ( mask ( q i k i ⊤ ) ) v i O=\left[\begin{matrix}o_1\\o_2 \\\vdots\\o_m\end{matrix}\right] , o_i =\text{softmax}(\text{mask}(q_i k_i^\top))v_i O= o1o2⋮om ,oi=softmax(mask(qiki⊤))vi
多头可以完全并行,可以和Batch合并到一起并行
于是多头自注意力的计算相比普通版本只需要给input和output增加一个偏移
pid_h = tl.program_id(0) # BxH ,放到一个编号中
input_ptr = input + pid_h * stride_ih
output_ptr = output + pid_h * stride_oh
BH放到同一个编号中,是因为矩阵乘法要用到两个网格号,CUDA上最多能用xyz三个网格号,就只能把BH压缩到一起,反正也是完全并行的
stride的作用
实际上多头注意力的计算是要经过转置的:
N × D → H × N × D H N\times D\to H\times N\times \frac{D}{H} N×D→H×N×HD
但计算并不用特殊处理这一点,可以直接在计算偏移的时候乘上stride就行,普通的计算也需要乘以这个stride
具体来说,原来的 ( n , d ) (n,d) (n,d) 对应的内存地址偏移是
n S n + d S d n S_n + d S_d nSn+dSd
一般 S n = D , S d = 1 S_n=D,S_d=1 Sn=D,Sd=1
现在在转置之前,把D切分为H份,是
( n , h , d ) → n S n + h S d + d S d (n,h,d)\to nS_n + hS_d + dS_d (n,h,d)→nSn+hSd+dSd
此时有
S n = D , S h = D H , S d = 1 S_n = D,S_h = \frac{D}{H},S_d=1 Sn=D,Sh=HD,Sd=1
现在转置只不过改变了坐标的顺序,偏移是完全不变的
( h , n , d ) → h S d + n S n + d S d (h,n,d)\to hS_d +nS_n + dS_d (h,n,d)→hSd+nSn+dSd
总结一点就是,只要不是最后一个维度发生转置,就可以直接调换stride,不用做任何其他处理
在torch中,transpose其实就是通过调换stride来实现的,所以调用这个算子之前先要把数据转置到BHSD的形式,这个过程没有大量计算和访存,只修改了几个stride
另外 expand 增加维度也是通过增加stride 实现的,除非在最后一维增加,否则都是 stride=0,在这个维度上的遍历实际上是访问完全相同过的内容
a = torch.arange(0, 12).view(3,4)
b = torch.expand(2, -1, -1) # b.stride() = 0,4,1
比如要访问b[1,2,3],地址是 1 *stride(0) + 2 *stride(1) + 3 * stride(2) = 1 * 0 +2 * 4 + 3 * 1 = 11,这个地址和 [0,2,3]完全一样。
另外如果是 exapnd(1,-1,-1),会发现 stride(0)实际上不是0,而是12,但这个维度只有索引0,所以不管stride是多少,地址偏移都是0
GQA
Qwen3内部使用的是GQA,把注意力头部分成多个组,在每个组内,KV一个头部共享Q的多个头部
比如Qwen3-0.6B,输入维度是1024,Q投影为2048维,KV投影为1024维,Q头部是16,KV头部是8,所以每个头部的维度都是相同的128,每个KV头被2个Q头部共享
实现这一点就在计算KV偏移的时候把tid_h整除共享数就行
pid_qh = tl.program_id(0) # BxH ,放到一个编号中
pid_kh = pid_qh // groups # groups = 16/8 = 2
KV 的编号向下整除 组数,不用分离 B和H,自然和Q在同一个B中
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐




所有评论(0)