#!/usr/bin/env python
"""
Fused Attention (Flash Attention v2) - 纯流出协变 · 最终修正版
========================================================================
修正: 梯度输出缓冲区 / 因果掩码全局坐标 / 偏移计算 / backward接线
"""

import torch
import triton
import triton.language as tl
from typing import Optional, Dict

# ============================================================================
# L5: 全局参数池
# ============================================================================
LOG2_E = 1.4426950408889634
LN2 = 0.6931471805599453
HW_MAX_SEQLEN = 131072
HW_MEM_BUDGET_RATIO = 0.8
HW_WARPS_SMALL_HEAD = 4
HW_WARPS_LARGE_HEAD = 8
HW_NUM_STAGES = 2

def get_dynamic_block_config(max_seq_q, max_seq_k):
    max_seq = max(max_seq_q, max_seq_k)
    if max_seq < 512:
        return {'BLOCK_M': 64, 'BLOCK_N': 64, 'PRE_LOAD_V': True}
    elif max_seq > 8192:
        return {'BLOCK_M': 128, 'BLOCK_N': 32, 'PRE_LOAD_V': False}
    return {'BLOCK_M': 128, 'BLOCK_N': 64, 'PRE_LOAD_V': True}

# ============================================================================
# L2: 流态配置池
# ============================================================================
class AttentionMetadata:
    def __init__(self, sm_scale=1.0, causal=False, dropout_p=0.0,
                 return_encoded_softmax=False, enable_profiling=False):
        self.sm_scale = sm_scale
        self.causal = causal
        self.dropout_p = dropout_p
        self.return_encoded_softmax = return_encoded_softmax
        self.varlen = False
        self.cu_seqlens_q = None
        self.cu_seqlens_k = None
        self.num_contexts = 0
        self.max_seqlens_q = 0
        self.max_seqlens_k = 0
        self.bias = None
        self.alibi_slopes = None
        self.enable_profiling = enable_profiling
        self.profile_stats = {}

    def set_varlen_params(self, cu_seqlens_q, cu_seqlens_k):
        self.varlen = True
        self.cu_seqlens_q = cu_seqlens_q
        self.cu_seqlens_k = cu_seqlens_k
        if len(cu_seqlens_q) < 2 or len(cu_seqlens_q) != len(cu_seqlens_k):
            raise ValueError("cu_seqlens mismatch.")
        self.num_contexts = len(cu_seqlens_q) - 1
        self.max_seqlens_q = max(cu_seqlens_q[i+1].item() - cu_seqlens_q[i].item()
                                  for i in range(self.num_contexts))
        self.max_seqlens_k = max(cu_seqlens_k[i+1].item() - cu_seqlens_k[i].item()
                                  for i in range(self.num_contexts))

    def set_bias(self, bias):
        if not bias.is_cuda or bias.dim() != 4 or bias.shape[0] != 1:
            raise ValueError("Bias must be a 4D CUDA tensor with batch size 1.")
        self.bias = bias

    def set_alibi(self, alibi_slopes, batch, nheads):
        if not alibi_slopes.is_cuda or alibi_slopes.dim() != 2:
            raise ValueError("ALiBi slopes must be a 2D CUDA tensor.")
        if alibi_slopes.shape != (batch, nheads):
            raise ValueError(f"ALiBi shape mismatch.")
        self.alibi_slopes = alibi_slopes

    def validate_inputs(self, q, k, v, o=None):
        if q.data_ptr() == 0 or k.data_ptr() == 0 or v.data_ptr() == 0:
            raise RuntimeError("Null pointers.")
        if q.dim() != k.dim() or q.dim() != v.dim():
            raise ValueError("Q/K/V dim mismatch.")
        if self.varlen:
            if q.dim() != 3 or self.cu_seqlens_q is None:
                raise ValueError("varlen requires 3D tensors.")
            if self.bias or self.dropout_p != 0.0 or self.return_encoded_softmax:
                raise ValueError("varlen: bias/dropout/encoded_softmax unsupported.")
        else:
            if q.dim() != 4 or self.max_seqlens_q <= 0:
                raise ValueError("Non-varlen requires 4D tensors.")
            if k.shape != v.shape or q.shape[-1] != k.shape[-1]:
                raise ValueError("K/V or head dim mismatch.")
            if q.dtype != k.dtype or q.dtype != v.dtype:
                raise TypeError("dtype mismatch.")
            if q.shape[-1] > 256:
                raise ValueError("Head dim > 256.")
            if o is not None and o.shape != q.shape:
                raise ValueError("Output shape mismatch.")
            if (q.shape[1] % k.shape[1]) != 0:
                raise ValueError("GQA/MQA head count mismatch.")

    def validate_resources(self, device_props=None):
        if device_props is None:
            free_mem_mb = (torch.cuda.memory_reserved() - torch.cuda.memory_allocated()) / 1024 / 1024
        else:
            free_mem_mb = device_props.get('free_mem_mb', 0)
        max_seq = max(self.max_seqlens_q, self.max_seqlens_k)
        est_mem_mb = (max_seq * max(self.num_contexts, 1) * 3 * 2) / 1024 / 1024
        if est_mem_mb > free_mem_mb * HW_MEM_BUDGET_RATIO:
            raise RuntimeError(f"SeqLen {max_seq} exceeds memory budget.")
        if max_seq > HW_MAX_SEQLEN:
            raise ValueError(f"SeqLen {max_seq} exceeds hardware limit.")

    def log_profile(self, **kwargs):
        if self.enable_profiling:
            self.profile_stats.update(kwargs)

# ============================================================================
# L1: 刚体工具函数
# ============================================================================
@triton.jit
def cdiv_fn(x, y):
    return (x + y - 1) // y

@triton.jit
def load_fn(block_ptr, first, second, pad):
    return tl.load(block_ptr, boundary_check=(first, second), padding_option=pad)

@triton.jit
def compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, offs_m, offs_n, transpose=False):
    relative_pos_block = offs_m[:, None] + seqlen_k - seqlen_q - offs_n[None, :]
    alibi_block = -1.0 * alibi_slope * tl.abs(relative_pos_block)
    return alibi_block.T if transpose else alibi_block

@triton.jit
def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride):
    philox_seed = philox_seed.to(tl.uint64)
    philox_offset = philox_offset.to(tl.uint64)
    seed_hi = (philox_seed >> 32).to(tl.uint32)
    seed_lo = philox_seed.to(tl.uint32)
    offset_hi = (philox_offset >> 32).to(tl.uint32)
    offset_lo = philox_offset.to(tl.uint32)
    rng = tl.randint4x(seed_lo, seed_hi, offset_lo, offset_hi, (m, n), stride)
    return rng < tl.uint32(0xFFFFFFFF * dropout_p)

# ============================================================================
# L1: 前向内循环(未改动,已验证正确)
# ============================================================================
@triton.jit
def _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr,
                     start_m, actual_seqlen_k, actual_seqlen_q,
                     dropout_p, philox_seed, batch_philox_offset,
                     encoded_softmax_block_ptr, block_min, block_max,
                     offs_n_causal, masked_blocks, n_extra_tokens,
                     bias_ptr, alibi_slope, stride_bm,
                     IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr,
                     BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,
                     OFFS_M: tl.constexpr, OFFS_N: tl.constexpr,
                     PRE_LOAD_V: tl.constexpr, MASK_STEPS: tl.constexpr,
                     ENABLE_DROPOUT: tl.constexpr,
                     RETURN_ENCODED_SOFTMAX: tl.constexpr,
                     PADDED_HEAD: tl.constexpr,
                     BIAS_TYPE: tl.constexpr, USE_ALIBI: tl.constexpr):
    offs_m = start_m * BLOCK_M + OFFS_M
    for start_n in range(block_min, block_max, BLOCK_N):
        start_n = tl.multiple_of(start_n, BLOCK_N)
        k = load_fn(K_block_ptr + start_n, PADDED_HEAD, True, "zero")
        if PRE_LOAD_V:
            v = load_fn(V_block_ptr + start_n, True, PADDED_HEAD, "zero")
        qk = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
        qk += tl.dot(q, k.T)
        if USE_ALIBI:
            alibi_block = compute_alibi_block(
                alibi_slope, actual_seqlen_q, actual_seqlen_k,
                start_m * BLOCK_M + OFFS_M, start_n + OFFS_N)
            qk += alibi_block * LOG2_E
        if BIAS_TYPE != 0:
            bias_block = load_fn(bias_ptr + start_n, False,
                                  MASK_STEPS and (n_extra_tokens != 0), "zero")
            qk += bias_block * LOG2_E
        if MASK_STEPS and start_n >= masked_blocks * BLOCK_N:
            mask = offs_n_causal[None, :] >= (start_n + OFFS_N)[None, :]
            qk = tl.where(mask, qk, float("-inf"))
        m_ij = tl.max(qk, 1)
        p = tl.math.exp2(qk - m_ij[:, None])
        l_ij = tl.sum(p, 1)
        m_i_new = tl.max(m_i, m_ij)
        alpha = tl.math.exp2(m_i - m_i_new)
        beta = tl.math.exp2(m_ij - m_i_new)
        l_i_new = alpha * l_i + beta * l_ij
        acc = acc * alpha[:, None]
        if ENABLE_DROPOUT:
            philox_offset = batch_philox_offset + start_m * BLOCK_M * actual_seqlen_k + start_n
            keep = dropout_mask(philox_seed, philox_offset, dropout_p,
                                BLOCK_M, BLOCK_N, actual_seqlen_k)
            p = tl.where(keep, p, 0.0) / (1.0 - dropout_p)
        if not PRE_LOAD_V:
            v = load_fn(V_block_ptr + start_n, True, PADDED_HEAD, "zero")
        acc += tl.dot(p.to(v.dtype), v)
        m_i = m_i_new
        l_i = l_i_new
    acc = acc / l_i[:, None]
    return acc, l_i, m_i

# ============================================================================
# L1: 前向外层调度(未改动,已验证正确)
# ============================================================================
@triton.jit
def _attn_fwd_kernel(
    Q, K, V, bias, sm_scale, L, Out,
    stride_qz, stride_qh, stride_qm, stride_qk,
    stride_kz, stride_kh, stride_kn, stride_kk,
    stride_vz, stride_vh, stride_vk, stride_vn,
    stride_oz, stride_oh, stride_om, stride_on,
    stride_bz, stride_bh, stride_bm, stride_bn,
    stride_az, stride_ah,
    cu_seqlens_q, cu_seqlens_k,
    dropout_p, philox_seed, philox_offset_base, encoded_softmax, alibi_slopes,
    HQ: tl.constexpr, HK: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr,
    MAX_SEQLENS_Q: tl.constexpr, MAX_SEQLENS_K: tl.constexpr,
    VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr,
    BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,
    PRE_LOAD_V: tl.constexpr, BIAS_TYPE: tl.constexpr,
    ENABLE_DROPOUT: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr,
    USE_ALIBI: tl.constexpr, BATCH_SIZE: tl.constexpr,
):
    start_m = tl.program_id(0); off_h = tl.program_id(1); off_z = tl.program_id(2)
    off_h_k = off_h // (HQ // HK)
    if VARLEN:
        start_q = tl.load(cu_seqlens_q + off_z)
        end_q = tl.load(cu_seqlens_q + off_z + 1)
        start_k = tl.load(cu_seqlens_k + off_z)
        end_k = tl.load(cu_seqlens_k + off_z + 1)
        actual_seqlen_q = end_q - start_q
        actual_seqlen_k = end_k - start_k
    else:
        start_q, end_q = 0, MAX_SEQLENS_Q
        start_k, end_k = 0, MAX_SEQLENS_K
        actual_seqlen_q, actual_seqlen_k = MAX_SEQLENS_Q, MAX_SEQLENS_K
    if start_m * BLOCK_M >= actual_seqlen_q:
        return
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = tl.arange(0, BLOCK_N)
    Q_block_ptr = Q + (off_z * stride_qz + off_h * stride_qh + (start_q + offs_m) * stride_qm)[:, None] + tl.arange(0, BLOCK_DMODEL)[None, :]
    K_block_ptr = K + (off_z * stride_kz + off_h_k * stride_kh + (start_k + offs_n) * stride_kn)[:, None] + tl.arange(0, BLOCK_DMODEL)[None, :]
    V_block_ptr = V + (off_z * stride_vz + off_h_k * stride_vh + (start_k + offs_n) * stride_vk)[:, None] + tl.arange(0, BLOCK_DMODEL)[None, :]
    q = load_fn(Q_block_ptr, True, BLOCK_DMODEL <= ACTUAL_BLOCK_DMODEL, "zero")
    q = (q * sm_scale * LOG2_E).to(q.dtype)
    acc = tl.zeros((BLOCK_M, BLOCK_DMODEL), dtype=tl.float32)
    l_i = tl.zeros((BLOCK_M,), dtype=tl.float32)
    m_i = tl.full((BLOCK_M,), float("-inf"), dtype=tl.float32)
    block_min, block_max = 0, cdiv_fn(actual_seqlen_k, BLOCK_N) * BLOCK_N
    masked_blocks, n_extra_tokens = 0, 0
    offs_n_causal = offs_n + (actual_seqlen_q - actual_seqlen_k)
    if IS_CAUSAL:
        block_max = min(block_max, (start_m * BLOCK_M + BLOCK_M) * BLOCK_N // BLOCK_M)
        masked_blocks = block_max // BLOCK_N
        n_extra_tokens = block_max % BLOCK_N
    alibi_slope = (tl.load(alibi_slopes + off_z * stride_az + off_h * stride_ah)
                   if USE_ALIBI else 0.0)
    bias_ptr = bias + off_h * stride_bh if BIAS_TYPE != 0 else None
    batch_philox_offset = (philox_offset_base + off_z * HQ * actual_seqlen_q * actual_seqlen_k
                           + off_h * actual_seqlen_q * actual_seqlen_k)
    acc, l_i, m_i = _attn_fwd_inner(
        acc, l_i, m_i, q, K_block_ptr, V_block_ptr, start_m,
        actual_seqlen_k, actual_seqlen_q, dropout_p, philox_seed,
        batch_philox_offset,
        encoded_softmax + off_z * HQ * MAX_SEQLENS_Q + off_h * MAX_SEQLENS_Q + start_m * BLOCK_M,
        block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens,
        bias_ptr, alibi_slope, stride_bm,
        IS_CAUSAL=IS_CAUSAL, BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL,
        BLOCK_N=BLOCK_N, OFFS_M=tl.arange(0, BLOCK_M), OFFS_N=tl.arange(0, BLOCK_N),
        PRE_LOAD_V=PRE_LOAD_V, MASK_STEPS=IS_CAUSAL, ENABLE_DROPOUT=ENABLE_DROPOUT,
        RETURN_ENCODED_SOFTMAX=RETURN_ENCODED_SOFTMAX,
        PADDED_HEAD=BLOCK_DMODEL > ACTUAL_BLOCK_DMODEL,
        BIAS_TYPE=BIAS_TYPE, USE_ALIBI=USE_ALIBI)
    if L is not None:
        L_ptr = L + off_z * HQ * MAX_SEQLENS_Q + off_h * MAX_SEQLENS_Q + start_m * BLOCK_M + offs_m
        tl.store(L_ptr, m_i + tl.math.log2(l_i), mask=offs_m < actual_seqlen_q)
    out_offset = off_z * stride_oz + off_h * stride_oh + (start_q + offs_m) * stride_om
    Out_block_ptr = Out + out_offset[:, None] + tl.arange(0, BLOCK_DMODEL)[None, :]
    tl.store(Out_block_ptr, acc.to(Out.dtype), mask=offs_m[:, None] < actual_seqlen_q)

# ============================================================================
# L1: 反向共享函数(修正:统一使用全局坐标语义)
# ============================================================================
@triton.jit
def _bwd_load_q_pre(Q, sm_scale, off_z, off_h, start_q, global_row_idx,
                     stride_qz, stride_qh, stride_qm,
                     BLOCK_DMODEL, ACTUAL_BLOCK_DMODEL: tl.constexpr):
    """Q加载 + 预乘。global_row_idx 为全局行索引(含start_q)"""
    q_offset = off_z * stride_qz + off_h * stride_qh + global_row_idx * stride_qm
    Q_block_ptr = Q + q_offset[:, None] + tl.arange(0, BLOCK_DMODEL)[None, :]
    q = load_fn(Q_block_ptr, True, BLOCK_DMODEL <= ACTUAL_BLOCK_DMODEL, "zero")
    q = (q * sm_scale * LOG2_E).to(q.dtype)
    return q

@triton.jit
def _bwd_load_lse(L, off_z, HQ, MAX_SEQLENS_Q, off_h, global_row_idx, actual_seqlen_q):
    """LSE加载。global_row_idx 为全局行索引(不含start_q)"""
    L_ptr = L + off_z * HQ * MAX_SEQLENS_Q + off_h * MAX_SEQLENS_Q + global_row_idx
    return tl.load(L_ptr, mask=global_row_idx < actual_seqlen_q)

@triton.jit
def _bwd_compute_p(q, k, l_i, global_m, global_n, IS_CAUSAL: tl.constexpr):
    """QK + 因果掩码 + softmax。global_m/global_n 均为全局坐标"""
    qk = tl.dot(q, k.T)
    if IS_CAUSAL:
        mask = global_m[:, None] >= global_n[None, :]
        qk = tl.where(mask, qk, float("-inf"))
    p = tl.math.exp2(qk - l_i[:, None])
    return p

# ============================================================================
# L1: 反向DK内核
# ============================================================================
@triton.jit
def _bwd_kernel_dk(
    Q, K, V, sm_scale, L, D, Out, O_fwd,
    stride_qz, stride_qh, stride_qm, stride_qk,
    stride_kz, stride_kh, stride_kn, stride_kk,
    stride_vz, stride_vh, stride_vk, stride_vn,
    stride_dkz, stride_dkh, stride_dkn, stride_dkk,
    stride_fz, stride_fh, stride_fm, stride_fn,
    cu_seqlens_q, cu_seqlens_k,
    HQ: tl.constexpr, HK: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr,
    MAX_SEQLENS_Q: tl.constexpr, MAX_SEQLENS_K: tl.constexpr,
    VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr,
    BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,
):
    start_n = tl.program_id(0); off_h = tl.program_id(1); off_z = tl.program_id(2)
    off_h_k = off_h // (HQ // HK)
    if VARLEN:
        start_k = tl.load(cu_seqlens_k + off_z)
        end_k = tl.load(cu_seqlens_k + off_z + 1)
        start_q = tl.load(cu_seqlens_q + off_z)
        end_q = tl.load(cu_seqlens_q + off_z + 1)
        actual_seqlen_k, actual_seqlen_q = end_k - start_k, end_q - start_q
    else:
        start_k, end_k = 0, MAX_SEQLENS_K
        start_q, end_q = 0, MAX_SEQLENS_Q
        actual_seqlen_k, actual_seqlen_q = MAX_SEQLENS_K, MAX_SEQLENS_Q
    if start_n * BLOCK_N >= actual_seqlen_k:
        return
    offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
    offs_m_local = tl.arange(0, BLOCK_M)  # 块内局部偏移

    k_offset = off_z * stride_kz + off_h_k * stride_kh + (start_k + offs_n) * stride_kn
    K_block_ptr = K + k_offset[:, None] + tl.arange(0, BLOCK_DMODEL)[None, :]
    k = load_fn(K_block_ptr, True, BLOCK_DMODEL <= ACTUAL_BLOCK_DMODEL, "zero")

    v_offset = off_z * stride_vz + off_h_k * stride_vh + (start_k + offs_n) * stride_vk
    V_block_ptr = V + v_offset[:, None] + tl.arange(0, BLOCK_DMODEL)[None, :]
    v = load_fn(V_block_ptr, True, BLOCK_DMODEL <= ACTUAL_BLOCK_DMODEL, "zero")

    dk = tl.zeros((BLOCK_N, BLOCK_DMODEL), dtype=tl.float32)
    block_min, block_max = 0, cdiv_fn(actual_seqlen_q, BLOCK_M) * BLOCK_M
    if IS_CAUSAL:
        block_min = max(block_min, start_n * BLOCK_N * BLOCK_M // BLOCK_N)

    for start_m in range(block_min, block_max, BLOCK_M):
        start_m = tl.multiple_of(start_m, BLOCK_M)
        global_m = start_m + offs_m_local  # 全局Q行坐标(相对于start_q)

        q = _bwd_load_q_pre(Q, sm_scale, off_z, off_h, start_q,
                             start_q + global_m, stride_qz, stride_qh, stride_qm,
                             BLOCK_DMODEL, ACTUAL_BLOCK_DMODEL)
        l_i = _bwd_load_lse(L, off_z, HQ, MAX_SEQLENS_Q, off_h, global_m, actual_seqlen_q)

        # 因果掩码:global_m vs offs_n(offs_n已是全局坐标)
        p = _bwd_compute_p(q, k, l_i, global_m, offs_n, IS_CAUSAL)

        # 从 O_fwd 加载前向输出行
        o_row_offset = off_z * stride_fz + off_h * stride_fh + (start_q + global_m) * stride_fm
        o_row_ptr = O_fwd + o_row_offset[:, None] + tl.arange(0, BLOCK_DMODEL)[None, :]
        o_row = load_fn(o_row_ptr, True, BLOCK_DMODEL <= ACTUAL_BLOCK_DMODEL, "zero")

        do_offset = off_z * stride_dkz + off_h * stride_dkh + (start_q + global_m) * stride_fm
        D_block_ptr = D + do_offset[:, None] + tl.arange(0, BLOCK_DMODEL)[None, :]
        do = load_fn(D_block_ptr, True, BLOCK_DMODEL <= ACTUAL_BLOCK_DMODEL, "zero")

        Di = tl.sum(do * o_row, 1)
        ds = p * (tl.dot(do, v.T) - Di[:, None])
        dk += tl.dot(ds.T.to(q.dtype), q)

    dk *= LN2  # 纯流出协变:抵消Q_pre中多余的LOG2_E

    # 写入输出缓冲区 Out(= dk)
    dk_ptr = Out + (off_z * stride_dkz + off_h_k * stride_dkh + (start_k + offs_n) * stride_dkn)[:, None] + tl.arange(0, BLOCK_DMODEL)[None, :]
    tl.atomic_add(dk_ptr, dk.to(Out.dtype), mask=offs_n[:, None] < actual_seqlen_k)

# ============================================================================
# L1: 反向DV内核
# ============================================================================
@triton.jit
def _bwd_kernel_dv(
    Q, K, V, sm_scale, L, D, Out, O_fwd,
    stride_qz, stride_qh, stride_qm, stride_qk,
    stride_kz, stride_kh, stride_kn, stride_kk,
    stride_vz, stride_vh, stride_vk, stride_vn,
    stride_dvz, stride_dvh, stride_dvn, stride_dvd,
    stride_fz, stride_fh, stride_fm, stride_fn,
    cu_seqlens_q, cu_seqlens_k,
    HQ: tl.constexpr, HK: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr,
    MAX_SEQLENS_Q: tl.constexpr, MAX_SEQLENS_K: tl.constexpr,
    VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr,
    BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,
):
    start_n = tl.program_id(0); off_h = tl.program_id(1); off_z = tl.program_id(2)
    off_h_k = off_h // (HQ // HK)
    if VARLEN:
        start_k = tl.load(cu_seqlens_k + off_z)
        end_k = tl.load(cu_seqlens_k + off_z + 1)
        start_q = tl.load(cu_seqlens_q + off_z)
        end_q = tl.load(cu_seqlens_q + off_z + 1)
        actual_seqlen_k, actual_seqlen_q = end_k - start_k, end_q - start_q
    else:
        start_k, end_k = 0, MAX_SEQLENS_K
        start_q, end_q = 0, MAX_SEQLENS_Q
        actual_seqlen_k, actual_seqlen_q = MAX_SEQLENS_K, MAX_SEQLENS_Q
    if start_n * BLOCK_N >= actual_seqlen_k:
        return
    offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
    offs_m_local = tl.arange(0, BLOCK_M)

    v_offset = off_z * stride_vz + off_h_k * stride_vh + (start_k + offs_n) * stride_vk
    V_block_ptr = V + v_offset[:, None] + tl.arange(0, BLOCK_DMODEL)[None, :]
    v = load_fn(V_block_ptr, True, BLOCK_DMODEL <= ACTUAL_BLOCK_DMODEL, "zero")

    k_offset = off_z * stride_kz + off_h_k * stride_kh + (start_k + offs_n) * stride_kn
    K_block_ptr = K + k_offset[:, None] + tl.arange(0, BLOCK_DMODEL)[None, :]
    k = load_fn(K_block_ptr, True, BLOCK_DMODEL <= ACTUAL_BLOCK_DMODEL, "zero")

    dv = tl.zeros((BLOCK_N, BLOCK_DMODEL), dtype=tl.float32)
    block_min, block_max = 0, cdiv_fn(actual_seqlen_q, BLOCK_M) * BLOCK_M
    if IS_CAUSAL:
        block_min = max(block_min, start_n * BLOCK_N * BLOCK_M // BLOCK_N)

    for start_m in range(block_min, block_max, BLOCK_M):
        start_m = tl.multiple_of(start_m, BLOCK_M)
        global_m = start_m + offs_m_local

        q = _bwd_load_q_pre(Q, sm_scale, off_z, off_h, start_q,
                             start_q + global_m, stride_qz, stride_qh, stride_qm,
                             BLOCK_DMODEL, ACTUAL_BLOCK_DMODEL)
        l_i = _bwd_load_lse(L, off_z, HQ, MAX_SEQLENS_Q, off_h, global_m, actual_seqlen_q)
        p = _bwd_compute_p(q, k, l_i, global_m, offs_n, IS_CAUSAL)

        do_offset = off_z * stride_dvz + off_h * stride_dvh + (start_q + global_m) * stride_fm
        D_block_ptr = D + do_offset[:, None] + tl.arange(0, BLOCK_DMODEL)[None, :]
        do = load_fn(D_block_ptr, True, BLOCK_DMODEL <= ACTUAL_BLOCK_DMODEL, "zero")

        dv += tl.dot(p.T.to(do.dtype), do)

    # 写入输出缓冲区 Out(= dv)
    dv_ptr = Out + (off_z * stride_dvz + off_h_k * stride_dvh + (start_k + offs_n) * stride_dvn)[:, None] + tl.arange(0, BLOCK_DMODEL)[None, :]
    tl.atomic_add(dv_ptr, dv.to(Out.dtype), mask=offs_n[:, None] < actual_seqlen_k)

# ============================================================================
# L1: 反向DQ内核
# ============================================================================
@triton.jit
def _bwd_kernel_dq(
    Q, K, V, sm_scale, L, D, Out, O_fwd,
    stride_qz, stride_qh, stride_qm, stride_qk,
    stride_kz, stride_kh, stride_kn, stride_kk,
    stride_vz, stride_vh, stride_vk, stride_vn,
    stride_dqz, stride_dqh, stride_dqm, stride_dqd,
    stride_fz, stride_fh, stride_fm, stride_fn,
    cu_seqlens_q, cu_seqlens_k,
    HQ: tl.constexpr, HK: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr,
    MAX_SEQLENS_Q: tl.constexpr, MAX_SEQLENS_K: tl.constexpr,
    VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr,
    BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,
):
    start_m = tl.program_id(0); off_h = tl.program_id(1); off_z = tl.program_id(2)
    off_h_k = off_h // (HQ // HK)
    if VARLEN:
        start_q = tl.load(cu_seqlens_q + off_z)
        end_q = tl.load(cu_seqlens_q + off_z + 1)
        start_k = tl.load(cu_seqlens_k + off_z)
        end_k = tl.load(cu_seqlens_k + off_z + 1)
        actual_seqlen_q, actual_seqlen_k = end_q - start_q, end_k - start_k
    else:
        start_q, end_q = 0, MAX_SEQLENS_Q
        start_k, end_k = 0, MAX_SEQLENS_K
        actual_seqlen_q, actual_seqlen_k = MAX_SEQLENS_Q, MAX_SEQLENS_K
    if start_m * BLOCK_M >= actual_seqlen_q:
        return
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)  # 已是全局坐标
    offs_n_local = tl.arange(0, BLOCK_N)

    q = _bwd_load_q_pre(Q, sm_scale, off_z, off_h, start_q,
                         start_q + offs_m, stride_qz, stride_qh, stride_qm,
                         BLOCK_DMODEL, ACTUAL_BLOCK_DMODEL)
    l_i = _bwd_load_lse(L, off_z, HQ, MAX_SEQLENS_Q, off_h, offs_m, actual_seqlen_q)

    # 从 O_fwd 加载前向输出行
    o_row_offset = off_z * stride_fz + off_h * stride_fh + (start_q + offs_m) * stride_fm
    o_row_ptr = O_fwd + o_row_offset[:, None] + tl.arange(0, BLOCK_DMODEL)[None, :]
    o_row = load_fn(o_row_ptr, True, BLOCK_DMODEL <= ACTUAL_BLOCK_DMODEL, "zero")

    do_offset = off_z * stride_dqz + off_h * stride_dqh + (start_q + offs_m) * stride_dqm
    D_block_ptr = D + do_offset[:, None] + tl.arange(0, BLOCK_DMODEL)[None, :]
    do = load_fn(D_block_ptr, True, BLOCK_DMODEL <= ACTUAL_BLOCK_DMODEL, "zero")

    Di = tl.sum(do * o_row, 1)
    dq = tl.zeros((BLOCK_M, BLOCK_DMODEL), dtype=tl.float32)
    block_min, block_max = 0, cdiv_fn(actual_seqlen_k, BLOCK_N) * BLOCK_N
    if IS_CAUSAL:
        block_max = min(block_max, (start_m * BLOCK_M + BLOCK_M) * BLOCK_N // BLOCK_M)

    for start_n in range(block_min, block_max, BLOCK_N):
        start_n = tl.multiple_of(start_n, BLOCK_N)
        global_n = start_n + offs_n_local  # 全局K列坐标

        k_offset = off_z * stride_kz + off_h_k * stride_kh + (start_k + global_n) * stride_kn
        K_block_ptr = K + k_offset[:, None] + tl.arange(0, BLOCK_DMODEL)[None, :]
        k = load_fn(K_block_ptr, True, BLOCK_DMODEL <= ACTUAL_BLOCK_DMODEL, "zero")

        v_offset = off_z * stride_vz + off_h_k * stride_vh + (start_k + global_n) * stride_vk
        V_block_ptr = V + v_offset[:, None] + tl.arange(0, BLOCK_DMODEL)[None, :]
        v = load_fn(V_block_ptr, True, BLOCK_DMODEL <= ACTUAL_BLOCK_DMODEL, "zero")

        p = _bwd_compute_p(q, k, l_i, offs_m, global_n, IS_CAUSAL)

        ds = p * (tl.dot(do, v.T) - Di[:, None])
        dq += tl.dot(ds.to(k.dtype), k)

    dq *= sm_scale  # 纯流出协变:补足K未预乘缺失的sm_scale

    # 写入输出缓冲区 Out(= dq)
    dq_ptr = Out + (off_z * stride_dqz + off_h * stride_dqh + (start_q + offs_m) * stride_dqm)[:, None] + tl.arange(0, BLOCK_DMODEL)[None, :]
    tl.store(dq_ptr, dq.to(Out.dtype), mask=offs_m[:, None] < actual_seqlen_q)

# ============================================================================
# L3: PyTorch Autograd 封装
# ============================================================================
_SCALAR_KEYS = ['sm_scale', 'causal', 'dropout_p', 'varlen', 'num_contexts',
                'max_seqlens_q', 'max_seqlens_k', 'return_encoded_softmax']

def _save_metadata(ctx, metadata):
    for key in _SCALAR_KEYS:
        setattr(ctx, key, getattr(metadata, key, None))
    ctx.bias = metadata.bias
    ctx.alibi_slopes = metadata.alibi_slopes
    ctx.cu_seqlens_q = metadata.cu_seqlens_q
    ctx.cu_seqlens_k = metadata.cu_seqlens_k

def _load_metadata(ctx):
    metadata = AttentionMetadata(
        sm_scale=ctx.sm_scale, causal=ctx.causal,
        dropout_p=ctx.dropout_p, return_encoded_softmax=ctx.return_encoded_softmax)
    metadata.varlen = ctx.varlen
    metadata.num_contexts = ctx.num_contexts
    metadata.max_seqlens_q = ctx.max_seqlens_q
    metadata.max_seqlens_k = ctx.max_seqlens_k
    metadata.bias = ctx.bias
    metadata.alibi_slopes = ctx.alibi_slopes
    metadata.cu_seqlens_q = ctx.cu_seqlens_q
    metadata.cu_seqlens_k = ctx.cu_seqlens_k
    return metadata

class TritonFlashAttention(torch.autograd.Function):
    @staticmethod
    def forward(ctx, q, k, v, o, metadata):
        if q.stride(-1) != 1 or k.stride(-1) != 1 or v.stride(-1) != 1:
            raise RuntimeError("Q/K/V must be contiguous.")
        if metadata.bias is not None and metadata.bias.numel() >= 2**31:
            raise ValueError("Bias too large.")
        if o is None:
            o = torch.empty_like(q, dtype=v.dtype)
        metadata.validate_inputs(q, k, v, o)
        metadata.validate_resources()
        cfg = get_dynamic_block_config(metadata.max_seqlens_q, metadata.max_seqlens_k)
        batch = q.shape[0] if not metadata.varlen else metadata.num_contexts
        nheads_q, nheads_k, head_dim = q.shape[1], k.shape[1], q.shape[-1]
        grid = (triton.cdiv(metadata.max_seqlens_q, cfg['BLOCK_M']), nheads_q, batch)
        if grid[0] == 0 or grid[1] == 0:
            o.zero_()
            ctx.save_for_backward(None, None, None, None, None)
            _save_metadata(ctx, metadata)
            return o
        L = torch.empty((batch, nheads_q, metadata.max_seqlens_q),
                         device=q.device, dtype=torch.float32)
        _attn_fwd_kernel[grid](
            q, k, v, metadata.bias, metadata.sm_scale, L, o,
            q.stride(0), q.stride(1), q.stride(2), q.stride(3),
            k.stride(0), k.stride(1), k.stride(2), k.stride(3),
            v.stride(0), v.stride(1), v.stride(2), v.stride(3),
            o.stride(0), o.stride(1), o.stride(2), o.stride(3),
            metadata.bias.stride(0) if metadata.bias is not None else 0,
            metadata.bias.stride(1) if metadata.bias is not None else 0,
            metadata.bias.stride(2) if metadata.bias is not None else 0,
            metadata.bias.stride(3) if metadata.bias is not None else 0,
            metadata.alibi_slopes.stride(0) if metadata.alibi_slopes is not None else 0,
            metadata.alibi_slopes.stride(1) if metadata.alibi_slopes is not None else 0,
            metadata.cu_seqlens_q, metadata.cu_seqlens_k,
            metadata.dropout_p,
            torch.randint(0, 2**32, (1,), device=q.device).item(),
            0, None, metadata.alibi_slopes,
            HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_dim,
            MAX_SEQLENS_Q=metadata.max_seqlens_q, MAX_SEQLENS_K=metadata.max_seqlens_k,
            VARLEN=metadata.varlen, IS_CAUSAL=metadata.causal,
            BLOCK_M=cfg['BLOCK_M'], BLOCK_DMODEL=triton.next_power_of_2(head_dim),
            BLOCK_N=cfg['BLOCK_N'], PRE_LOAD_V=cfg['PRE_LOAD_V'],
            BIAS_TYPE=1 if metadata.bias is not None else 0,
            ENABLE_DROPOUT=metadata.dropout_p > 0.0,
            RETURN_ENCODED_SOFTMAX=metadata.return_encoded_softmax,
            USE_ALIBI=metadata.alibi_slopes is not None,
            BATCH_SIZE=batch,
            num_warps=HW_WARPS_SMALL_HEAD if head_dim <= 64 else HW_WARPS_LARGE_HEAD,
            num_stages=HW_NUM_STAGES)
        metadata.log_profile(kernel="fwd", config=cfg, shape=q.shape, grid=grid)
        ctx.save_for_backward(q, k, v, o, L)
        _save_metadata(ctx, metadata)
        return o

    @staticmethod
    def backward(ctx, do):
        q, k, v, o, L = ctx.saved_tensors
        if q is None:
            return None, None, None, None, None
        metadata = _load_metadata(ctx)
        batch = q.shape[0] if not metadata.varlen else metadata.num_contexts
        nheads_q, nheads_k, head_dim = q.shape[1], k.shape[1], q.shape[-1]
        cfg = get_dynamic_block_config(metadata.max_seqlens_q, metadata.max_seqlens_k)
        dk = torch.empty_like(k)
        dv = torch.empty_like(v)
        dq = torch.empty_like(q)
        nwarps = HW_WARPS_SMALL_HEAD if head_dim <= 64 else HW_WARPS_LARGE_HEAD

        # dK 内核
        grid_dk = (triton.cdiv(metadata.max_seqlens_k, cfg['BLOCK_N']), nheads_q, batch)
        _bwd_kernel_dk[grid_dk](
            q, k, v, metadata.sm_scale, L, do, dk, o,
            q.stride(0), q.stride(1), q.stride(2), q.stride(3),
            k.stride(0), k.stride(1), k.stride(2), k.stride(3),
            v.stride(0), v.stride(1), v.stride(2), v.stride(3),
            dk.stride(0), dk.stride(1), dk.stride(2), dk.stride(3),
            o.stride(0), o.stride(1), o.stride(2), o.stride(3),
            metadata.cu_seqlens_q, metadata.cu_seqlens_k,
            HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_dim,
            MAX_SEQLENS_Q=metadata.max_seqlens_q, MAX_SEQLENS_K=metadata.max_seqlens_k,
            VARLEN=metadata.varlen, IS_CAUSAL=metadata.causal,
            BLOCK_M=cfg['BLOCK_M'], BLOCK_DMODEL=triton.next_power_of_2(head_dim),
            BLOCK_N=cfg['BLOCK_N'], num_warps=nwarps, num_stages=HW_NUM_STAGES)

        # dV 内核
        grid_dv = (triton.cdiv(metadata.max_seqlens_k, cfg['BLOCK_N']), nheads_q, batch)
        _bwd_kernel_dv[grid_dv](
            q, k, v, metadata.sm_scale, L, do, dv, o,
            q.stride(0), q.stride(1), q.stride(2), q.stride(3),
            k.stride(0), k.stride(1), k.stride(2), k.stride(3),
            v.stride(0), v.stride(1), v.stride(2), v.stride(3),
            dv.stride(0), dv.stride(1), dv.stride(2), dv.stride(3),
            o.stride(0), o.stride(1), o.stride(2), o.stride(3),
            metadata.cu_seqlens_q, metadata.cu_seqlens_k,
            HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_dim,
            MAX_SEQLENS_Q=metadata.max_seqlens_q, MAX_SEQLENS_K=metadata.max_seqlens_k,
            VARLEN=metadata.varlen, IS_CAUSAL=metadata.causal,
            BLOCK_M=cfg['BLOCK_M'], BLOCK_DMODEL=triton.next_power_of_2(head_dim),
            BLOCK_N=cfg['BLOCK_N'], num_warps=nwarps, num_stages=HW_NUM_STAGES)

        # dQ 内核
        grid_dq = (triton.cdiv(metadata.max_seqlens_q, cfg['BLOCK_M']), nheads_q, batch)
        _bwd_kernel_dq[grid_dq](
            q, k, v, metadata.sm_scale, L, do, dq, o,
            q.stride(0), q.stride(1), q.stride(2), q.stride(3),
            k.stride(0), k.stride(1), k.stride(2), k.stride(3),
            v.stride(0), v.stride(1), v.stride(2), v.stride(3),
            dq.stride(0), dq.stride(1), dq.stride(2), dq.stride(3),
            o.stride(0), o.stride(1), o.stride(2), o.stride(3),
            metadata.cu_seqlens_q, metadata.cu_seqlens_k,
            HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_dim,
            MAX_SEQLENS_Q=metadata.max_seqlens_q, MAX_SEQLENS_K=metadata.max_seqlens_k,
            VARLEN=metadata.varlen, IS_CAUSAL=metadata.causal,
            BLOCK_M=cfg['BLOCK_M'], BLOCK_DMODEL=triton.next_power_of_2(head_dim),
            BLOCK_N=cfg['BLOCK_N'], num_warps=nwarps, num_stages=HW_NUM_STAGES)

        metadata.log_profile(kernel="bwd", config=cfg, shape=q.shape)
        return dq, dk, dv, None, None

# ============================================================================
# L4: 对外接口
# ============================================================================
def _prepare_inputs(q, k, v, is_varlen):
    if is_varlen:
        return tuple(t.contiguous() for t in (q, k, v))
    return tuple(t.transpose(1, 2).contiguous() for t in (q, k, v))

def _prepare_output(o, is_varlen):
    return o if is_varlen else o.transpose(1, 2)

def flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False,
                     bias=None, alibi_slopes=None, **kwargs):
    q, k, v = _prepare_inputs(q, k, v, is_varlen=False)
    sm_scale = softmax_scale or q.shape[-1] ** -0.5
    metadata = AttentionMetadata(sm_scale=sm_scale, causal=causal,
                                  dropout_p=dropout_p, **kwargs)
    metadata.max_seqlens_q, metadata.max_seqlens_k = q.shape[2], k.shape[2]
    if bias is not None:
        metadata.set_bias(bias)
    if alibi_slopes is not None:
        metadata.set_alibi(alibi_slopes, q.shape[0], q.shape[1])
    return _prepare_output(TritonFlashAttention.apply(q, k, v, None, metadata), is_varlen=False)

def flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k,
                            max_seqlen_q, max_seqlen_k, dropout_p=0.0,
                            softmax_scale=None, causal=False, **kwargs):
    q, k, v = _prepare_inputs(q, k, v, is_varlen=True)
    sm_scale = softmax_scale or q.shape[-1] ** -0.5
    metadata = AttentionMetadata(sm_scale=sm_scale, causal=causal,
                                  dropout_p=dropout_p, **kwargs)
    metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k)
    metadata.max_seqlens_q, metadata.max_seqlens_k = max_seqlen_q, max_seqlen_k
    return _prepare_output(TritonFlashAttention.apply(q, k, v, None, metadata), is_varlen=True)

这是用于验证AI编程原理的验证版,多头注意力机制,采用协流变规则编制,与现有工程所有有差别,可用于学习或是教学。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐