矩阵乘法加速器#
现代 NVIDIA GPU 包含用于矩阵乘积累加(MMA)操作的专用硬件——通常称为 tensor core。这些单元在小型矩阵分块上计算 D = A × B + C,吞吐量远超标准浮点 ALU 所能达到的水平。一块 Hopper H100 提供超过 1000 TFLOPS 的 FP16 MMA,而同一芯片的标量 FP16 吞吐量大约为 200 TFLOPS。如果你的工作负载涉及矩阵乘法——而大多数深度学习、HPC 和信号处理工作负载确实涉及——tensor core 就是性能所在的地方。
cuda-oxide 提供了对两代矩阵加速器的访问:Hopper(SM 90)上的 WGMMA 和 Blackwell(SM 100)上的 tcgen05。本章涵盖两者、它们的编程模型以及它们如何连接到前面章节中的 TMA 和屏障机制。
参见
CUDA Programming Guide — Warpgroup Level Matrix Operations 了解 WGMMA 形状、元素类型和同步要求的硬件规格。
全局概览#
Hopper WGMMA 和 Blackwell tcgen05 的数据路径。两者都通过描述符从共享内存读取操作数。WGMMA 将结果累加到每线程寄存器(warpgroup-collective)。tcgen05 将结果累加到专用的 Tensor Memory (TMEM) 中,由单个线程发出。#
跨代演进遵循一个清晰的趋势:操作数更靠近计算单元,发出范围更广,程序员为更大的分块编写更少的指令。cuda-oxide 用按代划分的 API 而非一刀切的抽象来跟踪这一演进。
WGMMA — Hopper (SM 90)#
WGMMA(WarpGroup Matrix Multiply-Accumulate)是一种 warpgroup-collective 操作:4 个 warp(128 个线程)协作计算一个矩阵分块。操作数 A 和 B 通过 SMEM 描述符从共享内存读取,结果累加进每线程寄存器。
编程模型#
TMA 加载 A 和 B 的分块到共享内存(参见 Tensor Memory Accelerator)。
SMEM 描述符编码每个操作数的基地址、步长和 swizzle 模式。
WGMMA 指令消费描述符并产生累加器更新。该指令是异步的——它提交到一个屏障。
屏障等待确保在读取累加器之前 MMA 已完成。
支持的形状#
WGMMA 始终具有 M=64(行),N 和 K 取决于元素类型:
元素类型 |
K |
N 选项 |
|---|---|---|
f16, bf16 |
16 |
64, 128, 256 |
tf32 |
8 |
64, 128, 256 |
每条指令计算一个 64×N×K 的分块。对于更大的 K 维度,你在循环中发出多条 WGMMA 指令,累加到相同的寄存器分块中。
cuda-oxide API 概览#
cuda-oxide 通过直接映射到硬件指令的低级内联函数暴露 WGMMA。典型的使用模式:
use cuda_device::wgmma::{
make_smem_desc, wgmma_fence, wgmma_commit_group, wgmma_wait_group,
wgmma_mma_m64n64k16_f32_f16,
};
// TMA 已将 A 和 B 的分块加载到共享内存之后...
// 为已加载的分块构建 SMEM 描述符
let a_desc = unsafe { make_smem_desc(tile_a_ptr as *const u8) };
let b_desc = unsafe { make_smem_desc(tile_b_ptr as *const u8) };
// 累加器(4 个 warp × 每行 8 个 float = 64×64 分块)
let mut acc = [[0.0f32; 8]; 4];
// Fence + 发出 WGMMA — warpgroup 中的所有 128 个线程都参与
unsafe {
wgmma_fence();
wgmma_mma_m64n64k16_f32_f16(&mut acc, a_desc, b_desc);
wgmma_commit_group();
wgmma_wait_group::<0>(); // 等待所有未完成的组
}
// acc 中的累加器现在是有效的 — 存储、转换或传递给下一阶段
小技巧
WGMMA 通常与多阶段流水线配对使用:当 tensor core 处理第 k 个分块时,TMA 将第 k+1 个分块加载到第二个共享内存缓冲区中。TMA 和 MMA 的屏障是分开的,使得数据搬运和计算可以完全重叠。
tcgen05 — Blackwell (SM 100)#
Blackwell 引入了一种根本不同的矩阵加速器:tcgen05。关键创新包括:
单线程发出。 不是 warpgroup-collective 指令,而是由一个线程发出 MMA。硬件在内部分发工作。
Tensor Memory (TMEM)。 一个专用的片上累加器内存,独立于寄存器文件。TMEM 比寄存器更大,并且具有不同的访问特性。
更大的分块。 tcgen05 支持最大 256×256 的形状,vs WGMMA 的 64×256。
TMEM — Tensor Memory#
TMEM 是内存层次结构中的新层级,位于寄存器文件旁边,但专用于矩阵累加。它必须显式分配和释放:
use cuda_device::tcgen05::{TmemGuard, TmemUninit, TmemReady};
use cuda_device::SharedArray;
static mut TMEM_ADDR: SharedArray<u32, 1> = SharedArray::UNINIT;
// 分配 TMEM(warp-collective)
let tmem: TmemGuard<TmemReady, 128> = unsafe {
TmemGuard::<TmemUninit, 128>::from_static(TMEM_ADDR.as_mut_ptr())
.alloc()
};
// ... 使用 tmem 进行 MMA ...
// 释放(warp-collective,kernel 退出前必须执行)
unsafe { tmem.dealloc(); }
TmemGuard 使用类型状态:TmemUninit → alloc() → TmemReady → dealloc() → TmemDeallocated。类型系统防止在分配之前使用 TMEM 或忘记释放——后者会泄漏硬件资源并导致故障。
N_COLS 常量参数决定 TMEM 分块宽度。常见配置:
|
每 warp 的 TMEM |
用途 |
|---|---|---|
64 |
最小分配 |
窄分块 |
128 |
默认 |
标准 GEMM 分块 |
256 |
最大 |
宽分块,更高吞吐量 |
指令和 SMEM 描述符#
tcgen05 对每条 MMA 指令使用两个描述符:
指令描述符 — 编码 MMA 配置:
use cuda_device::tcgen05::{
Tcgen05InstructionDescriptor,
Tcgen05ElementType,
Tcgen05MmaShape,
};
let idesc = Tcgen05InstructionDescriptor::builder()
.shape(Tcgen05MmaShape::M128_N128)
.a_type(Tcgen05ElementType::F16)
.b_type(Tcgen05ElementType::F16)
.build();
SMEM 描述符 — 指向共享内存中的操作数分块:
use cuda_device::tcgen05::{Tcgen05SmemDescriptor, Tcgen05SwizzleMode};
let a_desc = Tcgen05SmemDescriptor::for_k_major(
smem_a_addr,
m, k,
2, // 每元素字节数(f16)
Tcgen05SwizzleMode::Swizzle128B,
);
发出 MMA#
一个线程发出 MMA 指令,然后所有线程在屏障上等待:
use cuda_device::tcgen05::{tcgen05_mma_f16, tcgen05_commit};
if thread::threadIdx_x() == 0 {
unsafe {
tcgen05_mma_f16(
tmem.raw_address(),
a_desc.raw(),
b_desc.raw(),
idesc.raw(),
true, // 启用累加到 D
);
tcgen05_commit(mma_barrier_ptr);
}
}
// 所有线程等待 MMA 完成
mma_barrier.wait(token);
Epilog — 从 TMEM 读取结果#
MMA 循环之后,累加器驻留在 TMEM 中。要将其移动到共享内存(以便后续 TMA 存储到全局内存),你从 TMEM 加载到寄存器,然后使用 stmatrix 写入共享内存:
use cuda_device::tcgen05::{
tcgen05_ld_16x256b_pure,
tcgen05_load_wait,
stmatrix_m8n8_x4,
TmemF32x4,
};
unsafe {
// 从 TMEM 加载一个 16×256 位的切片到寄存器
let regs: TmemF32x4 = tcgen05_ld_16x256b_pure(tmem.raw_address());
tcgen05_load_wait();
// 从寄存器存储到共享内存(warp-collective)
stmatrix_m8n8_x4(smem_ptr, regs[0], regs[1], regs[2], regs[3]);
}
stmatrix 是一个 warp-collective 操作——所有 32 个线程参与,每个线程贡献其寄存器切片以形成共享内存分块。
CTA-group-2 (CG2)#
Blackwell 还支持 CG2 模式,其中来自一个集群的两个 CTA 发出的 MMA 指令作为一对进行协调。这将有效分块宽度加倍,需要使用 _cg2 API 变体:
use cuda_device::tcgen05::{tcgen05_alloc_cg2, tcgen05_mma_f16_cg2};
CG2 和 CG1(标准)不能在同一个 kernel 中混用。
小技巧
tcgen05 仅在 Blackwell 数据中心 GPU(SM 100a)上可用。消费级 Blackwell(SM 120)使用较旧的 mma.sync 指令集。在接触 tcgen05 API 之前,请检查目标架构。
选择合适的加速器#
特性 |
WGMMA (Hopper) |
tcgen05 (Blackwell DC) |
mma.sync (Ampere/消费级) |
|---|---|---|---|
发出模型 |
Warpgroup(128 线程) |
单线程 |
Warp(32 线程) |
累加器 |
寄存器 |
TMEM(专用) |
寄存器 |
最大分块(M×N) |
64×256 |
256×256 |
16×8 |
异步执行 |
是(commit+barrier) |
是(commit+barrier) |
同步 |
TMA 集成 |
原生 |
原生 + 多播 CG2 |
手动加载 |
集群支持 |
是 |
是 + CG2 |
否 |
最低 SM |
90(Hopper) |
100a(Blackwell DC) |
80(Ampere) |
对于大多数用户来说,选择由目标 GPU 决定。如果你正在编写一个面向多种架构的库,你需要编译时特性开关或运行时架构检测来选择正确的代码路径。
GEMM 演进之路#
这些加速器并非孤立存在。一个高性能的 GEMM kernel 结合了本章及之前各章介绍的每一项技术:
TMA 将分块从全局内存加载到共享内存(无线程加载)
屏障 跟踪每个阶段的 TMA 完成情况
MMA 指令(WGMMA 或 tcgen05)从共享内存消费分块
多阶段流水线 使第 k+1 分块的加载与第 k 分块的 MMA 重叠
Warp 专业化 将一些 warp 专门用于加载,另一些专门用于 MMA
Epilog(tcgen05)从 TMEM 加载、转换精度、通过 TMA 存储
这就是 cuBLAS 和 CUTLASS 内部使用的 kernel 结构。使用 cuda-oxide,你可以在 Rust 中构建同样的结构——使用 SharedArray 管理分块,使用 ManagedBarrier 进行同步,使用 MMA API 进行计算。
参见
共享内存与同步 — MMA 操作数的分块管理
Tensor Memory Accelerator — 向 tensor core 输送分块
集群编程 — 多 CTA MMA 的 CG2 模式和多播