Tensor Memory Accelerator (TMA)#

从 Hopper 架构(SM 90)开始,NVIDIA GPU 包含一个专用的硬件单元——Tensor Memory Accelerator——它在全局内存和共享内存之间传输数据,不占用线程执行资源。block 中的每个线程不需要各自发出 load 指令,而是由单个线程将描述符(descriptor)交给 TMA 引擎,硬件异步执行整个传输。其他 127 个线程可以自由地计算、加载其他数据,或者仅仅在屏障处等待。

cuda-oxide 通过 TmaDescriptorcp_async_bulk_tensor_* 系列函数以及 ManagedBarrier 类型状态 API 暴露 TMA。本章介绍设置方法、拷贝模式,以及 TMA 如何与屏障集成以构建高效的加载/计算流水线。

参见

CUDA Programming Guide — Asynchronous Data Copies using TMA 获取硬件单元、swizzle 模式以及支持的张量维度的完整描述。


TMA 解决的问题#

共享内存章节中,block 中的每个线程都参与从全局内存加载一个分块:

TILE_A[ty * TILE + tx] = a[row * k + tile_offset + tx];

这能工作,但意味着 block 中所有 128(或 256,或 1024)个线程都花费时间计算地址和发出 load 指令。对于大分块,加载阶段可能主导 kernel 的运行时间。

TMA 用单条硬件指令替代了每线程的加载循环。一个线程说"将这个 2D 区域从全局内存拷贝到共享内存",TMA 引擎处理剩下的事情——包括地址计算、步长处理以及无 bank 冲突的 swizzle 写入。

advanced/images/tma-async-pipeline.svg

TMA 异步批量拷贝流水线。Host 构建一个 TmaDescriptor 编码张量布局。在 device 上,一个线程发出拷贝指令,所有线程在 mbarrier 上等待。TMA 引擎异步执行传输,完成后向屏障发出信号,线程继续从共享内存中计算。#


TmaDescriptor —— 张量映射#

TmaDescriptor 是一个 128 字节的不透明结构体,编码了 TMA 引擎需要知道的关于张量的一切:基地址、维度、元素类型、步长和 swizzle 模式。它在 host 端使用 CUDA 驱动 API 构建:

use cuda_device::tma::TmaDescriptor;

// 在 host 端(简化——实际调用有许多参数)
unsafe {
    cuTensorMapEncodeTiled(
        &mut desc as *mut TmaDescriptor as *mut _,
        CU_TENSOR_MAP_DATA_TYPE_FLOAT32,
        2,                       // 2D 张量
        global_ptr,              // 全局内存中的基地址
        &dims,                   // [rows, cols]
        &strides,                // 每个维度的字节步长
        &box_dims,               // 每次拷贝的分块大小
        &elem_strides,           // 盒内元素步长
        CU_TENSOR_MAP_INTERLEAVE_NONE,
        CU_TENSOR_MAP_SWIZZLE_128B, // 无 bank 冲突 swizzle
        CU_TENSOR_MAP_L2_PROMOTION_L2_128B,
        CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE,
    );
}

然后描述符作为普通参数传递给 kernel。在 device 端,它被视为 *const TmaDescriptor——不透明、只读,并且在其描述的分配生命周期内始终有效。

小技巧

Swizzle 模式是性能最关键的参数。SWIZZLE_128B 重新排列共享内存中的字节布局,使得来自不同 warp 的 128 字节访问命中不同的 bank。如果不使用 swizzle,朴素的 2D 分块加载通常会产生 bank 冲突。


全局到共享的拷贝(G2S)#

cp_async_bulk_tensor_*_g2s 函数发起一个从全局内存到共享内存目标的 TMA 拷贝。它们可用于 1D 到 5D 张量:

use cuda_device::tma::{cp_async_bulk_tensor_2d_g2s, TmaDescriptor};
use cuda_device::barrier::Barrier;

unsafe fn load_tile(
    smem_dst: *mut u8,
    desc: *const TmaDescriptor,
    tile_x: i32,
    tile_y: i32,
    bar: *mut Barrier,
) {
    cp_async_bulk_tensor_2d_g2s(smem_dst, desc, tile_x, tile_y, bar);
}

坐标(tile_xtile_y)指定要拷贝哪个分块,单位是描述符中编码的盒维度。屏障指针告诉 TMA 引擎在何处发出完成信号。

只有一个线程发出拷贝#

TMA 拷贝不是跨 block 的集合操作。恰好一个线程应当调用 cp_async_bulk_tensor_*_g2s。所有其他线程通过屏障机制参与:

if thread::threadIdx_x() == 0 {
    let token = bar.arrive_expect_tx(tile_bytes as u32);
    unsafe {
        cp_async_bulk_tensor_2d_g2s(dst, desc, x, y, bar.as_ptr() as *mut _);
    }
    bar.wait(token);
} else {
    let token = bar.arrive();
    bar.wait(token);
}

arrive_expect_tx(bytes) 告诉屏障,除了线程到达之外,还预期有 bytes 字节的事务完成。当所有线程都已到达并且 TMA 引擎已交付所有预期字节时,屏障被触发。


mbarrier 之舞#

TMA 完成跟踪依赖于 ManagedBarrier(或原始的 mbarrier_* 函数)。类型状态 API 在类型层面强制执行生命周期:

use cuda_device::barrier::{Barrier, ManagedBarrier, TmaBarrierHandle, Uninit, Ready};
use cuda_device::SharedArray;

#[kernel]
pub fn tma_load_kernel(desc: *const TmaDescriptor) {
    static mut BAR: SharedArray<Barrier, 1, 128> = SharedArray::UNINIT;

    let bar: TmaBarrierHandle<Ready> = unsafe {
        TmaBarrierHandle::<Uninit>::from_static(BAR.as_mut_ptr())
            .init_by(block_size, 0) // 线程 0 初始化,包含 fence + sync
    };

    // 发出 TMA 拷贝
    if thread::threadIdx_x() == 0 {
        let token = bar.arrive_expect_tx(TILE_BYTES as u32);
        unsafe {
            cp_async_bulk_tensor_2d_g2s(
                smem_ptr, desc, tile_x, tile_y, bar.as_ptr() as *mut _
            );
        }
        bar.wait(token);
    } else {
        let token = bar.arrive();
        bar.wait(token);
    }

    // 共享内存现在包含分块——可以安全读取
    // ...

    unsafe { bar.inval(); }
}

生命周期为:Uninitinit()Readyarrive() / wait()inval()Invalidated。类型系统防止在未初始化的屏障上调用 wait,或者在已失效的屏障上调用 arrive

关键屏障规则#

规则

原因

init_by 包含 fence_proxy_async_shared_cta()

TMA 引擎需要在共享内存中看到已初始化的屏障

arrive_expect_tx 在拷贝之前调用,而非之后

屏障必须在引擎开始写入之前知道预期的字节数

所有线程都必须到达(或被计入)

屏障在 expected_count 次到达 + TX 字节完成后触发

完成后调用 inval()(如果不再重用)

释放屏障硬件资源


共享到全局的拷贝(S2G)#

TMA 也支持反向传输。该模式使用提交组(commit group)而非屏障:

use cuda_device::tma::{
    cp_async_bulk_tensor_2d_s2g,
    cp_async_bulk_commit_group,
    cp_async_bulk_wait_group,
};

unsafe {
    cp_async_bulk_tensor_2d_s2g(smem_src, desc, tile_x, tile_y);
    cp_async_bulk_commit_group();
    cp_async_bulk_wait_group(0); // 等待所有未完成的组
}

S2G 拷贝不使用屏障。相反,commit_group 将未完成的拷贝打包成一个组,wait_group(n) 阻塞直到最多有 n 个组仍在传输中。wait_group(0) 等待所有组完成。


对齐要求#

共享内存中的 TMA 目标必须是128 字节对齐的。cuda-oxide 通过 SharedArrayDynamicSharedArray 上的 ALIGN 参数提供此功能:

static mut TILE: SharedArray<f32, 1024, 128> = SharedArray::UNINIT;

// 或者使用动态共享内存:
let dst: *mut u8 = DynamicSharedArray::<u8, 128>::get();

如果目标不是 128 字节对齐的,TMA 引擎将静默地产生不正确的结果或触发故障。没有运行时检查——对齐必须在构造时保证正确。


多播(Multicast)变体#

在启用集群的 Hopper+ 上,TMA 可以将单次全局加载多播到多个 CTA 的共享内存中,同时完成:

use cuda_device::tma::cp_async_bulk_tensor_2d_g2s_multicast;

unsafe {
    cp_async_bulk_tensor_2d_g2s_multicast(
        dst, desc, x, y, bar_ptr,
        cta_mask,  // u16 位掩码,指示集群中目标 CTA
    );
}

这对于 GEMM 风格的内核特别有用,其中 A 或 B 的同一个分块被多个线程 block 需要。不是每个 block 加载自己的副本,而是一次 TMA 操作服务于所有 block。

参见

集群编程了解完整的集群模型和分布式共享内存,它与 TMA 多播自然搭配。


TMA vs. 手动加载#

属性

手动线程加载

TMA 批量拷贝

所需线程数

block 中所有线程

1 个线程发出;硬件执行

地址计算

每线程(索引数学)

编码在描述符中

Bank 冲突的 swizzle

手动填充

描述符级 swizzle 模式

完成跟踪

sync_threads()

带 TX 跟踪的 mbarrier

与计算的重叠

不可能(线程正在加载)

加载下一个分块的同时计算当前分块

最低架构

任意 CUDA GPU

Hopper(SM 90+)

关键优势不是原始带宽——TMA 和手动加载可以达到相似的峰值吞吐量。优势在于线程利用率:当 TMA 加载下一个分块时,所有线程可以计算当前分块。这使得现代 GEMM 实现所依赖的多阶段软件流水线模式成为可能。

参见