简介
在 Part 2 里,我们已经讨论了 Flash Attention 的基础 CUDA building blocks:Tensor Core mma 与高效内存传输(cp.async、ldmatrix)。这一章进入“组装阶段”:把这些低层组件拼成第一个完整、可运行的 Flash Attention kernel(Kernel 1)。
我们会按 3 个步骤推进:
- 先确定 CTAs / warps / threads 的分工;
- 基于底层指令构建更高层操作;
- 把所有模块拼进完整 kernel。
本文目标配置:
- forward pass only
- non-causal attention
d_head = 128- no dropout / no KV caching
- Q/K/V 序列长度相同
- 序列长度可被 block size 整除(通常
64~128) - 输入输出
bf16/fp16,softmax 用fp32
在该配置下,Kernel 1 在 RTX 3090 上可达到约 49% reference 性能。
Kernel Architecture Overview
Flash Attention kernel 遵循经典三段式:
Kernel Phases
- Prologue(初始化)
- 计算 tensor/warp 地址
- 执行一次性加载(例如
Q: GMEM -> SMEM -> RF) - 初始化 softmax 统计量(
m, l)和输出累加器
- Mainloop(迭代计算)
K: GMEM -> SMEM -> RF- 计算 attention score
- 执行 online softmax 并更新统计量
V: GMEM -> SMEM -> RF- 计算输出增量
- Epilogue(收尾)
- 完成 softmax 归一化
fp32 -> fp16/bf16O: RF -> SMEM -> GMEM
Implementation Challenges
落地时的三大技术点:
- Data Movement:跨层搬运(GMEM -> SMEM -> RF)与不同 layout 协调
- Math Ops:
mma+ warp primitives 的 GEMM / online softmax - Synchronization:线程与 warp/CTA 的同步边界
复杂度主要来自数据搬运:不同 tensor 对布局、访问模式、同步范围的要求并不一致。
Kernel Configuration
从 Kernel 1 到 Kernel 7,固定配置为:
B_r = 64B_c = 64d_head = 128n_warps = 4(每 CTA 128 线程)
这些参数与 m16n8k16 对齐,且能较稳地放进 SMEM / RF 预算。
CTA Work Distribution
输入 tensor 形状:
(batch_size, seq_len, n_heads, d_head)
切块后:
Q/O:按 query 维切成B_r行块K/V:按 key/value 维切成B_c行块
每个 CTA 负责一个 (sample, head, q_block),并遍历该 (sample, head) 对应全部 K/V blocks。
总 CTA 数为 n_samples * n_heads * T_r,其中 T_r = seq_len / B_r。
Kernel Arguments
forward_kernel.cuh
struct ForwardKernelArgs {
using index_t = int64_t;
void *__restrict__ Q;
void *__restrict__ K;
void *__restrict__ V;
void *__restrict__ O;
const index_t batch_stride;
const index_t seq_stride;
const index_t head_stride;
const index_t seq_len;
const index_t n_heads;
const int n_Q_blocks;
const int n_KV_blocks;
};
Grid Mapping
每个 CTA 处理一个 (64,128) query tile,grid 可看成 (sample, query_block, head) 的某种排列。核心问题:怎么映射 CTA 才更利于 cache?
同一 (sample, head) 下,不同 CTA 会读不同 Q / 写不同 O,但会共享同一批 K^{(j)} 与 V^{(j)} blocks。
因此希望这批 CTA 启动时间靠近,便于复用先前 CTA 放入 L2 的 K/V 数据。
CTA 线性启动顺序:
blockIdx.x + blockIdx.y * gridDim.x + gridDim.x * gridDim.y * blockIdx.z
固定 y,z 时,不同 x 的 CTA 在线性 ID 上连续,通常会“更一起”被调度。所以采用:
(x, y, z) -> (Q_block, head, batch)
forward_kernel.cuh
const int sample = blockIdx.z;
const int head = blockIdx.y;
const int q_seq_block = blockIdx.x;
| GPU | L2 Size | Unoptimized Hit Rate | Optimized Hit Rate | Perf Impact |
|---|---|---|---|---|
| RTX 3090 | 6MB | ~2% | ~98% | ~3% TFLOPs 提升 |
| A100 | 40MB | ~25.6% | ~92.6% | ~1% TFLOPs 提升 |
Warp-Level Work Distribution
进入 CTA 内部:B_r=64 且每 CTA 4 warps,所以每 warp 分到 64/4=16 行,处理子块 (16,128)。
关键分工:
Q/O:warp 内独立处理K/V:跨 warp 协作加载(每 warp 先各自加载 slice 到 SMEM,随后每个 warp 都要读完整块进 RF)
Q/O(Independent per Warp)
- 64 行被分为 4 个独立 slice
- 每个 warp 对其 slice 独立完成加载/存储与 GEMM
K/V(Cooperative across Warps)
- Loading:每个 warp 先搬自己的
(16,128)slice 到 SMEM - Sync:全 CTA 等待完整
(64,128)block 就绪 - Copy:每个 warp 再从 SMEM 读整块到 RF
后续(Kernel 9)会进一步优化到
Q相关加载也能更协作。
Data Movement
数据搬运是本章最复杂部分。实现策略分三层:
- Core ops:GMEM↔SMEM、SMEM->RF 原语
- Addressing:各 tensor / warp 的偏移计算
- Abstraction:
MatrixLDST封装统一接口
Configuration Structs
load_store.cuh
struct TileLayout {
const int row_fragments;
const int col_fragments;
};
struct TensorLDSTConfig {
const TileLayout GSM;
const TileLayout RF;
const bool transposed;
const int block_size;
const int smem_cols;
const int warp_ldst_rows;
const bool compute_over_entire_block;
};
Storage Layout
- SMEM:统一 row-major(便于 GMEM 加载)
- RF:大多 row-major;
V在 RF 侧使用 transpose 语义
LD/ST 操作规格
| From | To | Blocks | PTX Instr. / C++ | Warp-Wide Op Size | Thread Op Size | Thread Mapping | Register Shape | Notes |
|---|---|---|---|---|---|---|---|---|
| GMEM | SMEM | Q, K^{(j)}, V^{(j)} |
cp.async |
$(4,64)$ | $(1,8)$ | row-major | ||
| SMEM | RF | Q, K^{(j)}, V^{(j)} |
ldmatrix.x4 |
$(16,16)$ | $(1,8)$ | column-major | $(2,2)$ | V^{(j)} transpose |
| RF | SMEM | O |
standard (4B) | $(8,8)$ | $(1,2)$ | row-major | $(1,1)$ | |
| SMEM | GMEM | O |
standard (16B) | $(4,64)$ | $(1,8)$ | row-major |
Copying Between GMEM ↔ SMEM
cp.async 每次搬 16B;按 (4,64) warp 宽块组织。
实现做成双向模板(GM2SM 与 SM2GM):
template <typename T>
struct GM2SM_async {
__device__ constexpr void operator()(T *gmem, T *smem) {
cp_async<16>(smem, gmem);
}
};
template <typename T>
struct SM2GM {
__device__ constexpr void operator()(T *gmem, T *smem) {
reinterpret_cast<uint4 *>(gmem)[0] = reinterpret_cast<uint4 *>(smem)[0];
}
};
reinterpret_cast<uint4*> 是常见向量化访问写法,前提是对齐满足。
SMEM → RF
Q/K 与 V 的 RF 布局不同,因此拆成两类 helper:
copy_warp_fragment_SM2RF():普通ldmatrix_x4copy_warp_fragment_transposed_SM2RF():ldmatrix_x4_transpose
RF → SMEM
O 用标准 4B stores 从 RF 写回 SMEM,再由 SMEM 向量化写回 GMEM。
GMEM 地址计算
forward_kernel.cuh
const index_t sample_head_offset = sample * args.batch_stride + head * args.head_stride;
const index_t QO_gmem_block_offset = sample_head_offset + q_seq_block * CFG.B_r * gmem_seq_stride;
const index_t KV_gmem_block_offset = sample_head_offset;
Tensor Abstraction Layer
MatrixLDST 封装 GMEM↔SMEM↔RF 全流程,统一不同 tensor 的地址与 layout 差异。
tensor.cuh
template <TensorLDSTConfig ldst, typename value_t, typename index_t = int64_t>
struct MatrixLDST {
// ...
__forceinline__ __device__ constexpr void copy_GM2SM();
__forceinline__ __device__ constexpr void copy_SM2RF(int stage = 0, int tile_offset = 0);
__forceinline__ __device__ constexpr void copy_RF2SM();
__forceinline__ __device__ constexpr void copy_SM2GM();
};
Computing Operations
GEMM
基于 mma 的核心形式:
实现按 K -> M -> N 三层展开:
template <typename value_t, const int M_fragments, const int N_fragments,
const int K_fragments, typename accum_t = float>
__forceinline__ __device__ constexpr void warp_fragment_mma_f32_accum(...) {
// loop over k,m,n
// mma_m16n8k16_f32_accum(...)
}
Softmax(Thread-Level 视角)
softmax 在 fp32 中执行,包含 element-wise 与 row-wise 两类操作。
关键步骤:
- 初始化行统计:
m = -inf,l = 0 S乘1/sqrt(d_head)(dot-product scaling)- 用 warp shuffle 做行最大值归约
- 重标定
l与O_accum exp(S - m)- 更新行和
l(主循环内先不做跨线程归并) - epilogue 再完成
l的最终跨线程归约并归一化O
__shfl_xor_sync 的两次归约(offset 2 与 1)用于同一 row quartet 聚合。
Type Conversion
需要两处 fp32 -> 16-bit:
- 每轮把
S_accum转为P(给PV) - epilogue 把
O_accum转回输出 dtype
使用 paired conversion:__float22half2_rn / __float22bfloat162_rn。
Synchronization
不同 tensor 的同步范围不同:
Q/O多为 warp 内依赖,可用__syncwarp()K/V涉及跨 warp 协作,需__syncthreads()cp.async数据可见前必须cp_async_wait- 迭代间必须确保上一轮
ldmatrix全部完成,才允许覆写 SMEM
Kernel 1 组装
Prologue
- 计算
(sample, head, q_seq_block)与各 tensor gmem 偏移 - 构造
Q/K/V/O的MatrixLDST - 启动
Q.copy_GM2SM(),初始化O_accum/m/l cp_async_wait<0>() + __syncwarp()后把Q装入 RF
Mainloop
每轮 j:
K.copy_GM2SM()并前移 gmem 指针cp_async_wait + __syncthreads(Barrier 1)K.copy_SM2RF()S = QK^T- online softmax:
scale -> row_max -> rescale(l/O) -> exp -> row_sum S_accum -> P_b16V.copy_GM2SM()并前移 gmem 指针cp_async_wait + __syncthreads(Barrier 2)V.copy_SM2RF()O_accum += PV
Epilogue
final_softmax_normalization(O_accum, l)O_accum -> O_b16O_b16.copy_RF2SM()__syncwarp()O_b16.copy_SM2GM()
Occupancy 与资源预算
以 SM_86(RTX 3090)为例:
- 每 SM 最多 1536 resident threads;128 threads/CTA 本身不是限制项
- 编译后约
202 regs/thread - 目标是 2 CTAs/SM,因此需把 SMEM 控制在约
48KB/CTA
通过让访问时序不重叠的 tensor 共享 SMEM 切片,可从更高需求降到可接受范围。
最终得到约 8 resident warps/SM。
性能结果
在该 baseline 配置下:
- Kernel 1:约
33.28 TFLOPS - Reference:约
67.29 TFLOPS - 相对性能:约
49%
小结
Chapter3 的关键意义是搭起第一版完整链路:
从数据搬运、GEMM、online softmax 到同步与写回全部打通。尽管只是 baseline,但它定义了后续优化(swizzling、double buffering、instruction fusion、auto-tuning)的结构基础。