从零开始写Qwen3目录

概述

自注意力是大模型的核心组件,也是计算最密集的两个部分之一(另一个是注意力后的MLP,这也是参数最多的部分)

第五章将依次用Triton实现基础版本和FlashAttn,本节先来实现基础版本

自注意力原理

用公式表示就是

O = softmax ( mask ( Q K ⊤ / D ) ) V O=\text{softmax}\left(\text{mask}(QK^\top/\sqrt{D})\right)V O=softmax(mask(QK/D ))V

根据公式可以很容易写出torch版本

batch_size, head_q, seq_len_q, head_dim = q.shape
head_kv = k.shape[1]

# 计算缩放因子
scale = self.n_head_dim ** -0.5

# GQA: 将q的head分组,每组对应一个kv head
assert head_q % head_kv == 0, f"head_q ({head_q}) must be divisible by head_kv ({head_kv})"
n_groups = head_q // head_kv

# 重塑q为 [batch, head_kv, n_groups, seq_len_q, head_dim]
q_reshaped = q.reshape(batch_size, head_kv, n_groups, seq_len_q, head_dim)

# 扩展k和v的维度以匹配q的分组 [batch, head_kv, 1, seq_len_k, head_dim]
k_expanded = k.unsqueeze(2)
v_expanded = v.unsqueeze(2)

# 计算注意力分数: [batch, head_kv, n_groups, seq_len_q, seq_len_k]
scores = torch.matmul(q_reshaped, k_expanded.transpose(-2, -1)) * scale

# 应用因果掩码
seq_len_k = k.shape[2]
if seq_len_q > 1 and is_causal:
    # 创建因果掩码
    causal_mask = torch.tril(
        torch.ones(seq_len_q, seq_len_k, device=q.device, dtype=torch.bool)
    )
    scores = scores.masked_fill(~causal_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0), float('-inf'))

# Softmax
attn_weights = torch.softmax(scores, dim=-1)

# 计算输出: [batch, head_kv, n_groups, seq_len_q, head_dim]
out = torch.matmul(attn_weights, v_expanded)

# 重塑回原始形状 [batch, head_q, seq_len_q, head_dim]
out = out.reshape(batch_size, head_q, seq_len_q, head_dim)

要实现它的GPU算子,最简单的就是像上面这样分成多步:

  • QK矩阵乘法,除以sqrt(D)可以放在一起
  • 遮罩和softmax得到注意力权重矩阵
  • 注意力权重和V矩阵乘法得到结果

矩阵乘法的Triton实现

基础版

二维分块,(y,x)块负责计算(y,x)的结果,然后对D维度进行循环乘积

问题:

  1. 乘法中的其中一个是行顺序,一个是列顺序,列顺序这个无法合并读取
  2. 存在大量重复读取,具体分析如下

考虑
O i j = ∑ k A i k B k j O_{ij}=\sum_k A_{ik} B_{kj} Oij=kAikBkj

O 00 = ∑ k A 0 k B k 0 , O 01 = ∑ k A 0 k B k 1 , O 10 = ∑ k A 1 k B k 0 O_{00}=\sum_k A_{0k}B_{k0}, O_{01}=\sum_k A_{0k}B_{k1}, O_{10}=\sum_k A_{1k}B_{k0} O00=kA0kBk0,O01=kA0kBk1,O10=kA1kBk0

也就是相邻之间线程这些读取都是重复的,产生了很多浪费

改进-共享内存

使用共享内存,实现两件事情:

  1. 把列遍历转为行遍历,实现合并读写
  2. 复用全局内存读取

现在一个线程块读取A的BMxBK的数据,B的BKxBN的数据,然后沿着K方向遍历进行子矩阵矩阵乘法并求和,最后输出BMxBN的区域。也就是变成分块矩阵乘法

下面是CUTLASS的示意图,这个框架不仅使用分块矩阵乘法,还使用了三层,线程块一层,线程束一层,线程一层
在这里插入图片描述
Triton不用写这么多分层代码,简单实现就有

@triton.jit
def matrix_multiplication_kernel(
    a,
    b,
    c,
    M,
    N,
    K,
    stride_am,
    stride_ak,
    stride_bk,
    stride_bn,
    stride_cm,
    stride_cn,
    scale,
    BLOCK_SIZE_M: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr,
    BLOCK_SIZE_K: tl.constexpr,
):
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)
    offsets_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offsets_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    mask_m = offsets_m[:, None] < M
    mask_n = offsets_n[None, :] < N
    offsets_k = tl.arange(0, BLOCK_SIZE_K)

    a_ptrs = a + offsets_m[:, None] * stride_am + offsets_k[None, :] * stride_ak
    b_ptrs = b + offsets_n[None, :] * stride_bn + offsets_k[:, None] * stride_bk
    acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
    for k in tl.range(0, K, BLOCK_SIZE_K):
        block_a = tl.load(a_ptrs, (offsets_k[None, :] < K - k) & mask_m, 0.0)
        block_b = tl.load(b_ptrs, (offsets_k[:, None] < K - k) & mask_n, 0.0)
        acc = tl.dot(block_a, block_b, acc, input_precision="ieee")
        a_ptrs += BLOCK_SIZE_K * stride_ak
        b_ptrs += BLOCK_SIZE_K * stride_bk
    tl.store(
        c + offsets_m[:, None] * stride_cm + offsets_n[None, :] * stride_cn,
        acc * scale,
        mask_m & mask_n,
    )

有几个要点说明:

  1. offsets_m和offsets_n对M和N取余,这是从Triton官网示例看过来的。也可以不取余,改成在mask上增加条件
    1. 因为最终写入的 c i j = a i k b k j c_{ij}=a_{ik}b_{kj} cij=aikbkj,对ij进行了限制,所以超出限制的ij即使参与了读取,参与了计算,不管它用的是什么值,反正最后不写入,所以可以使用模M、N的方式读取越界值
  2. tl.dot在支持张量核的显卡上会使用张量核,这是专门用来加速矩阵乘法的机器指令,效率很高
    1. 至少在3060的安培架构上,tl.dot会使用TF这种浮点数格式来计算,可能产生些精度损失,如果不想精度损失,可以指定input_precision="ieee" ,但是速度会慢很多

免转置矩阵乘法

在自注意力中可以看到,矩阵乘法中K是要转置的,QKV输入过来都是BxSxD的格式,不转置不能使用上面的计算

虽然torch的转置并不会真的改动物理内存,只会更改stride,但是修改后K的行就不再是主方向,此时原本假设的行优先的合并读取的优化全部失效,反而会产生错误,所以通常许多地方都会要求张量内存是连续的,x.stride(-1)==1

可以不转置,单独写一个矩阵乘法算子适配这种情况

@triton.jit
def matrix_multiplication_kernel_with_transpose(
    a,
    b,
    c,
    M,
    N,
    K,
    stride_am,
    stride_ak,
    stride_bn,
    stride_bk,
    stride_cm,
    stride_cn,
    scale,
    BLOCK_SIZE_M: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr,
    BLOCK_SIZE_K: tl.constexpr,
):
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)
    offsets_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M))[:, None]
    ori_offsets_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    offsets_n = ori_offsets_n[:, None]
    mask_m = offsets_m < M
    mask_n = offsets_n < N
    offsets_k = tl.arange(0, BLOCK_SIZE_K)

    a_ptrs = a + offsets_m * stride_am + offsets_k[None, :] * stride_ak
    b_ptrs = b + offsets_n * stride_bn + offsets_k[None, :] * stride_bk
    acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
    for k in tl.range(0, K, BLOCK_SIZE_K):
        block_a = tl.load(a_ptrs, (offsets_k[None, :] < K - k) & mask_m, 0.0)
        block_b = tl.load(b_ptrs, (offsets_k[None, :] < K - k) & mask_n, 0.0)
        acc = tl.dot(block_a, block_b.T, acc, input_precision="ieee")
        a_ptrs += BLOCK_SIZE_K * stride_ak
        b_ptrs += BLOCK_SIZE_K * stride_bk
    tl.store(
        c + offsets_m * stride_cm + ori_offsets_n[None, :] * stride_cn,
        acc * scale,
        mask_m & mask_n,
    )

在不转置的情况下,a读取BMxBK的大小,b读取BNxBK的大小,矩阵乘法的时候还是需要转置一下,不过这里使用共享内存处理好了,张量核能直接使用

softmax的triton实现

softmax可以让一个线程块处理一行,做这些事情

  • 找到本行的最大值
  • 所有值减去最大值
  • 取指数
  • 求和得到分母
  • 除以分母

一个线程块处理完整一行正好,如果多个线程块处理一行还需要启动多次核函数,开销就大了

使用triton实现

@triton.jit
def softmax(A, row_stride, col_stride, M, N, scale, BLOCK_SIZE: tl.constexpr):
    row_start = tl.program_id(0)
    row_step = tl.num_programs(0)
    for row_idx in tl.range(row_start, M, row_step):
        row_start_ptr = A + row_idx * row_stride
        col_offsets = tl.arange(0, BLOCK_SIZE)
        input_ptr = row_start_ptr + col_offsets * col_stride
        mask = col_offsets < N
        row = tl.load(input_ptr, mask=mask, other=-float("inf"))
				# row = tl.load(input_ptr, mask=mask & col_offsets <= row_idx, 
				#	other=-float("inf"))
        max_val = tl.max(row)
        row = row - max_val
        row = tl.exp(row)
        sum_val = tl.sum(row)
        row = row / sum_val
        tl.store(input_ptr, row, mask=mask)

这里让每个线程块多处理了几行

求最大值、求和这些线程块内同步triton都帮我们做好了,不需要自己写

softmax中需要求最大值,然后分子分母的指数中同减去最大值,结果和减之前一样。从数学角度上看没有减的必要,但是在计算机实际计算中如果不减,某个激活值可能较大,指数后就爆炸了,实际计算需要防止这种情况发生,非常小的值则不用管,指数后也只会变成0

然后如果要使用因果遮罩,可以在条件中加上列号小于等于行号,这样把对角线以上的部分全部填充为负无穷,在指数后变成0

row = tl.load(input_ptr, mask=mask & col_offsets <= row_idx, 
	other=-float("inf"))

串起来

三个核函数,其中第一个免转置矩阵乘法核函数需要开辟 M × N M\times N M×N 的注意力权重矩阵,第二个softmax原地修改注意力权重并增加mask,第三个普通矩阵乘法

多头自注意力

多头自注意力就是把QKV原本的D维度向量拆分为多个头部,然后每个头部分别进行自注意力计算,最后拼接起来

Q = [ q 1 q 2 ⋮ q m ] , K = [ k 1 k 2 ⋮ k n ] , V = [ v 1 v 2 ⋮ v n ] Q=\left[\begin{matrix}q_1\\q_2 \\\vdots\\q_m\end{matrix}\right] ,K=\left[\begin{matrix}k_1\\k_2\\\vdots\\k_n \end{matrix}\right], V=\left[\begin{matrix}v_1\\v_2\\\vdots\\v_n \end{matrix}\right] Q= q1q2qm ,K= k1k2kn ,V= v1v2vn

得到结果

O = [ o 1 o 2 ⋮ o m ] , o i = softmax ( mask ( q i k i ⊤ ) ) v i O=\left[\begin{matrix}o_1\\o_2 \\\vdots\\o_m\end{matrix}\right] , o_i =\text{softmax}(\text{mask}(q_i k_i^\top))v_i O= o1o2om ,oi=softmax(mask(qiki))vi

多头可以完全并行,可以和Batch合并到一起并行

于是多头自注意力的计算相比普通版本只需要给input和output增加一个偏移

pid_h = tl.program_id(0)  # BxH ,放到一个编号中
input_ptr = input + pid_h * stride_ih
output_ptr = output + pid_h * stride_oh

BH放到同一个编号中,是因为矩阵乘法要用到两个网格号,CUDA上最多能用xyz三个网格号,就只能把BH压缩到一起,反正也是完全并行的

stride的作用

实际上多头注意力的计算是要经过转置的:

N × D → H × N × D H N\times D\to H\times N\times \frac{D}{H} N×DH×N×HD

但计算并不用特殊处理这一点,可以直接在计算偏移的时候乘上stride就行,普通的计算也需要乘以这个stride

具体来说,原来的 ( n , d ) (n,d) (n,d) 对应的内存地址偏移是

n S n + d S d n S_n + d S_d nSn+dSd

一般 S n = D , S d = 1 S_n=D,S_d=1 Sn=D,Sd=1

现在在转置之前,把D切分为H份,是

( n , h , d ) → n S n + h S d + d S d (n,h,d)\to nS_n + hS_d + dS_d (n,h,d)nSn+hSd+dSd

此时有

S n = D , S h = D H , S d = 1 S_n = D,S_h = \frac{D}{H},S_d=1 Sn=D,Sh=HD,Sd=1

现在转置只不过改变了坐标的顺序,偏移是完全不变的

( h , n , d ) → h S d + n S n + d S d (h,n,d)\to hS_d +nS_n + dS_d (h,n,d)hSd+nSn+dSd

总结一点就是,只要不是最后一个维度发生转置,就可以直接调换stride,不用做任何其他处理

在torch中,transpose其实就是通过调换stride来实现的,所以调用这个算子之前先要把数据转置到BHSD的形式,这个过程没有大量计算和访存,只修改了几个stride

另外 expand 增加维度也是通过增加stride 实现的,除非在最后一维增加,否则都是 stride=0,在这个维度上的遍历实际上是访问完全相同过的内容

a = torch.arange(0, 12).view(3,4)
b = torch.expand(2, -1, -1) # b.stride() = 0,4,1

比如要访问b[1,2,3],地址是 1 *stride(0) + 2 *stride(1) + 3 * stride(2) = 1 * 0 +2 * 4 + 3 * 1 = 11,这个地址和 [0,2,3]完全一样。
另外如果是 exapnd(1,-1,-1),会发现 stride(0)实际上不是0,而是12,但这个维度只有索引0,所以不管stride是多少,地址偏移都是0

GQA

Qwen3内部使用的是GQA,把注意力头部分成多个组,在每个组内,KV一个头部共享Q的多个头部

比如Qwen3-0.6B,输入维度是1024,Q投影为2048维,KV投影为1024维,Q头部是16,KV头部是8,所以每个头部的维度都是相同的128,每个KV头被2个Q头部共享

实现这一点就在计算KV偏移的时候把tid_h整除共享数就行

    pid_qh = tl.program_id(0)  # BxH ,放到一个编号中
    pid_kh = pid_qh // groups  # groups = 16/8 = 2

KV 的编号向下整除 组数,不用分离 B和H,自然和Q在同一个B中

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐