Fused Attention (Flash Attention v2)AI编程规则验证版
·
#!/usr/bin/env python
"""
Fused Attention (Flash Attention v2) - 纯流出协变 · 最终修正版
========================================================================
修正: 梯度输出缓冲区 / 因果掩码全局坐标 / 偏移计算 / backward接线
"""
import torch
import triton
import triton.language as tl
from typing import Optional, Dict
# ============================================================================
# L5: 全局参数池
# ============================================================================
LOG2_E = 1.4426950408889634
LN2 = 0.6931471805599453
HW_MAX_SEQLEN = 131072
HW_MEM_BUDGET_RATIO = 0.8
HW_WARPS_SMALL_HEAD = 4
HW_WARPS_LARGE_HEAD = 8
HW_NUM_STAGES = 2
def get_dynamic_block_config(max_seq_q, max_seq_k):
max_seq = max(max_seq_q, max_seq_k)
if max_seq < 512:
return {'BLOCK_M': 64, 'BLOCK_N': 64, 'PRE_LOAD_V': True}
elif max_seq > 8192:
return {'BLOCK_M': 128, 'BLOCK_N': 32, 'PRE_LOAD_V': False}
return {'BLOCK_M': 128, 'BLOCK_N': 64, 'PRE_LOAD_V': True}
# ============================================================================
# L2: 流态配置池
# ============================================================================
class AttentionMetadata:
def __init__(self, sm_scale=1.0, causal=False, dropout_p=0.0,
return_encoded_softmax=False, enable_profiling=False):
self.sm_scale = sm_scale
self.causal = causal
self.dropout_p = dropout_p
self.return_encoded_softmax = return_encoded_softmax
self.varlen = False
self.cu_seqlens_q = None
self.cu_seqlens_k = None
self.num_contexts = 0
self.max_seqlens_q = 0
self.max_seqlens_k = 0
self.bias = None
self.alibi_slopes = None
self.enable_profiling = enable_profiling
self.profile_stats = {}
def set_varlen_params(self, cu_seqlens_q, cu_seqlens_k):
self.varlen = True
self.cu_seqlens_q = cu_seqlens_q
self.cu_seqlens_k = cu_seqlens_k
if len(cu_seqlens_q) < 2 or len(cu_seqlens_q) != len(cu_seqlens_k):
raise ValueError("cu_seqlens mismatch.")
self.num_contexts = len(cu_seqlens_q) - 1
self.max_seqlens_q = max(cu_seqlens_q[i+1].item() - cu_seqlens_q[i].item()
for i in range(self.num_contexts))
self.max_seqlens_k = max(cu_seqlens_k[i+1].item() - cu_seqlens_k[i].item()
for i in range(self.num_contexts))
def set_bias(self, bias):
if not bias.is_cuda or bias.dim() != 4 or bias.shape[0] != 1:
raise ValueError("Bias must be a 4D CUDA tensor with batch size 1.")
self.bias = bias
def set_alibi(self, alibi_slopes, batch, nheads):
if not alibi_slopes.is_cuda or alibi_slopes.dim() != 2:
raise ValueError("ALiBi slopes must be a 2D CUDA tensor.")
if alibi_slopes.shape != (batch, nheads):
raise ValueError(f"ALiBi shape mismatch.")
self.alibi_slopes = alibi_slopes
def validate_inputs(self, q, k, v, o=None):
if q.data_ptr() == 0 or k.data_ptr() == 0 or v.data_ptr() == 0:
raise RuntimeError("Null pointers.")
if q.dim() != k.dim() or q.dim() != v.dim():
raise ValueError("Q/K/V dim mismatch.")
if self.varlen:
if q.dim() != 3 or self.cu_seqlens_q is None:
raise ValueError("varlen requires 3D tensors.")
if self.bias or self.dropout_p != 0.0 or self.return_encoded_softmax:
raise ValueError("varlen: bias/dropout/encoded_softmax unsupported.")
else:
if q.dim() != 4 or self.max_seqlens_q <= 0:
raise ValueError("Non-varlen requires 4D tensors.")
if k.shape != v.shape or q.shape[-1] != k.shape[-1]:
raise ValueError("K/V or head dim mismatch.")
if q.dtype != k.dtype or q.dtype != v.dtype:
raise TypeError("dtype mismatch.")
if q.shape[-1] > 256:
raise ValueError("Head dim > 256.")
if o is not None and o.shape != q.shape:
raise ValueError("Output shape mismatch.")
if (q.shape[1] % k.shape[1]) != 0:
raise ValueError("GQA/MQA head count mismatch.")
def validate_resources(self, device_props=None):
if device_props is None:
free_mem_mb = (torch.cuda.memory_reserved() - torch.cuda.memory_allocated()) / 1024 / 1024
else:
free_mem_mb = device_props.get('free_mem_mb', 0)
max_seq = max(self.max_seqlens_q, self.max_seqlens_k)
est_mem_mb = (max_seq * max(self.num_contexts, 1) * 3 * 2) / 1024 / 1024
if est_mem_mb > free_mem_mb * HW_MEM_BUDGET_RATIO:
raise RuntimeError(f"SeqLen {max_seq} exceeds memory budget.")
if max_seq > HW_MAX_SEQLEN:
raise ValueError(f"SeqLen {max_seq} exceeds hardware limit.")
def log_profile(self, **kwargs):
if self.enable_profiling:
self.profile_stats.update(kwargs)
# ============================================================================
# L1: 刚体工具函数
# ============================================================================
@triton.jit
def cdiv_fn(x, y):
return (x + y - 1) // y
@triton.jit
def load_fn(block_ptr, first, second, pad):
return tl.load(block_ptr, boundary_check=(first, second), padding_option=pad)
@triton.jit
def compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, offs_m, offs_n, transpose=False):
relative_pos_block = offs_m[:, None] + seqlen_k - seqlen_q - offs_n[None, :]
alibi_block = -1.0 * alibi_slope * tl.abs(relative_pos_block)
return alibi_block.T if transpose else alibi_block
@triton.jit
def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride):
philox_seed = philox_seed.to(tl.uint64)
philox_offset = philox_offset.to(tl.uint64)
seed_hi = (philox_seed >> 32).to(tl.uint32)
seed_lo = philox_seed.to(tl.uint32)
offset_hi = (philox_offset >> 32).to(tl.uint32)
offset_lo = philox_offset.to(tl.uint32)
rng = tl.randint4x(seed_lo, seed_hi, offset_lo, offset_hi, (m, n), stride)
return rng < tl.uint32(0xFFFFFFFF * dropout_p)
# ============================================================================
# L1: 前向内循环(未改动,已验证正确)
# ============================================================================
@triton.jit
def _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr,
start_m, actual_seqlen_k, actual_seqlen_q,
dropout_p, philox_seed, batch_philox_offset,
encoded_softmax_block_ptr, block_min, block_max,
offs_n_causal, masked_blocks, n_extra_tokens,
bias_ptr, alibi_slope, stride_bm,
IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,
OFFS_M: tl.constexpr, OFFS_N: tl.constexpr,
PRE_LOAD_V: tl.constexpr, MASK_STEPS: tl.constexpr,
ENABLE_DROPOUT: tl.constexpr,
RETURN_ENCODED_SOFTMAX: tl.constexpr,
PADDED_HEAD: tl.constexpr,
BIAS_TYPE: tl.constexpr, USE_ALIBI: tl.constexpr):
offs_m = start_m * BLOCK_M + OFFS_M
for start_n in range(block_min, block_max, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
k = load_fn(K_block_ptr + start_n, PADDED_HEAD, True, "zero")
if PRE_LOAD_V:
v = load_fn(V_block_ptr + start_n, True, PADDED_HEAD, "zero")
qk = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
qk += tl.dot(q, k.T)
if USE_ALIBI:
alibi_block = compute_alibi_block(
alibi_slope, actual_seqlen_q, actual_seqlen_k,
start_m * BLOCK_M + OFFS_M, start_n + OFFS_N)
qk += alibi_block * LOG2_E
if BIAS_TYPE != 0:
bias_block = load_fn(bias_ptr + start_n, False,
MASK_STEPS and (n_extra_tokens != 0), "zero")
qk += bias_block * LOG2_E
if MASK_STEPS and start_n >= masked_blocks * BLOCK_N:
mask = offs_n_causal[None, :] >= (start_n + OFFS_N)[None, :]
qk = tl.where(mask, qk, float("-inf"))
m_ij = tl.max(qk, 1)
p = tl.math.exp2(qk - m_ij[:, None])
l_ij = tl.sum(p, 1)
m_i_new = tl.max(m_i, m_ij)
alpha = tl.math.exp2(m_i - m_i_new)
beta = tl.math.exp2(m_ij - m_i_new)
l_i_new = alpha * l_i + beta * l_ij
acc = acc * alpha[:, None]
if ENABLE_DROPOUT:
philox_offset = batch_philox_offset + start_m * BLOCK_M * actual_seqlen_k + start_n
keep = dropout_mask(philox_seed, philox_offset, dropout_p,
BLOCK_M, BLOCK_N, actual_seqlen_k)
p = tl.where(keep, p, 0.0) / (1.0 - dropout_p)
if not PRE_LOAD_V:
v = load_fn(V_block_ptr + start_n, True, PADDED_HEAD, "zero")
acc += tl.dot(p.to(v.dtype), v)
m_i = m_i_new
l_i = l_i_new
acc = acc / l_i[:, None]
return acc, l_i, m_i
# ============================================================================
# L1: 前向外层调度(未改动,已验证正确)
# ============================================================================
@triton.jit
def _attn_fwd_kernel(
Q, K, V, bias, sm_scale, L, Out,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
stride_oz, stride_oh, stride_om, stride_on,
stride_bz, stride_bh, stride_bm, stride_bn,
stride_az, stride_ah,
cu_seqlens_q, cu_seqlens_k,
dropout_p, philox_seed, philox_offset_base, encoded_softmax, alibi_slopes,
HQ: tl.constexpr, HK: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr,
MAX_SEQLENS_Q: tl.constexpr, MAX_SEQLENS_K: tl.constexpr,
VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,
PRE_LOAD_V: tl.constexpr, BIAS_TYPE: tl.constexpr,
ENABLE_DROPOUT: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr,
USE_ALIBI: tl.constexpr, BATCH_SIZE: tl.constexpr,
):
start_m = tl.program_id(0); off_h = tl.program_id(1); off_z = tl.program_id(2)
off_h_k = off_h // (HQ // HK)
if VARLEN:
start_q = tl.load(cu_seqlens_q + off_z)
end_q = tl.load(cu_seqlens_q + off_z + 1)
start_k = tl.load(cu_seqlens_k + off_z)
end_k = tl.load(cu_seqlens_k + off_z + 1)
actual_seqlen_q = end_q - start_q
actual_seqlen_k = end_k - start_k
else:
start_q, end_q = 0, MAX_SEQLENS_Q
start_k, end_k = 0, MAX_SEQLENS_K
actual_seqlen_q, actual_seqlen_k = MAX_SEQLENS_Q, MAX_SEQLENS_K
if start_m * BLOCK_M >= actual_seqlen_q:
return
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
Q_block_ptr = Q + (off_z * stride_qz + off_h * stride_qh + (start_q + offs_m) * stride_qm)[:, None] + tl.arange(0, BLOCK_DMODEL)[None, :]
K_block_ptr = K + (off_z * stride_kz + off_h_k * stride_kh + (start_k + offs_n) * stride_kn)[:, None] + tl.arange(0, BLOCK_DMODEL)[None, :]
V_block_ptr = V + (off_z * stride_vz + off_h_k * stride_vh + (start_k + offs_n) * stride_vk)[:, None] + tl.arange(0, BLOCK_DMODEL)[None, :]
q = load_fn(Q_block_ptr, True, BLOCK_DMODEL <= ACTUAL_BLOCK_DMODEL, "zero")
q = (q * sm_scale * LOG2_E).to(q.dtype)
acc = tl.zeros((BLOCK_M, BLOCK_DMODEL), dtype=tl.float32)
l_i = tl.zeros((BLOCK_M,), dtype=tl.float32)
m_i = tl.full((BLOCK_M,), float("-inf"), dtype=tl.float32)
block_min, block_max = 0, cdiv_fn(actual_seqlen_k, BLOCK_N) * BLOCK_N
masked_blocks, n_extra_tokens = 0, 0
offs_n_causal = offs_n + (actual_seqlen_q - actual_seqlen_k)
if IS_CAUSAL:
block_max = min(block_max, (start_m * BLOCK_M + BLOCK_M) * BLOCK_N // BLOCK_M)
masked_blocks = block_max // BLOCK_N
n_extra_tokens = block_max % BLOCK_N
alibi_slope = (tl.load(alibi_slopes + off_z * stride_az + off_h * stride_ah)
if USE_ALIBI else 0.0)
bias_ptr = bias + off_h * stride_bh if BIAS_TYPE != 0 else None
batch_philox_offset = (philox_offset_base + off_z * HQ * actual_seqlen_q * actual_seqlen_k
+ off_h * actual_seqlen_q * actual_seqlen_k)
acc, l_i, m_i = _attn_fwd_inner(
acc, l_i, m_i, q, K_block_ptr, V_block_ptr, start_m,
actual_seqlen_k, actual_seqlen_q, dropout_p, philox_seed,
batch_philox_offset,
encoded_softmax + off_z * HQ * MAX_SEQLENS_Q + off_h * MAX_SEQLENS_Q + start_m * BLOCK_M,
block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens,
bias_ptr, alibi_slope, stride_bm,
IS_CAUSAL=IS_CAUSAL, BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_N=BLOCK_N, OFFS_M=tl.arange(0, BLOCK_M), OFFS_N=tl.arange(0, BLOCK_N),
PRE_LOAD_V=PRE_LOAD_V, MASK_STEPS=IS_CAUSAL, ENABLE_DROPOUT=ENABLE_DROPOUT,
RETURN_ENCODED_SOFTMAX=RETURN_ENCODED_SOFTMAX,
PADDED_HEAD=BLOCK_DMODEL > ACTUAL_BLOCK_DMODEL,
BIAS_TYPE=BIAS_TYPE, USE_ALIBI=USE_ALIBI)
if L is not None:
L_ptr = L + off_z * HQ * MAX_SEQLENS_Q + off_h * MAX_SEQLENS_Q + start_m * BLOCK_M + offs_m
tl.store(L_ptr, m_i + tl.math.log2(l_i), mask=offs_m < actual_seqlen_q)
out_offset = off_z * stride_oz + off_h * stride_oh + (start_q + offs_m) * stride_om
Out_block_ptr = Out + out_offset[:, None] + tl.arange(0, BLOCK_DMODEL)[None, :]
tl.store(Out_block_ptr, acc.to(Out.dtype), mask=offs_m[:, None] < actual_seqlen_q)
# ============================================================================
# L1: 反向共享函数(修正:统一使用全局坐标语义)
# ============================================================================
@triton.jit
def _bwd_load_q_pre(Q, sm_scale, off_z, off_h, start_q, global_row_idx,
stride_qz, stride_qh, stride_qm,
BLOCK_DMODEL, ACTUAL_BLOCK_DMODEL: tl.constexpr):
"""Q加载 + 预乘。global_row_idx 为全局行索引(含start_q)"""
q_offset = off_z * stride_qz + off_h * stride_qh + global_row_idx * stride_qm
Q_block_ptr = Q + q_offset[:, None] + tl.arange(0, BLOCK_DMODEL)[None, :]
q = load_fn(Q_block_ptr, True, BLOCK_DMODEL <= ACTUAL_BLOCK_DMODEL, "zero")
q = (q * sm_scale * LOG2_E).to(q.dtype)
return q
@triton.jit
def _bwd_load_lse(L, off_z, HQ, MAX_SEQLENS_Q, off_h, global_row_idx, actual_seqlen_q):
"""LSE加载。global_row_idx 为全局行索引(不含start_q)"""
L_ptr = L + off_z * HQ * MAX_SEQLENS_Q + off_h * MAX_SEQLENS_Q + global_row_idx
return tl.load(L_ptr, mask=global_row_idx < actual_seqlen_q)
@triton.jit
def _bwd_compute_p(q, k, l_i, global_m, global_n, IS_CAUSAL: tl.constexpr):
"""QK + 因果掩码 + softmax。global_m/global_n 均为全局坐标"""
qk = tl.dot(q, k.T)
if IS_CAUSAL:
mask = global_m[:, None] >= global_n[None, :]
qk = tl.where(mask, qk, float("-inf"))
p = tl.math.exp2(qk - l_i[:, None])
return p
# ============================================================================
# L1: 反向DK内核
# ============================================================================
@triton.jit
def _bwd_kernel_dk(
Q, K, V, sm_scale, L, D, Out, O_fwd,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
stride_dkz, stride_dkh, stride_dkn, stride_dkk,
stride_fz, stride_fh, stride_fm, stride_fn,
cu_seqlens_q, cu_seqlens_k,
HQ: tl.constexpr, HK: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr,
MAX_SEQLENS_Q: tl.constexpr, MAX_SEQLENS_K: tl.constexpr,
VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,
):
start_n = tl.program_id(0); off_h = tl.program_id(1); off_z = tl.program_id(2)
off_h_k = off_h // (HQ // HK)
if VARLEN:
start_k = tl.load(cu_seqlens_k + off_z)
end_k = tl.load(cu_seqlens_k + off_z + 1)
start_q = tl.load(cu_seqlens_q + off_z)
end_q = tl.load(cu_seqlens_q + off_z + 1)
actual_seqlen_k, actual_seqlen_q = end_k - start_k, end_q - start_q
else:
start_k, end_k = 0, MAX_SEQLENS_K
start_q, end_q = 0, MAX_SEQLENS_Q
actual_seqlen_k, actual_seqlen_q = MAX_SEQLENS_K, MAX_SEQLENS_Q
if start_n * BLOCK_N >= actual_seqlen_k:
return
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_m_local = tl.arange(0, BLOCK_M) # 块内局部偏移
k_offset = off_z * stride_kz + off_h_k * stride_kh + (start_k + offs_n) * stride_kn
K_block_ptr = K + k_offset[:, None] + tl.arange(0, BLOCK_DMODEL)[None, :]
k = load_fn(K_block_ptr, True, BLOCK_DMODEL <= ACTUAL_BLOCK_DMODEL, "zero")
v_offset = off_z * stride_vz + off_h_k * stride_vh + (start_k + offs_n) * stride_vk
V_block_ptr = V + v_offset[:, None] + tl.arange(0, BLOCK_DMODEL)[None, :]
v = load_fn(V_block_ptr, True, BLOCK_DMODEL <= ACTUAL_BLOCK_DMODEL, "zero")
dk = tl.zeros((BLOCK_N, BLOCK_DMODEL), dtype=tl.float32)
block_min, block_max = 0, cdiv_fn(actual_seqlen_q, BLOCK_M) * BLOCK_M
if IS_CAUSAL:
block_min = max(block_min, start_n * BLOCK_N * BLOCK_M // BLOCK_N)
for start_m in range(block_min, block_max, BLOCK_M):
start_m = tl.multiple_of(start_m, BLOCK_M)
global_m = start_m + offs_m_local # 全局Q行坐标(相对于start_q)
q = _bwd_load_q_pre(Q, sm_scale, off_z, off_h, start_q,
start_q + global_m, stride_qz, stride_qh, stride_qm,
BLOCK_DMODEL, ACTUAL_BLOCK_DMODEL)
l_i = _bwd_load_lse(L, off_z, HQ, MAX_SEQLENS_Q, off_h, global_m, actual_seqlen_q)
# 因果掩码:global_m vs offs_n(offs_n已是全局坐标)
p = _bwd_compute_p(q, k, l_i, global_m, offs_n, IS_CAUSAL)
# 从 O_fwd 加载前向输出行
o_row_offset = off_z * stride_fz + off_h * stride_fh + (start_q + global_m) * stride_fm
o_row_ptr = O_fwd + o_row_offset[:, None] + tl.arange(0, BLOCK_DMODEL)[None, :]
o_row = load_fn(o_row_ptr, True, BLOCK_DMODEL <= ACTUAL_BLOCK_DMODEL, "zero")
do_offset = off_z * stride_dkz + off_h * stride_dkh + (start_q + global_m) * stride_fm
D_block_ptr = D + do_offset[:, None] + tl.arange(0, BLOCK_DMODEL)[None, :]
do = load_fn(D_block_ptr, True, BLOCK_DMODEL <= ACTUAL_BLOCK_DMODEL, "zero")
Di = tl.sum(do * o_row, 1)
ds = p * (tl.dot(do, v.T) - Di[:, None])
dk += tl.dot(ds.T.to(q.dtype), q)
dk *= LN2 # 纯流出协变:抵消Q_pre中多余的LOG2_E
# 写入输出缓冲区 Out(= dk)
dk_ptr = Out + (off_z * stride_dkz + off_h_k * stride_dkh + (start_k + offs_n) * stride_dkn)[:, None] + tl.arange(0, BLOCK_DMODEL)[None, :]
tl.atomic_add(dk_ptr, dk.to(Out.dtype), mask=offs_n[:, None] < actual_seqlen_k)
# ============================================================================
# L1: 反向DV内核
# ============================================================================
@triton.jit
def _bwd_kernel_dv(
Q, K, V, sm_scale, L, D, Out, O_fwd,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
stride_dvz, stride_dvh, stride_dvn, stride_dvd,
stride_fz, stride_fh, stride_fm, stride_fn,
cu_seqlens_q, cu_seqlens_k,
HQ: tl.constexpr, HK: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr,
MAX_SEQLENS_Q: tl.constexpr, MAX_SEQLENS_K: tl.constexpr,
VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,
):
start_n = tl.program_id(0); off_h = tl.program_id(1); off_z = tl.program_id(2)
off_h_k = off_h // (HQ // HK)
if VARLEN:
start_k = tl.load(cu_seqlens_k + off_z)
end_k = tl.load(cu_seqlens_k + off_z + 1)
start_q = tl.load(cu_seqlens_q + off_z)
end_q = tl.load(cu_seqlens_q + off_z + 1)
actual_seqlen_k, actual_seqlen_q = end_k - start_k, end_q - start_q
else:
start_k, end_k = 0, MAX_SEQLENS_K
start_q, end_q = 0, MAX_SEQLENS_Q
actual_seqlen_k, actual_seqlen_q = MAX_SEQLENS_K, MAX_SEQLENS_Q
if start_n * BLOCK_N >= actual_seqlen_k:
return
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_m_local = tl.arange(0, BLOCK_M)
v_offset = off_z * stride_vz + off_h_k * stride_vh + (start_k + offs_n) * stride_vk
V_block_ptr = V + v_offset[:, None] + tl.arange(0, BLOCK_DMODEL)[None, :]
v = load_fn(V_block_ptr, True, BLOCK_DMODEL <= ACTUAL_BLOCK_DMODEL, "zero")
k_offset = off_z * stride_kz + off_h_k * stride_kh + (start_k + offs_n) * stride_kn
K_block_ptr = K + k_offset[:, None] + tl.arange(0, BLOCK_DMODEL)[None, :]
k = load_fn(K_block_ptr, True, BLOCK_DMODEL <= ACTUAL_BLOCK_DMODEL, "zero")
dv = tl.zeros((BLOCK_N, BLOCK_DMODEL), dtype=tl.float32)
block_min, block_max = 0, cdiv_fn(actual_seqlen_q, BLOCK_M) * BLOCK_M
if IS_CAUSAL:
block_min = max(block_min, start_n * BLOCK_N * BLOCK_M // BLOCK_N)
for start_m in range(block_min, block_max, BLOCK_M):
start_m = tl.multiple_of(start_m, BLOCK_M)
global_m = start_m + offs_m_local
q = _bwd_load_q_pre(Q, sm_scale, off_z, off_h, start_q,
start_q + global_m, stride_qz, stride_qh, stride_qm,
BLOCK_DMODEL, ACTUAL_BLOCK_DMODEL)
l_i = _bwd_load_lse(L, off_z, HQ, MAX_SEQLENS_Q, off_h, global_m, actual_seqlen_q)
p = _bwd_compute_p(q, k, l_i, global_m, offs_n, IS_CAUSAL)
do_offset = off_z * stride_dvz + off_h * stride_dvh + (start_q + global_m) * stride_fm
D_block_ptr = D + do_offset[:, None] + tl.arange(0, BLOCK_DMODEL)[None, :]
do = load_fn(D_block_ptr, True, BLOCK_DMODEL <= ACTUAL_BLOCK_DMODEL, "zero")
dv += tl.dot(p.T.to(do.dtype), do)
# 写入输出缓冲区 Out(= dv)
dv_ptr = Out + (off_z * stride_dvz + off_h_k * stride_dvh + (start_k + offs_n) * stride_dvn)[:, None] + tl.arange(0, BLOCK_DMODEL)[None, :]
tl.atomic_add(dv_ptr, dv.to(Out.dtype), mask=offs_n[:, None] < actual_seqlen_k)
# ============================================================================
# L1: 反向DQ内核
# ============================================================================
@triton.jit
def _bwd_kernel_dq(
Q, K, V, sm_scale, L, D, Out, O_fwd,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
stride_dqz, stride_dqh, stride_dqm, stride_dqd,
stride_fz, stride_fh, stride_fm, stride_fn,
cu_seqlens_q, cu_seqlens_k,
HQ: tl.constexpr, HK: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr,
MAX_SEQLENS_Q: tl.constexpr, MAX_SEQLENS_K: tl.constexpr,
VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,
):
start_m = tl.program_id(0); off_h = tl.program_id(1); off_z = tl.program_id(2)
off_h_k = off_h // (HQ // HK)
if VARLEN:
start_q = tl.load(cu_seqlens_q + off_z)
end_q = tl.load(cu_seqlens_q + off_z + 1)
start_k = tl.load(cu_seqlens_k + off_z)
end_k = tl.load(cu_seqlens_k + off_z + 1)
actual_seqlen_q, actual_seqlen_k = end_q - start_q, end_k - start_k
else:
start_q, end_q = 0, MAX_SEQLENS_Q
start_k, end_k = 0, MAX_SEQLENS_K
actual_seqlen_q, actual_seqlen_k = MAX_SEQLENS_Q, MAX_SEQLENS_K
if start_m * BLOCK_M >= actual_seqlen_q:
return
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) # 已是全局坐标
offs_n_local = tl.arange(0, BLOCK_N)
q = _bwd_load_q_pre(Q, sm_scale, off_z, off_h, start_q,
start_q + offs_m, stride_qz, stride_qh, stride_qm,
BLOCK_DMODEL, ACTUAL_BLOCK_DMODEL)
l_i = _bwd_load_lse(L, off_z, HQ, MAX_SEQLENS_Q, off_h, offs_m, actual_seqlen_q)
# 从 O_fwd 加载前向输出行
o_row_offset = off_z * stride_fz + off_h * stride_fh + (start_q + offs_m) * stride_fm
o_row_ptr = O_fwd + o_row_offset[:, None] + tl.arange(0, BLOCK_DMODEL)[None, :]
o_row = load_fn(o_row_ptr, True, BLOCK_DMODEL <= ACTUAL_BLOCK_DMODEL, "zero")
do_offset = off_z * stride_dqz + off_h * stride_dqh + (start_q + offs_m) * stride_dqm
D_block_ptr = D + do_offset[:, None] + tl.arange(0, BLOCK_DMODEL)[None, :]
do = load_fn(D_block_ptr, True, BLOCK_DMODEL <= ACTUAL_BLOCK_DMODEL, "zero")
Di = tl.sum(do * o_row, 1)
dq = tl.zeros((BLOCK_M, BLOCK_DMODEL), dtype=tl.float32)
block_min, block_max = 0, cdiv_fn(actual_seqlen_k, BLOCK_N) * BLOCK_N
if IS_CAUSAL:
block_max = min(block_max, (start_m * BLOCK_M + BLOCK_M) * BLOCK_N // BLOCK_M)
for start_n in range(block_min, block_max, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
global_n = start_n + offs_n_local # 全局K列坐标
k_offset = off_z * stride_kz + off_h_k * stride_kh + (start_k + global_n) * stride_kn
K_block_ptr = K + k_offset[:, None] + tl.arange(0, BLOCK_DMODEL)[None, :]
k = load_fn(K_block_ptr, True, BLOCK_DMODEL <= ACTUAL_BLOCK_DMODEL, "zero")
v_offset = off_z * stride_vz + off_h_k * stride_vh + (start_k + global_n) * stride_vk
V_block_ptr = V + v_offset[:, None] + tl.arange(0, BLOCK_DMODEL)[None, :]
v = load_fn(V_block_ptr, True, BLOCK_DMODEL <= ACTUAL_BLOCK_DMODEL, "zero")
p = _bwd_compute_p(q, k, l_i, offs_m, global_n, IS_CAUSAL)
ds = p * (tl.dot(do, v.T) - Di[:, None])
dq += tl.dot(ds.to(k.dtype), k)
dq *= sm_scale # 纯流出协变:补足K未预乘缺失的sm_scale
# 写入输出缓冲区 Out(= dq)
dq_ptr = Out + (off_z * stride_dqz + off_h * stride_dqh + (start_q + offs_m) * stride_dqm)[:, None] + tl.arange(0, BLOCK_DMODEL)[None, :]
tl.store(dq_ptr, dq.to(Out.dtype), mask=offs_m[:, None] < actual_seqlen_q)
# ============================================================================
# L3: PyTorch Autograd 封装
# ============================================================================
_SCALAR_KEYS = ['sm_scale', 'causal', 'dropout_p', 'varlen', 'num_contexts',
'max_seqlens_q', 'max_seqlens_k', 'return_encoded_softmax']
def _save_metadata(ctx, metadata):
for key in _SCALAR_KEYS:
setattr(ctx, key, getattr(metadata, key, None))
ctx.bias = metadata.bias
ctx.alibi_slopes = metadata.alibi_slopes
ctx.cu_seqlens_q = metadata.cu_seqlens_q
ctx.cu_seqlens_k = metadata.cu_seqlens_k
def _load_metadata(ctx):
metadata = AttentionMetadata(
sm_scale=ctx.sm_scale, causal=ctx.causal,
dropout_p=ctx.dropout_p, return_encoded_softmax=ctx.return_encoded_softmax)
metadata.varlen = ctx.varlen
metadata.num_contexts = ctx.num_contexts
metadata.max_seqlens_q = ctx.max_seqlens_q
metadata.max_seqlens_k = ctx.max_seqlens_k
metadata.bias = ctx.bias
metadata.alibi_slopes = ctx.alibi_slopes
metadata.cu_seqlens_q = ctx.cu_seqlens_q
metadata.cu_seqlens_k = ctx.cu_seqlens_k
return metadata
class TritonFlashAttention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, o, metadata):
if q.stride(-1) != 1 or k.stride(-1) != 1 or v.stride(-1) != 1:
raise RuntimeError("Q/K/V must be contiguous.")
if metadata.bias is not None and metadata.bias.numel() >= 2**31:
raise ValueError("Bias too large.")
if o is None:
o = torch.empty_like(q, dtype=v.dtype)
metadata.validate_inputs(q, k, v, o)
metadata.validate_resources()
cfg = get_dynamic_block_config(metadata.max_seqlens_q, metadata.max_seqlens_k)
batch = q.shape[0] if not metadata.varlen else metadata.num_contexts
nheads_q, nheads_k, head_dim = q.shape[1], k.shape[1], q.shape[-1]
grid = (triton.cdiv(metadata.max_seqlens_q, cfg['BLOCK_M']), nheads_q, batch)
if grid[0] == 0 or grid[1] == 0:
o.zero_()
ctx.save_for_backward(None, None, None, None, None)
_save_metadata(ctx, metadata)
return o
L = torch.empty((batch, nheads_q, metadata.max_seqlens_q),
device=q.device, dtype=torch.float32)
_attn_fwd_kernel[grid](
q, k, v, metadata.bias, metadata.sm_scale, L, o,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
metadata.bias.stride(0) if metadata.bias is not None else 0,
metadata.bias.stride(1) if metadata.bias is not None else 0,
metadata.bias.stride(2) if metadata.bias is not None else 0,
metadata.bias.stride(3) if metadata.bias is not None else 0,
metadata.alibi_slopes.stride(0) if metadata.alibi_slopes is not None else 0,
metadata.alibi_slopes.stride(1) if metadata.alibi_slopes is not None else 0,
metadata.cu_seqlens_q, metadata.cu_seqlens_k,
metadata.dropout_p,
torch.randint(0, 2**32, (1,), device=q.device).item(),
0, None, metadata.alibi_slopes,
HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_dim,
MAX_SEQLENS_Q=metadata.max_seqlens_q, MAX_SEQLENS_K=metadata.max_seqlens_k,
VARLEN=metadata.varlen, IS_CAUSAL=metadata.causal,
BLOCK_M=cfg['BLOCK_M'], BLOCK_DMODEL=triton.next_power_of_2(head_dim),
BLOCK_N=cfg['BLOCK_N'], PRE_LOAD_V=cfg['PRE_LOAD_V'],
BIAS_TYPE=1 if metadata.bias is not None else 0,
ENABLE_DROPOUT=metadata.dropout_p > 0.0,
RETURN_ENCODED_SOFTMAX=metadata.return_encoded_softmax,
USE_ALIBI=metadata.alibi_slopes is not None,
BATCH_SIZE=batch,
num_warps=HW_WARPS_SMALL_HEAD if head_dim <= 64 else HW_WARPS_LARGE_HEAD,
num_stages=HW_NUM_STAGES)
metadata.log_profile(kernel="fwd", config=cfg, shape=q.shape, grid=grid)
ctx.save_for_backward(q, k, v, o, L)
_save_metadata(ctx, metadata)
return o
@staticmethod
def backward(ctx, do):
q, k, v, o, L = ctx.saved_tensors
if q is None:
return None, None, None, None, None
metadata = _load_metadata(ctx)
batch = q.shape[0] if not metadata.varlen else metadata.num_contexts
nheads_q, nheads_k, head_dim = q.shape[1], k.shape[1], q.shape[-1]
cfg = get_dynamic_block_config(metadata.max_seqlens_q, metadata.max_seqlens_k)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
dq = torch.empty_like(q)
nwarps = HW_WARPS_SMALL_HEAD if head_dim <= 64 else HW_WARPS_LARGE_HEAD
# dK 内核
grid_dk = (triton.cdiv(metadata.max_seqlens_k, cfg['BLOCK_N']), nheads_q, batch)
_bwd_kernel_dk[grid_dk](
q, k, v, metadata.sm_scale, L, do, dk, o,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
dk.stride(0), dk.stride(1), dk.stride(2), dk.stride(3),
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
metadata.cu_seqlens_q, metadata.cu_seqlens_k,
HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_dim,
MAX_SEQLENS_Q=metadata.max_seqlens_q, MAX_SEQLENS_K=metadata.max_seqlens_k,
VARLEN=metadata.varlen, IS_CAUSAL=metadata.causal,
BLOCK_M=cfg['BLOCK_M'], BLOCK_DMODEL=triton.next_power_of_2(head_dim),
BLOCK_N=cfg['BLOCK_N'], num_warps=nwarps, num_stages=HW_NUM_STAGES)
# dV 内核
grid_dv = (triton.cdiv(metadata.max_seqlens_k, cfg['BLOCK_N']), nheads_q, batch)
_bwd_kernel_dv[grid_dv](
q, k, v, metadata.sm_scale, L, do, dv, o,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
dv.stride(0), dv.stride(1), dv.stride(2), dv.stride(3),
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
metadata.cu_seqlens_q, metadata.cu_seqlens_k,
HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_dim,
MAX_SEQLENS_Q=metadata.max_seqlens_q, MAX_SEQLENS_K=metadata.max_seqlens_k,
VARLEN=metadata.varlen, IS_CAUSAL=metadata.causal,
BLOCK_M=cfg['BLOCK_M'], BLOCK_DMODEL=triton.next_power_of_2(head_dim),
BLOCK_N=cfg['BLOCK_N'], num_warps=nwarps, num_stages=HW_NUM_STAGES)
# dQ 内核
grid_dq = (triton.cdiv(metadata.max_seqlens_q, cfg['BLOCK_M']), nheads_q, batch)
_bwd_kernel_dq[grid_dq](
q, k, v, metadata.sm_scale, L, do, dq, o,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
dq.stride(0), dq.stride(1), dq.stride(2), dq.stride(3),
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
metadata.cu_seqlens_q, metadata.cu_seqlens_k,
HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_dim,
MAX_SEQLENS_Q=metadata.max_seqlens_q, MAX_SEQLENS_K=metadata.max_seqlens_k,
VARLEN=metadata.varlen, IS_CAUSAL=metadata.causal,
BLOCK_M=cfg['BLOCK_M'], BLOCK_DMODEL=triton.next_power_of_2(head_dim),
BLOCK_N=cfg['BLOCK_N'], num_warps=nwarps, num_stages=HW_NUM_STAGES)
metadata.log_profile(kernel="bwd", config=cfg, shape=q.shape)
return dq, dk, dv, None, None
# ============================================================================
# L4: 对外接口
# ============================================================================
def _prepare_inputs(q, k, v, is_varlen):
if is_varlen:
return tuple(t.contiguous() for t in (q, k, v))
return tuple(t.transpose(1, 2).contiguous() for t in (q, k, v))
def _prepare_output(o, is_varlen):
return o if is_varlen else o.transpose(1, 2)
def flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False,
bias=None, alibi_slopes=None, **kwargs):
q, k, v = _prepare_inputs(q, k, v, is_varlen=False)
sm_scale = softmax_scale or q.shape[-1] ** -0.5
metadata = AttentionMetadata(sm_scale=sm_scale, causal=causal,
dropout_p=dropout_p, **kwargs)
metadata.max_seqlens_q, metadata.max_seqlens_k = q.shape[2], k.shape[2]
if bias is not None:
metadata.set_bias(bias)
if alibi_slopes is not None:
metadata.set_alibi(alibi_slopes, q.shape[0], q.shape[1])
return _prepare_output(TritonFlashAttention.apply(q, k, v, None, metadata), is_varlen=False)
def flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k,
max_seqlen_q, max_seqlen_k, dropout_p=0.0,
softmax_scale=None, causal=False, **kwargs):
q, k, v = _prepare_inputs(q, k, v, is_varlen=True)
sm_scale = softmax_scale or q.shape[-1] ** -0.5
metadata = AttentionMetadata(sm_scale=sm_scale, causal=causal,
dropout_p=dropout_p, **kwargs)
metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k)
metadata.max_seqlens_q, metadata.max_seqlens_k = max_seqlen_q, max_seqlen_k
return _prepare_output(TritonFlashAttention.apply(q, k, v, None, metadata), is_varlen=True)
这是用于验证AI编程原理的验证版,多头注意力机制,采用协流变规则编制,与现有工程所有有差别,可用于学习或是教学。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐

所有评论(0)