安全模型#
一个 GPU 内核会同时运行数千个线程,这些线程在同一时刻看到相同的内存。在 CPU 上,Rust 通过所有权和借用机制来防止数据竞争 —— 一个可变引用、不允许别名、在编译时强制生效。而在 GPU 上,每个 SM 上有 2048 个线程,它们全都从同一个函数启动,全都指向同一个输出缓冲区。borrow checker 并非为此而设计。
cuda-oxide 以分层的方式解决这个问题。常见情况 —— 一个线程写一个元素 —— 在构造上就是安全的,无需 unsafe。不常见的情况 —— shared memory、warp shuffle、硬件 intrinsic —— 需要 unsafe 并附带文档化的契约。而前沿情况 —— TMA、Tensor Core、Cluster 级通信 —— 则是完全手动的,与其所控制硬件的复杂度相匹配。
本章将解释这一模型,逐层讲解,并准确告诉你何时需要 unsafe 以及为什么。
三层架构#
cuda-oxide 根据编译器能够验证的程度,将内核安全性划分为三个层级:
层级 |
描述 |
是否需要 |
|---|---|---|
Tier 1 |
构造上安全 —— 类型系统防止误用 |
否 |
Tier 2 |
显式 |
是,限域 |
Tier 3 |
原始硬件 intrinsic —— 完全由用户负责 |
是,普遍存在 |
大多数应用内核完全处于 Tier 1,或介于 Tier 1 和 Tier 2 之间。Tier 3 是为在 CUTLASS 或 Triton IR 层面进行构建的性能工程师准备的。如果你在写一个 vecadd、一个 GEMM 或一个规约,你很少会离开 Tier 2。
Tier 1:默认安全#
核心思想:DisjointSlice<T> + ThreadIndex#
主要的安全抽象是一对类型,它们共同保证在调用处无需 unsafe 即可进行无竞争的并行写入:
ThreadIndex<'kernel, IndexSpace>—— 一个围绕usize的不透明凭证,没有公开的构造函数。唯一获取它的方式是通过可信函数(index_1d、index_2d::<S>),这些函数从硬件内置变量(threadIdx、blockIdx、blockDim)——运行时在启动时填充的只读特殊寄存器 —— 推导出值。ThreadIndex是!Send + !Sync + !Copy + !Clone,并且'kernel生命周期将其绑定到内核体内的栈局部作用域,因此它无法通过 shared memory 在线程间传递,也无法比内核存活更久。DisjointSlice<T, IndexSpace>—— 一个类似 slice 的类型,其get_mut()方法只接受IndexSpace与其自身匹配的ThreadIndex。返回Option<&mut T>—— 对于越界索引返回None。
将它们组合在一起,你就可以写出零 unsafe 的内核:
use cuda_device::{kernel, DisjointSlice};
#[kernel]
pub fn vecadd(a: &[f32], b: &[f32], mut c: DisjointSlice<f32>) {
if let Some((c_elem, idx)) = c.get_mut_indexed() {
let i = idx.get();
*c_elem = a[i] + b[i];
}
}
get_mut_indexed 是一体化的调用形式:它在一次调用中生成每个线程的凭证并将其解析为 &mut T。当你需要索引来对多个 slice 进行并行运算时,也可以使用显式的两步形式:
#[kernel]
pub fn vecadd(a: &[f32], b: &[f32], mut c: DisjointSlice<f32>) {
let idx = thread::index_1d();
if let Some(c_elem) = c.get_mut(idx) {
let i = idx.get();
*c_elem = a[i] + b[i];
}
}
安全性源自四个事实:
index_1d()为每个线程生成唯一的值(硬件保证:threadIdx.x < blockDim.x,因此线性索引blockIdx.x * blockDim.x + threadIdx.x在整个 grid 中是唯一的)。get_mut()进行边界检查 —— 超出范围的线程得到None。IndexSpace参数将每个凭证与其布局绑定:DisjointSlice<T, Index2D<128>>不会接受ThreadIndex<'_, Index2D<256>>—— 混合跨度是编译错误。凭证是
!Send + !Sync + !Copy + !Clone且作用域为'kernel,因此线程无法通过 shared memory 洗白其他线程的索引。
borrow checker 看到每个线程一个 &mut T。硬件保证索引不相交。类型系统将两者绑定在一起。
可信索引函数#
ThreadIndex 只有在其创建函数可信时才可信。以下是 cuda-oxide 提供的构造函数:
函数 |
公式 |
返回类型 |
备注 |
|---|---|---|---|
|
|
|
每个线程无条件唯一 |
|
|
|
常量跨度;混合跨度是编译错误 |
|
|
|
调用者断言每个线程使用了相同的 |
|
|
|
分量访问器,不是凭证构造函数 |
|
|
|
分量访问器,不是凭证构造函数 |
index_2d_row() 和 index_2d_col() 返回普通的 usize —— 它们为你提供了用于算术运算的分量,但不能用来索引 DisjointSlice。只有线性化的结果才能获得 ThreadIndex。
index_2d 如何实现类型安全#
index_2d::<S>() 通过 const 泛型参数化行跨度。凭证以 ThreadIndex<'kernel, Index2D<S>> 类型返回,而 DisjointSlice<T, Index2D<S>> 只接受这个确切的 S。两个线程无法以不同的跨度生成凭证并将其输入同一个 slice —— 类型系统会拒绝:
#[kernel]
pub fn ok(mut out: DisjointSlice<u32, Index2D<128>>) {
if let Some(idx) = thread::index_2d::<128>() { // 匹配
if let Some(slot) = out.get_mut(idx) { *slot = 1; }
}
}
#[kernel]
pub fn rejected(mut out: DisjointSlice<u32, Index2D<128>>) {
if let Some(idx) = thread::index_2d::<256>() { // ⛔ Index2D<256> != Index2D<128>
if let Some(slot) = out.get_mut(idx) { *slot = 1; }
}
}
凭证还是 !Send + !Sync + !Copy + !Clone,并且其 'kernel 生命周期是从宏注入的栈局部作用域借用的 —— 因此线程不能将其 ThreadIndex 存放在 shared memory 中让相邻线程稍后拾取,且凭证不能比内核体存活更久。
真正的运行时跨度:unsafe index_2d_runtime#
有些内核确实在启动时接收跨度(例如只在 host 端知道的矩阵维度)。对于这些情况,有一个对应的 Runtime2DIndex 凭证和一个 unsafe 构造函数:
let idx = unsafe { thread::index_2d_runtime(n)? };
unsafe 就是契约:向同一个 DisjointSlice<T, Runtime2DIndex> 提供 Runtime2DIndex 的每个内核线程都必须使用了相同的 n。类型系统无法证明这一点 —— 在不同运行时跨度下生成的两个 ThreadIndex<'_, Runtime2DIndex> 值具有相同的类型。如果你能在编译时确定跨度,就优先使用 index_2d::<S>()。如果不能,index_2d_runtime 上的 unsafe 关键字就是标记,表示安全义务由你承担。
GEMM 模式#
对于常量跨度的 2D 内核(常见情况 —— 分块 GEMM、stencil 内核、具有固定通道数的图像内核),const 泛型形式是安全的默认选择:
const STRIDE: usize = 1024; // C 是 M x STRIDE
#[kernel]
pub fn gemm(a: &[f32], b: &[f32], mut c: DisjointSlice<f32, Index2D<STRIDE>>, m: u32) {
let row = thread::index_2d_row();
if let Some((c_elem, _)) = c.get_mut_indexed() {
// col < STRIDE 由 `Some` 保证 —— 无需手动检查
if row < m as usize {
// ... 计算点积放入局部累加器 `sum` ...
*c_elem = alpha * sum + beta * (*c_elem);
}
}
}
对于运行时跨度,相同的形式可以工作,但线性化步骤是显式的且需要 unsafe:
#[kernel]
pub fn gemm_runtime(a: &[f32], b: &[f32], mut c: DisjointSlice<f32, Runtime2DIndex>, m: u32, n: u32) {
let n = n as usize; // 一个绑定,一个跨度值
let row = thread::index_2d_row();
// SAFETY: 每个内核线程看到相同的 `n`(内核参数)。
if let Some(c_idx) = unsafe { thread::index_2d_runtime(n) } {
if row < m as usize {
// ... 计算点积 ...
if let Some(c_elem) = c.get_mut(c_idx) {
*c_elem = alpha * sum + beta * (*c_elem);
}
}
}
}
index_2d_runtime 的 if let Some 替代了你在 CUDA C++ 中需要手动写的 col < n 守卫。row < m 检查仍然保留,因为它防止从输入矩阵读取垃圾数据。
什么使内核成为 Tier 1#
当满足以下条件时,内核是完全安全的 —— Tier 1:
所有可变输出访问都通过
DisjointSlice::get_mut(ThreadIndex)进行所有输入都是共享不可变引用(
&[T])没有 shared memory,没有原始指针,没有超出线程索引范围的 intrinsic
此层级的示例包括 vecadd、helper_fn、generic、host_closure 以及 gemm 和 async_mlp 示例中的朴素 GEMM 内核。
Tier 2:限域的 unsafe#
并非每个内核都符合“一个线程,一个输出元素”的模式。当线程需要协作 —— 通过快速片上内存共享数据、在 warp 内的 lane 之间通信或执行原子更新 —— 你就需要 unsafe。Tier 2 的关键属性是 unsafe 是限域的和可审计的:每个块都有文档化的安全契约,而内核的其余部分保持安全。
Warp Intrinsic#
Warp 级原语让 warp 内的线程可以在完全不接触内存的情况下交换数据 —— 寄存器到寄存器传输,由硬件协调。它们是 unsafe 的,因为硬件不检查线程收敛:如果你传入一个包含已分歧线程的掩码,你会得到 undefined behavior(通常是静默挂起,这比崩溃更糟)。
API |
安全义务 |
|---|---|
|
源 lane 必须活跃;掩码必须包含调用线程 |
|
掩码中的所有线程必须已收敛 |
|
结果仅在调用点有意义 |
参见
Warp 级编程 了解使用 warp intrinsic 的 shuffle 模式、规约和前缀和。
屏障与生命周期#
ManagedBarrier 的类型状态 API 将屏障生命周期(Uninit -> Ready -> Invalidated)编码在类型系统中,因此你不会在从未初始化的屏障上等待,或使用已被无效化的屏障。init() 和 inval() 转换仍然需要 unsafe,因为它们与硬件交互,但类型状态在编译时防止了最常见的错误。
API |
安全义务 |
|---|---|
|
必须由恰好一个线程调用;屏障必须在 shared memory 中 |
|
屏障必须已初始化;token 必须匹配 |
|
|
原子操作#
一旦你拥有有效的原子引用,原子操作本身是安全的。unsafe 的表面在于构造 —— 从原始指针创建 DeviceAtomicU32 需要调用者保证指针有效且正确对齐:
let atom = unsafe { DeviceAtomicU32::new(ptr) };
atom.fetch_add(1, Ordering::Relaxed); // 安全调用
未经检查的 Slice 访问#
当“一个线程,一个元素”模型不适用时 —— 例如,在 warp 级规约中只有 lane 0 写入结果 —— DisjointSlice::get_unchecked_mut(usize) 提供了一个逃生舱口:
if warp::lane_id() == 0 {
let warp_idx = gid.get() / 32;
// SAFETY: 每个 warp 只有 lane 0 写入;warp 索引是唯一的
unsafe { *out.get_unchecked_mut(warp_idx) = sum; }
}
安全义务与 ThreadIndex 系统自动强制执行的相同:索引在范围内,没有两个线程共享相同的索引。区别在于你自己证明它,而不是让类型系统为你做这件事。
Tier 3:原始硬件#
栈的底层是原始硬件 intrinsic —— 直接与特定 GPU 架构上的特定功能单元对话的 API。每个调用都是 unsafe 的,安全契约复杂且依赖于架构,文档更多地存在于 PTX ISA 手册而非 Rust 文档注释中。
特性 |
关键 API |
架构 |
|---|---|---|
TMA(Tensor Memory Accelerator) |
|
sm_90+ (Hopper) |
tcgen05(Tensor Core Gen 5) |
|
sm_120 (Blackwell) |
WGMMA(Warpgroup MMA) |
|
sm_90+ (Hopper) |
Cluster |
|
sm_90+ (Hopper) |
CLC(Cluster Launch Control) |
|
sm_120 (Blackwell) |
TMEM(Tensor Memory) |
|
sm_120 (Blackwell) |
如果你在编写应用级内核,你不应该需要 Tier 3 API。它们是为那些构建下一个 CUTLASS 的人准备的 —— 对于这些人,cuda-oxide 提供了与 CUDA C++ 中内联 PTX 相同的硬件访问能力,同时 Rust 的类型系统可作为(但不强制执行的)护栏。
参见
Tensor Memory Accelerator、 矩阵乘法加速器、 以及 Cluster 编程 了解 Tier 3 特性的详细内容。
Borrow Checker 为你带来了什么#
cuda-oxide 不是 DSL 或宏系统 —— 它在你的内核代码上运行真正的 rustc 前端。这意味着 Rust 在 CPU 上提供的每项安全保证在 GPU 上也同样强制执行:
保证 |
工作原理 |
|---|---|
所有权与借用 |
生命周期错误、use-after-free 和混叠违规在编译时捕获 |
安全并行写入 |
|
显式 |
原始指针访问需要 |
收敛属性强制执行 |
同步原语(barrier、fence、shuffle)在 IR 中标记为 |
前三项是标准 Rust。第四项是 GPU 特有的:CUDA 的 bar.sync、fence 和 warp shuffle 指令不能被编译器复制或重新排序。cuda-oxide 在 IR 中将它们标记为 convergent,以便 LLVM 的优化 pass 不会动它们。
难题#
Rust 的 borrow checker 是为单线程所有权设计的,并以 Send/Sync 用于 CPU 并发。SIMT 执行引入了 borrow checker 从未被教导去推理的模式。以下是 cuda-oxide 目前不强制执行的诚实的统计 —— 以及这些问题是可解决的原因。
线程分歧控制流#
Rustc 的 JumpThreading MIR 优化会将函数调用复制到 if 语句的两个分支中 —— 在 CPU 上这是一个完全正确的优化,但它会破坏 GPU 屏障语义,因为 block 中的所有线程必须在同一个 bar.sync 指令处收敛。cuda-oxide 目前对 device 代码禁用 JumpThreading(-Z mir-enable-passes=-JumpThreading)。一个正确的解决方案是教会编译器理解收敛要求,使其可以围绕这些要求进行优化,而不是完全禁用该 pass。
Shared Memory 访问模式#
Borrow checker 无法推理线程 0 写 smem[0] 而线程 1 写 smem[1] 是否安全 —— 它看到 &mut smem 并拒绝它。DisjointSlice 解决了唯一索引写入模式,但不解决协作模式,如规约、扫描或生产者/消费者流水线,其中多个线程有意地在阶段之间同步后访问重叠区域。
Warp 级收敛#
shfl_sync 和 ballot_sync 等操作要求参与掩码中命名的所有线程在调用点确实已收敛。类型系统目前无法强制执行这一点。如果线程已经分歧而你又传入了一个完整掩码,你会得到静默挂起 —— 最糟糕的一类 bug,因为没有崩溃也没有错误消息,只有一个永远无法完成的内核。
内存空间感知#
GPU 内存具有不同的地址空间 —— global、shared、local、TMEM。一个指向 shared memory 的 &mut 对 block 中的每个线程可见;一个指向 local memory 的 &mut 仅对一个线程私有。Borrow checker 将它们视为相同。这是保守的(它拒绝了一些安全程序),但绝不是不健全的(它不会接受不安全的程序)。尽管如此,一个内存空间感知的 borrow checker 可以接受更多无需 unsafe 的程序。
为什么这些问题是可解决的#
构建块已经存在于 Rust 的类型系统中。它们需要被扩展,而不是被重新发明:
想法 |
它解决的问题 |
|---|---|
执行资源感知的类型 |
函数标注其执行层级(grid / block / warp / thread)。在分歧分支中的 barrier 调用成为编译错误。 |
内存视图 |
泛化的并行访问模式 —— 类似 |
针对同步的扩展 borrow checking |
静态强制 barrier 不能被遗忘、放置在分歧控制流中或被优化器复制。类型系统中的收敛。 |
所有这些都是编译时分析。生成的 PTX 与你手写的内容完全相同 —— 安全网在代码生成时消失。零运行时开销。
cuda-oxide 处于有利位置,可以逐步实现这一点。真正的 rustc borrow checker 已经在 device 代码上运行。IR 基础设施(pliron dialect)支持 GPU 感知的分析 pass。从 MIR 到 PTX 的完整编译流水线在我们的控制之下。而且每项新的安全检查都是叠加的 —— 现有内核继续编译,而新内核获得更强的保证。
编写安全内核:速查表#
默认路径#
对于大多数内核,从这里开始:
#[kernel]
pub fn my_kernel(input: &[f32], mut output: DisjointSlice<f32>) {
if let Some((out, idx)) = output.get_mut_indexed() {
*out = transform(input[idx.get()]);
}
}
对于常量跨度的 2D,参数化 slice 并请求常量索引:
#[kernel]
pub fn tile_kernel(mut output: DisjointSlice<f32, Index2D<1024>>) {
if let Some((out, _idx)) = output.get_mut_indexed() {
*out = ...;
}
}
规则:
对所有可变输出使用
DisjointSlice。对所有只读输入使用
&[T]。对于 1D grid,默认使用
get_mut_indexed()。如果你需要索引来对多个 slice 进行算术运算,回退到显式配对:let idx = thread::index_1d(); slice.get_mut(idx)。对于常量跨度的 2D grid,将 slice 参数化为
DisjointSlice<T, Index2D<S>>并使用get_mut_indexed()或thread::index_2d::<S>()。不匹配的跨度是编译错误。对于运行时跨度,使用带有
Runtime2DIndex标记的 slice 的unsafe { thread::index_2d_runtime(n) }。unsafe是每个线程都使用了相同的n的契约。始终通过
get_mut()/get_mut_indexed()进行边界检查(两者都返回Option)。
如果你的内核在编译时没有 unsafe,那么它在构造上就是无竞争的。
何时需要 unsafe#
模式 |
原因 |
缓解措施 |
|---|---|---|
Shared memory |
多个线程访问同一个 |
在跨线程读取之前使用 |
Warp shuffle |
线程收敛不被编译器检查 |
对于完整 warp 操作使用 |
原子操作 |
从原始指针构造 |
包装在辅助函数中;原子操作本身是安全的 |
非均匀写入 |
并非每个线程都写入自己的索引 |
使用 |
硬件 intrinsic |
复杂、架构特定的契约 |
遵循 PTX ISA 文档;在目标硬件上测试 |
SAFETY 注释#
对于每个 unsafe 块,文档化为什么不变量成立。不是代码做什么 —— 代码已经说明了 —— 而是为什么这个特定用法是安全的:
// SAFETY: 每个 warp 只有 lane 0 执行此分支。
// Warp 索引 (gid / 32) 在各个 warp 之间是唯一的,因此没有两个
// 线程写入同一个输出元素。
if warp::lane_id() == 0 {
let warp_idx = gid.get() / 32;
unsafe { *partial_sums.get_unchecked_mut(warp_idx) = warp_sum; }
}
这不是形式主义。当一个内核在凌晨两点发生数据竞争,而你正盯着一份 compute-sanitizer 日志时,过去的你留下的 safety 注释是通往 bug 的最快路径。
小技巧
如果你无法为某个 unsafe 块写出有说服力的 SAFETY 注释,这是一个信号,表明不变量实际上并未被维护。重构代码直到论证变得显而易见,或者改用安全 API。
总结#
属性 |
状态 |
|---|---|
Device 代码上的 borrow checker |
已强制(真正的 |
安全 1D 并行写入( |
已强制 |
安全 2D 并行写入 —— 常量跨度 |
已强制( |
安全 2D 并行写入 —— 运行时跨度 |
调用者通过 |
|
已强制( |
|
未强制 —— 将任何 |
Shared memory、intrinsic 的显式 |
已强制(Rust 语言规则) |
同步原语的 convergent 属性 |
已强制(IR 级别的 |
Warp 操作的线程收敛 |
未强制(运行时义务) |
内存空间感知(shared 与 global) |
未强制(未来工作) |
&mut [T] 作为内核参数是下一个突出的缺口:允许宏今天接受该类型,但运行时布局(每个线程看到相同的底层指针)意味着通过 &mut data[i] 从两个不同的线程写入是同一种别名,这正是 DisjointSlice 存在所要防止的。在宏彻底拒绝 &mut [T](或将其重写为 DisjointSlice)之前,将任何接受它的内核视为其中每一行都是 unsafe。
安全模型旨在让常见情况默认安全,同时为其他所有情况提供显式的逃生舱口。编写你的内核,让类型系统捕获竞争,把 unsafe 留给那些你真的知道编译器不知道的东西的部分。