安全模型#

一个 GPU 内核会同时运行数千个线程,这些线程在同一时刻看到相同的内存。在 CPU 上,Rust 通过所有权和借用机制来防止数据竞争 —— 一个可变引用、不允许别名、在编译时强制生效。而在 GPU 上,每个 SM 上有 2048 个线程,它们全都从同一个函数启动,全都指向同一个输出缓冲区。borrow checker 并非为此而设计。

cuda-oxide 以分层的方式解决这个问题。常见情况 —— 一个线程写一个元素 —— 在构造上就是安全的,无需 unsafe。不常见的情况 —— shared memory、warp shuffle、硬件 intrinsic —— 需要 unsafe 并附带文档化的契约。而前沿情况 —— TMA、Tensor Core、Cluster 级通信 —— 则是完全手动的,与其所控制硬件的复杂度相匹配。

本章将解释这一模型,逐层讲解,并准确告诉你何时需要 unsafe 以及为什么。


三层架构#

cuda-oxide 根据编译器能够验证的程度,将内核安全性划分为三个层级:

层级

描述

是否需要 unsafe

Tier 1

构造上安全 —— 类型系统防止误用

Tier 2

显式 unsafe 并附带明确的安全契约

是,限域

Tier 3

原始硬件 intrinsic —— 完全由用户负责

是,普遍存在

大多数应用内核完全处于 Tier 1,或介于 Tier 1 和 Tier 2 之间。Tier 3 是为在 CUTLASS 或 Triton IR 层面进行构建的性能工程师准备的。如果你在写一个 vecadd、一个 GEMM 或一个规约,你很少会离开 Tier 2。


Tier 1:默认安全#

核心思想:DisjointSlice<T> + ThreadIndex#

主要的安全抽象是一对类型,它们共同保证在调用处无需 unsafe 即可进行无竞争的并行写入:

  • ThreadIndex<'kernel, IndexSpace> —— 一个围绕 usize 的不透明凭证,没有公开的构造函数。唯一获取它的方式是通过可信函数(index_1dindex_2d::<S>),这些函数从硬件内置变量threadIdxblockIdxblockDim)——运行时在启动时填充的只读特殊寄存器 —— 推导出值。ThreadIndex!Send + !Sync + !Copy + !Clone,并且 'kernel 生命周期将其绑定到内核体内的栈局部作用域,因此它无法通过 shared memory 在线程间传递,也无法比内核存活更久。

  • DisjointSlice<T, IndexSpace> —— 一个类似 slice 的类型,其 get_mut() 方法只接受 IndexSpace 与其自身匹配的 ThreadIndex。返回 Option<&mut T> —— 对于越界索引返回 None

将它们组合在一起,你就可以写出零 unsafe 的内核:

use cuda_device::{kernel, DisjointSlice};

#[kernel]
pub fn vecadd(a: &[f32], b: &[f32], mut c: DisjointSlice<f32>) {
    if let Some((c_elem, idx)) = c.get_mut_indexed() {
        let i = idx.get();
        *c_elem = a[i] + b[i];
    }
}

get_mut_indexed 是一体化的调用形式:它在一次调用中生成每个线程的凭证并将其解析为 &mut T。当你需要索引来对多个 slice 进行并行运算时,也可以使用显式的两步形式:

#[kernel]
pub fn vecadd(a: &[f32], b: &[f32], mut c: DisjointSlice<f32>) {
    let idx = thread::index_1d();
    if let Some(c_elem) = c.get_mut(idx) {
        let i = idx.get();
        *c_elem = a[i] + b[i];
    }
}

安全性源自四个事实:

  1. index_1d() 为每个线程生成唯一的值(硬件保证:threadIdx.x < blockDim.x,因此线性索引 blockIdx.x * blockDim.x + threadIdx.x 在整个 grid 中是唯一的)。

  2. get_mut() 进行边界检查 —— 超出范围的线程得到 None

  3. IndexSpace 参数将每个凭证与其布局绑定:DisjointSlice<T, Index2D<128>> 不会接受 ThreadIndex<'_, Index2D<256>> —— 混合跨度是编译错误。

  4. 凭证是 !Send + !Sync + !Copy + !Clone 且作用域为 'kernel,因此线程无法通过 shared memory 洗白其他线程的索引。

borrow checker 看到每个线程一个 &mut T。硬件保证索引不相交。类型系统将两者绑定在一起。

可信索引函数#

ThreadIndex 只有在其创建函数可信时才可信。以下是 cuda-oxide 提供的构造函数:

函数

公式

返回类型

备注

index_1d()

blockIdx.x * blockDim.x + threadIdx.x

ThreadIndex<'kernel, Index1D>

每个线程无条件唯一

index_2d::<S>()

row * S + col

Option<ThreadIndex<'kernel, Index2D<S>>>

常量跨度;混合跨度是编译错误

unsafe index_2d_runtime(s)

row * s + col

Option<ThreadIndex<'kernel, Runtime2DIndex>>

调用者断言每个线程使用了相同的 s

index_2d_row()

blockIdx.y * blockDim.y + threadIdx.y

usize

分量访问器,不是凭证构造函数

index_2d_col()

blockIdx.x * blockDim.x + threadIdx.x

usize

分量访问器,不是凭证构造函数

index_2d_row()index_2d_col() 返回普通的 usize —— 它们为你提供了用于算术运算的分量,但不能用来索引 DisjointSlice。只有线性化的结果才能获得 ThreadIndex

index_2d 如何实现类型安全#

index_2d::<S>() 通过 const 泛型参数化行跨度。凭证以 ThreadIndex<'kernel, Index2D<S>> 类型返回,而 DisjointSlice<T, Index2D<S>> 只接受这个确切的 S。两个线程无法以不同的跨度生成凭证并将其输入同一个 slice —— 类型系统会拒绝:

#[kernel]
pub fn ok(mut out: DisjointSlice<u32, Index2D<128>>) {
    if let Some(idx) = thread::index_2d::<128>() {  // 匹配
        if let Some(slot) = out.get_mut(idx) { *slot = 1; }
    }
}

#[kernel]
pub fn rejected(mut out: DisjointSlice<u32, Index2D<128>>) {
    if let Some(idx) = thread::index_2d::<256>() { // ⛔ Index2D<256> != Index2D<128>
        if let Some(slot) = out.get_mut(idx) { *slot = 1; }
    }
}

凭证还是 !Send + !Sync + !Copy + !Clone,并且其 'kernel 生命周期是从宏注入的栈局部作用域借用的 —— 因此线程不能将其 ThreadIndex 存放在 shared memory 中让相邻线程稍后拾取,且凭证不能比内核体存活更久。

真正的运行时跨度:unsafe index_2d_runtime#

有些内核确实在启动时接收跨度(例如只在 host 端知道的矩阵维度)。对于这些情况,有一个对应的 Runtime2DIndex 凭证和一个 unsafe 构造函数:

let idx = unsafe { thread::index_2d_runtime(n)? };

unsafe 就是契约:向同一个 DisjointSlice<T, Runtime2DIndex> 提供 Runtime2DIndex 的每个内核线程都必须使用了相同的 n。类型系统无法证明这一点 —— 在不同运行时跨度下生成的两个 ThreadIndex<'_, Runtime2DIndex> 值具有相同的类型。如果你能在编译时确定跨度,就优先使用 index_2d::<S>()。如果不能,index_2d_runtime 上的 unsafe 关键字就是标记,表示安全义务由你承担。

GEMM 模式#

对于常量跨度的 2D 内核(常见情况 —— 分块 GEMM、stencil 内核、具有固定通道数的图像内核),const 泛型形式是安全的默认选择:

const STRIDE: usize = 1024;     // C 是 M x STRIDE

#[kernel]
pub fn gemm(a: &[f32], b: &[f32], mut c: DisjointSlice<f32, Index2D<STRIDE>>, m: u32) {
    let row = thread::index_2d_row();
    if let Some((c_elem, _)) = c.get_mut_indexed() {
        // col < STRIDE 由 `Some` 保证 —— 无需手动检查
        if row < m as usize {
            // ... 计算点积放入局部累加器 `sum` ...
            *c_elem = alpha * sum + beta * (*c_elem);
        }
    }
}

对于运行时跨度,相同的形式可以工作,但线性化步骤是显式的且需要 unsafe

#[kernel]
pub fn gemm_runtime(a: &[f32], b: &[f32], mut c: DisjointSlice<f32, Runtime2DIndex>, m: u32, n: u32) {
    let n = n as usize;             // 一个绑定,一个跨度值
    let row = thread::index_2d_row();

    // SAFETY: 每个内核线程看到相同的 `n`(内核参数)。
    if let Some(c_idx) = unsafe { thread::index_2d_runtime(n) } {
        if row < m as usize {
            // ... 计算点积 ...
            if let Some(c_elem) = c.get_mut(c_idx) {
                *c_elem = alpha * sum + beta * (*c_elem);
            }
        }
    }
}

index_2d_runtimeif let Some 替代了你在 CUDA C++ 中需要手动写的 col < n 守卫。row < m 检查仍然保留,因为它防止从输入矩阵读取垃圾数据。

什么使内核成为 Tier 1#

当满足以下条件时,内核是完全安全的 —— Tier 1:

  1. 所有可变输出访问都通过 DisjointSlice::get_mut(ThreadIndex) 进行

  2. 所有输入都是共享不可变引用(&[T]

  3. 没有 shared memory,没有原始指针,没有超出线程索引范围的 intrinsic

此层级的示例包括 vecaddhelper_fngenerichost_closure 以及 gemmasync_mlp 示例中的朴素 GEMM 内核。


Tier 2:限域的 unsafe#

并非每个内核都符合“一个线程,一个输出元素”的模式。当线程需要协作 —— 通过快速片上内存共享数据、在 warp 内的 lane 之间通信或执行原子更新 —— 你就需要 unsafe。Tier 2 的关键属性是 unsafe限域的可审计的:每个块都有文档化的安全契约,而内核的其余部分保持安全。

Shared Memory#

Shared memory 是快速的、片上的,并且对一个 block 中的每个线程都是可见的。最后一个属性正是它需要 unsafe 的原因 —— borrow checker 无法推理 256 个线程向同一个 static mut 数组写入:

static mut TILE: SharedArray<f32, 256> = SharedArray::UNINIT;

unsafe { TILE[ty * TILE_SIZE + tx] = value; }

thread::sync_threads();

let neighbor = unsafe { TILE[other_idx] };

契约:确保没有来自并发线程的冲突写入,且没有适当的同步。sync_threads() 屏障是使这项工作成立的工具 —— 它保证所有线程在任意线程读取之前都已完成写入。

API

安全义务

SharedArray<T, N>

通过 static mut 访问。没有同步时不允许冲突写入。

DynamicSharedArray<T>

相同规则,但大小在启动时通过 LaunchConfig::shared_mem_bytes 设置。

参见

Shared Memory 与同步 获取完整讲解:分块、bank conflict、动态分配和双缓冲流水线。

Warp Intrinsic#

Warp 级原语让 warp 内的线程可以在完全不接触内存的情况下交换数据 —— 寄存器到寄存器传输,由硬件协调。它们是 unsafe 的,因为硬件不检查线程收敛:如果你传入一个包含已分歧线程的掩码,你会得到 undefined behavior(通常是静默挂起,这比崩溃更糟)。

API

安全义务

shfl_syncshfl_up_syncshfl_down_syncshfl_xor_sync

源 lane 必须活跃;掩码必须包含调用线程

ballot_syncany_syncall_sync

掩码中的所有线程必须已收敛

activemask

结果仅在调用点有意义

参见

Warp 级编程 了解使用 warp intrinsic 的 shuffle 模式、规约和前缀和。

屏障与生命周期#

ManagedBarrier 的类型状态 API 将屏障生命周期(Uninit -> Ready -> Invalidated)编码在类型系统中,因此你不会在从未初始化的屏障上等待,或使用已被无效化的屏障。init()inval() 转换仍然需要 unsafe,因为它们与硬件交互,但类型状态在编译时防止了最常见的错误。

API

安全义务

mbarrier_init

必须由恰好一个线程调用;屏障必须在 shared memory 中

mbarrier_arrive / mbarrier_wait

屏障必须已初始化;token 必须匹配

ManagedBarrier(类型状态)

init()inval() 需要 unsafe;状态机在编译时强制执行

原子操作#

一旦你拥有有效的原子引用,原子操作本身是安全的。unsafe 的表面在于构造 —— 从原始指针创建 DeviceAtomicU32 需要调用者保证指针有效且正确对齐:

let atom = unsafe { DeviceAtomicU32::new(ptr) };
atom.fetch_add(1, Ordering::Relaxed);  // 安全调用

未经检查的 Slice 访问#

当“一个线程,一个元素”模型不适用时 —— 例如,在 warp 级规约中只有 lane 0 写入结果 —— DisjointSlice::get_unchecked_mut(usize) 提供了一个逃生舱口:

if warp::lane_id() == 0 {
    let warp_idx = gid.get() / 32;
    // SAFETY: 每个 warp 只有 lane 0 写入;warp 索引是唯一的
    unsafe { *out.get_unchecked_mut(warp_idx) = sum; }
}

安全义务与 ThreadIndex 系统自动强制执行的相同:索引在范围内,没有两个线程共享相同的索引。区别在于你自己证明它,而不是让类型系统为你做这件事。


Tier 3:原始硬件#

栈的底层是原始硬件 intrinsic —— 直接与特定 GPU 架构上的特定功能单元对话的 API。每个调用都是 unsafe 的,安全契约复杂且依赖于架构,文档更多地存在于 PTX ISA 手册而非 Rust 文档注释中。

特性

关键 API

架构

TMA(Tensor Memory Accelerator)

tma_load_2dtma_store_2dTmaDescriptor

sm_90+ (Hopper)

tcgen05(Tensor Core Gen 5)

tcgen05_mmatcgen05_commitTensorMemoryHandle

sm_120 (Blackwell)

WGMMA(Warpgroup MMA)

wgmma_mma_asyncwgmma_commit_groupwgmma_wait_group

sm_90+ (Hopper)

Cluster

cluster_rankmap_shared_rankcluster_barrier_arrive

sm_90+ (Hopper)

CLC(Cluster Launch Control)

clc_prefetchclc_query_channel

sm_120 (Blackwell)

TMEM(Tensor Memory)

TmemGuard(类型状态)、tmem_alloctmem_dealloc

sm_120 (Blackwell)

如果你在编写应用级内核,你不应该需要 Tier 3 API。它们是为那些构建下一个 CUTLASS 的人准备的 —— 对于这些人,cuda-oxide 提供了与 CUDA C++ 中内联 PTX 相同的硬件访问能力,同时 Rust 的类型系统可作为(但不强制执行的)护栏。

参见

Tensor Memory Accelerator矩阵乘法加速器、 以及 Cluster 编程 了解 Tier 3 特性的详细内容。


Borrow Checker 为你带来了什么#

cuda-oxide 不是 DSL 或宏系统 —— 它在你的内核代码上运行真正的 rustc 前端。这意味着 Rust 在 CPU 上提供的每项安全保证在 GPU 上也同样强制执行:

保证

工作原理

所有权与借用

生命周期错误、use-after-free 和混叠违规在编译时捕获

安全并行写入

DisjointSlice<T> + ThreadIndex —— 类型级别的证明,证明写入不会竞争

显式 unsafe 限域

原始指针访问需要 unsafe,使义务可见且可审计

收敛属性强制执行

同步原语(barrier、fence、shuffle)在 IR 中标记为 convergent,阻止优化器在控制流中移动或复制它们

前三项是标准 Rust。第四项是 GPU 特有的:CUDA 的 bar.sync、fence 和 warp shuffle 指令不能被编译器复制或重新排序。cuda-oxide 在 IR 中将它们标记为 convergent,以便 LLVM 的优化 pass 不会动它们。


难题#

Rust 的 borrow checker 是为单线程所有权设计的,并以 Send/Sync 用于 CPU 并发。SIMT 执行引入了 borrow checker 从未被教导去推理的模式。以下是 cuda-oxide 目前强制执行的诚实的统计 —— 以及这些问题是可解决的原因。

线程分歧控制流#

Rustc 的 JumpThreading MIR 优化会将函数调用复制到 if 语句的两个分支中 —— 在 CPU 上这是一个完全正确的优化,但它会破坏 GPU 屏障语义,因为 block 中的所有线程必须在同一个 bar.sync 指令处收敛。cuda-oxide 目前对 device 代码禁用 JumpThreading(-Z mir-enable-passes=-JumpThreading)。一个正确的解决方案是教会编译器理解收敛要求,使其可以围绕这些要求进行优化,而不是完全禁用该 pass。

Shared Memory 访问模式#

Borrow checker 无法推理线程 0 写 smem[0] 而线程 1 写 smem[1] 是否安全 —— 它看到 &mut smem 并拒绝它。DisjointSlice 解决了唯一索引写入模式,但不解决协作模式,如规约、扫描或生产者/消费者流水线,其中多个线程有意地在阶段之间同步后访问重叠区域。

Warp 级收敛#

shfl_syncballot_sync 等操作要求参与掩码中命名的所有线程在调用点确实已收敛。类型系统目前无法强制执行这一点。如果线程已经分歧而你又传入了一个完整掩码,你会得到静默挂起 —— 最糟糕的一类 bug,因为没有崩溃也没有错误消息,只有一个永远无法完成的内核。

内存空间感知#

GPU 内存具有不同的地址空间 —— global、shared、local、TMEM。一个指向 shared memory 的 &mut 对 block 中的每个线程可见;一个指向 local memory 的 &mut 仅对一个线程私有。Borrow checker 将它们视为相同。这是保守的(它拒绝了一些安全程序),但绝不是不健全的(它不会接受不安全的程序)。尽管如此,一个内存空间感知的 borrow checker 可以接受更多无需 unsafe 的程序。

为什么这些问题是可解决的#

构建块已经存在于 Rust 的类型系统中。它们需要被扩展,而不是被重新发明:

想法

它解决的问题

执行资源感知的类型

函数标注其执行层级(grid / block / warp / thread)。在分歧分支中的 barrier 调用成为编译错误。

内存视图

泛化的并行访问模式 —— 类似 DisjointSlice,但覆盖分块、条纹、转置和组合布局。类型检查的无竞争大规模写入。

针对同步的扩展 borrow checking

静态强制 barrier 不能被遗忘、放置在分歧控制流中或被优化器复制。类型系统中的收敛。

所有这些都是编译时分析。生成的 PTX 与你手写的内容完全相同 —— 安全网在代码生成时消失。零运行时开销。

cuda-oxide 处于有利位置,可以逐步实现这一点。真正的 rustc borrow checker 已经在 device 代码上运行。IR 基础设施(pliron dialect)支持 GPU 感知的分析 pass。从 MIR 到 PTX 的完整编译流水线在我们的控制之下。而且每项新的安全检查都是叠加的 —— 现有内核继续编译,而新内核获得更强的保证。


编写安全内核:速查表#

默认路径#

对于大多数内核,从这里开始:

#[kernel]
pub fn my_kernel(input: &[f32], mut output: DisjointSlice<f32>) {
    if let Some((out, idx)) = output.get_mut_indexed() {
        *out = transform(input[idx.get()]);
    }
}

对于常量跨度的 2D,参数化 slice 并请求常量索引:

#[kernel]
pub fn tile_kernel(mut output: DisjointSlice<f32, Index2D<1024>>) {
    if let Some((out, _idx)) = output.get_mut_indexed() {
        *out = ...;
    }
}

规则:

  • 对所有可变输出使用 DisjointSlice

  • 对所有只读输入使用 &[T]

  • 对于 1D grid,默认使用 get_mut_indexed()。如果你需要索引来对多个 slice 进行算术运算,回退到显式配对:let idx = thread::index_1d(); slice.get_mut(idx)

  • 对于常量跨度的 2D grid,将 slice 参数化为 DisjointSlice<T, Index2D<S>> 并使用 get_mut_indexed()thread::index_2d::<S>()。不匹配的跨度是编译错误。

  • 对于运行时跨度,使用带有 Runtime2DIndex 标记的 slice 的 unsafe { thread::index_2d_runtime(n) }unsafe 是每个线程都使用了相同的 n 的契约。

  • 始终通过 get_mut() / get_mut_indexed() 进行边界检查(两者都返回 Option)。

如果你的内核在编译时没有 unsafe,那么它在构造上就是无竞争的。

何时需要 unsafe#

模式

原因

缓解措施

Shared memory

多个线程访问同一个 static mut

在跨线程读取之前使用 sync_threads() 进行同步

Warp shuffle

线程收敛不被编译器检查

对于完整 warp 操作使用 FULL_MASK;为部分掩码写文档

原子操作

从原始指针构造

包装在辅助函数中;原子操作本身是安全的

非均匀写入

并非每个线程都写入自己的索引

使用 get_unchecked_mut 并附带文档化的唯一性论证

硬件 intrinsic

复杂、架构特定的契约

遵循 PTX ISA 文档;在目标硬件上测试

SAFETY 注释#

对于每个 unsafe 块,文档化为什么不变量成立。不是代码做什么 —— 代码已经说明了 —— 而是为什么这个特定用法是安全的:

// SAFETY: 每个 warp 只有 lane 0 执行此分支。
// Warp 索引 (gid / 32) 在各个 warp 之间是唯一的,因此没有两个
// 线程写入同一个输出元素。
if warp::lane_id() == 0 {
    let warp_idx = gid.get() / 32;
    unsafe { *partial_sums.get_unchecked_mut(warp_idx) = warp_sum; }
}

这不是形式主义。当一个内核在凌晨两点发生数据竞争,而你正盯着一份 compute-sanitizer 日志时,过去的你留下的 safety 注释是通往 bug 的最快路径。

小技巧

如果你无法为某个 unsafe 块写出有说服力的 SAFETY 注释,这是一个信号,表明不变量实际上并未被维护。重构代码直到论证变得显而易见,或者改用安全 API。


总结#

属性

状态

Device 代码上的 borrow checker

已强制(真正的 rustc 前端)

安全 1D 并行写入(DisjointSlice + index_1d

已强制

安全 2D 并行写入 —— 常量跨度

已强制(Index2D<S> 不匹配是编译错误)

安全 2D 并行写入 —— 运行时跨度

调用者通过 unsafe index_2d_runtime 断言

ThreadIndex 不可跨线程传递(smem 洗白)

已强制(!Send + !Sync + !Copy + !Clone + 'kernel

&mut [T] 内核参数

未强制 —— 将任何 &mut 参数视为 unsafe

Shared memory、intrinsic 的显式 unsafe

已强制(Rust 语言规则)

同步原语的 convergent 属性

已强制(IR 级别的 convergent 标记)

Warp 操作的线程收敛

未强制(运行时义务)

内存空间感知(shared 与 global)

未强制(未来工作)

&mut [T] 作为内核参数是下一个突出的缺口:允许宏今天接受该类型,但运行时布局(每个线程看到相同的底层指针)意味着通过 &mut data[i] 从两个不同的线程写入是同一种别名,这正是 DisjointSlice 存在所要防止的。在宏彻底拒绝 &mut [T](或将其重写为 DisjointSlice)之前,将任何接受它的内核视为其中每一行都是 unsafe

安全模型旨在让常见情况默认安全,同时为其他所有情况提供显式的逃生舱口。编写你的内核,让类型系统捕获竞争,把 unsafe 留给那些你真的知道编译器不知道的东西的部分。