Tensor Memory Accelerator (TMA)#
从 Hopper 架构(SM 90)开始,NVIDIA GPU 包含一个专用的硬件单元——Tensor Memory Accelerator——它在全局内存和共享内存之间传输数据,不占用线程执行资源。block 中的每个线程不需要各自发出 load 指令,而是由单个线程将描述符(descriptor)交给 TMA 引擎,硬件异步执行整个传输。其他 127 个线程可以自由地计算、加载其他数据,或者仅仅在屏障处等待。
cuda-oxide 通过 TmaDescriptor、cp_async_bulk_tensor_* 系列函数以及 ManagedBarrier 类型状态 API 暴露 TMA。本章介绍设置方法、拷贝模式,以及 TMA 如何与屏障集成以构建高效的加载/计算流水线。
参见
CUDA Programming Guide — Asynchronous Data Copies using TMA 获取硬件单元、swizzle 模式以及支持的张量维度的完整描述。
TMA 解决的问题#
在共享内存章节中,block 中的每个线程都参与从全局内存加载一个分块:
TILE_A[ty * TILE + tx] = a[row * k + tile_offset + tx];
这能工作,但意味着 block 中所有 128(或 256,或 1024)个线程都花费时间计算地址和发出 load 指令。对于大分块,加载阶段可能主导 kernel 的运行时间。
TMA 用单条硬件指令替代了每线程的加载循环。一个线程说"将这个 2D 区域从全局内存拷贝到共享内存",TMA 引擎处理剩下的事情——包括地址计算、步长处理以及无 bank 冲突的 swizzle 写入。
TMA 异步批量拷贝流水线。Host 构建一个 TmaDescriptor 编码张量布局。在 device 上,一个线程发出拷贝指令,所有线程在 mbarrier 上等待。TMA 引擎异步执行传输,完成后向屏障发出信号,线程继续从共享内存中计算。#
TmaDescriptor —— 张量映射#
TmaDescriptor 是一个 128 字节的不透明结构体,编码了 TMA 引擎需要知道的关于张量的一切:基地址、维度、元素类型、步长和 swizzle 模式。它在 host 端使用 CUDA 驱动 API 构建:
use cuda_device::tma::TmaDescriptor;
// 在 host 端(简化——实际调用有许多参数)
unsafe {
cuTensorMapEncodeTiled(
&mut desc as *mut TmaDescriptor as *mut _,
CU_TENSOR_MAP_DATA_TYPE_FLOAT32,
2, // 2D 张量
global_ptr, // 全局内存中的基地址
&dims, // [rows, cols]
&strides, // 每个维度的字节步长
&box_dims, // 每次拷贝的分块大小
&elem_strides, // 盒内元素步长
CU_TENSOR_MAP_INTERLEAVE_NONE,
CU_TENSOR_MAP_SWIZZLE_128B, // 无 bank 冲突 swizzle
CU_TENSOR_MAP_L2_PROMOTION_L2_128B,
CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE,
);
}
然后描述符作为普通参数传递给 kernel。在 device 端,它被视为 *const TmaDescriptor——不透明、只读,并且在其描述的分配生命周期内始终有效。
小技巧
Swizzle 模式是性能最关键的参数。SWIZZLE_128B 重新排列共享内存中的字节布局,使得来自不同 warp 的 128 字节访问命中不同的 bank。如果不使用 swizzle,朴素的 2D 分块加载通常会产生 bank 冲突。
全局到共享的拷贝(G2S)#
cp_async_bulk_tensor_*_g2s 函数发起一个从全局内存到共享内存目标的 TMA 拷贝。它们可用于 1D 到 5D 张量:
use cuda_device::tma::{cp_async_bulk_tensor_2d_g2s, TmaDescriptor};
use cuda_device::barrier::Barrier;
unsafe fn load_tile(
smem_dst: *mut u8,
desc: *const TmaDescriptor,
tile_x: i32,
tile_y: i32,
bar: *mut Barrier,
) {
cp_async_bulk_tensor_2d_g2s(smem_dst, desc, tile_x, tile_y, bar);
}
坐标(tile_x、tile_y)指定要拷贝哪个分块,单位是描述符中编码的盒维度。屏障指针告诉 TMA 引擎在何处发出完成信号。
只有一个线程发出拷贝#
TMA 拷贝不是跨 block 的集合操作。恰好一个线程应当调用 cp_async_bulk_tensor_*_g2s。所有其他线程通过屏障机制参与:
if thread::threadIdx_x() == 0 {
let token = bar.arrive_expect_tx(tile_bytes as u32);
unsafe {
cp_async_bulk_tensor_2d_g2s(dst, desc, x, y, bar.as_ptr() as *mut _);
}
bar.wait(token);
} else {
let token = bar.arrive();
bar.wait(token);
}
arrive_expect_tx(bytes) 告诉屏障,除了线程到达之外,还预期有 bytes 字节的事务完成。当所有线程都已到达并且 TMA 引擎已交付所有预期字节时,屏障被触发。
mbarrier 之舞#
TMA 完成跟踪依赖于 ManagedBarrier(或原始的 mbarrier_* 函数)。类型状态 API 在类型层面强制执行生命周期:
use cuda_device::barrier::{Barrier, ManagedBarrier, TmaBarrierHandle, Uninit, Ready};
use cuda_device::SharedArray;
#[kernel]
pub fn tma_load_kernel(desc: *const TmaDescriptor) {
static mut BAR: SharedArray<Barrier, 1, 128> = SharedArray::UNINIT;
let bar: TmaBarrierHandle<Ready> = unsafe {
TmaBarrierHandle::<Uninit>::from_static(BAR.as_mut_ptr())
.init_by(block_size, 0) // 线程 0 初始化,包含 fence + sync
};
// 发出 TMA 拷贝
if thread::threadIdx_x() == 0 {
let token = bar.arrive_expect_tx(TILE_BYTES as u32);
unsafe {
cp_async_bulk_tensor_2d_g2s(
smem_ptr, desc, tile_x, tile_y, bar.as_ptr() as *mut _
);
}
bar.wait(token);
} else {
let token = bar.arrive();
bar.wait(token);
}
// 共享内存现在包含分块——可以安全读取
// ...
unsafe { bar.inval(); }
}
生命周期为:Uninit → init() → Ready → arrive() / wait() → inval() → Invalidated。类型系统防止在未初始化的屏障上调用 wait,或者在已失效的屏障上调用 arrive。
关键屏障规则#
规则 |
原因 |
|---|---|
|
TMA 引擎需要在共享内存中看到已初始化的屏障 |
|
屏障必须在引擎开始写入之前知道预期的字节数 |
所有线程都必须到达(或被计入) |
屏障在 |
完成后调用 |
释放屏障硬件资源 |
共享到全局的拷贝(S2G)#
TMA 也支持反向传输。该模式使用提交组(commit group)而非屏障:
use cuda_device::tma::{
cp_async_bulk_tensor_2d_s2g,
cp_async_bulk_commit_group,
cp_async_bulk_wait_group,
};
unsafe {
cp_async_bulk_tensor_2d_s2g(smem_src, desc, tile_x, tile_y);
cp_async_bulk_commit_group();
cp_async_bulk_wait_group(0); // 等待所有未完成的组
}
S2G 拷贝不使用屏障。相反,commit_group 将未完成的拷贝打包成一个组,wait_group(n) 阻塞直到最多有 n 个组仍在传输中。wait_group(0) 等待所有组完成。
对齐要求#
共享内存中的 TMA 目标必须是128 字节对齐的。cuda-oxide 通过 SharedArray 和 DynamicSharedArray 上的 ALIGN 参数提供此功能:
static mut TILE: SharedArray<f32, 1024, 128> = SharedArray::UNINIT;
// 或者使用动态共享内存:
let dst: *mut u8 = DynamicSharedArray::<u8, 128>::get();
如果目标不是 128 字节对齐的,TMA 引擎将静默地产生不正确的结果或触发故障。没有运行时检查——对齐必须在构造时保证正确。
多播(Multicast)变体#
在启用集群的 Hopper+ 上,TMA 可以将单次全局加载多播到多个 CTA 的共享内存中,同时完成:
use cuda_device::tma::cp_async_bulk_tensor_2d_g2s_multicast;
unsafe {
cp_async_bulk_tensor_2d_g2s_multicast(
dst, desc, x, y, bar_ptr,
cta_mask, // u16 位掩码,指示集群中目标 CTA
);
}
这对于 GEMM 风格的内核特别有用,其中 A 或 B 的同一个分块被多个线程 block 需要。不是每个 block 加载自己的副本,而是一次 TMA 操作服务于所有 block。
参见
集群编程了解完整的集群模型和分布式共享内存,它与 TMA 多播自然搭配。
TMA vs. 手动加载#
属性 |
手动线程加载 |
TMA 批量拷贝 |
|---|---|---|
所需线程数 |
block 中所有线程 |
1 个线程发出;硬件执行 |
地址计算 |
每线程(索引数学) |
编码在描述符中 |
Bank 冲突的 swizzle |
手动填充 |
描述符级 swizzle 模式 |
完成跟踪 |
|
带 TX 跟踪的 |
与计算的重叠 |
不可能(线程正在加载) |
加载下一个分块的同时计算当前分块 |
最低架构 |
任意 CUDA GPU |
Hopper(SM 90+) |
关键优势不是原始带宽——TMA 和手动加载可以达到相似的峰值吞吐量。优势在于线程利用率:当 TMA 加载下一个分块时,所有线程可以计算当前分块。这使得现代 GEMM 实现所依赖的多阶段软件流水线模式成为可能。