闭包与泛型#

Rust 的零成本抽象——泛型、闭包和 trait 约束——在 GPU 上也能工作。 这是 cuda-oxide 最独特的能力之一:你可以编写适用于任何数值类型的 单个泛型 kernel,或者从主机传递一个闭包来自定义 GPU 行为, 全都没有运行时开销。

泛型 kernel#

Kernel 可以对类型和 trait 约束泛型化,就像任何 Rust 函数一样。 编译器将每个实例化单态化为单独的 PTX 入口点:

use cuda_device::{kernel, thread, DisjointSlice};
use core::ops::Mul;

#[kernel]
pub fn scale<T: Copy + Mul<Output = T>>(
    factor: T,
    input: &[T],
    mut out: DisjointSlice<T>,
) {
    let idx = thread::index_1d();
    let i = idx.get();
    if let Some(out_elem) = out.get_mut(idx) {
        *out_elem = input[i] * factor;
    }
}

PTX 命名#

每次单态化都会产生一个不同的 PTX 入口点。非泛型 kernel 保留其纯函数名。泛型 kernel(包括闭包泛型 kernel) 获得 _TID_<hex32> 后缀,其中 <hex32> 是泛型参数元组的 rustc 稳定 type-id hash,以 32 个小写十六进制字符呈现:

实例化

PTX 入口点名称

vecadd(非泛型)

vecadd

scale::<f32>

scale_TID_<hex32>

scale::<MyType>

scale_TID_<hex32>

map::<f32, _>(闭包)

map_TID_<hex32>

主机启动器和设备后端都会向同一个 rustc 调用请求相同的 hash, 因此字符串逐字节匹配。对元组(而非每个参数独立)进行 hash, 使得通信名称保持固定长度,无论 kernel 接受多少个泛型参数。 借用生命周期在 hash 前被擦除,因此 &'a T&'static T 对于相同的形态 T 产生相同的 hash。

启动泛型 kernel#

启动时,在生成的类型化方法上指定类型参数。这会强制具体的 实例化,并让加载器查找匹配的 PTX 入口点:

use cuda_core::LaunchConfig;

module
    .scale::<f32>(
        &stream,
        LaunchConfig::for_num_elems(N as u32),
        2.0f32,
        &input_dev,
        &mut output_dev,
    )
    .expect("Launch failed");

生成的方法强制 scale::<f32> 的单态化,因此该实例化会出现在 编译后的 PTX 中,即使它在 CPU 上从未被直接调用过。

主机闭包作为 kernel 参数#

cuda-oxide 支持将闭包从主机传递到 GPU。这使得强大的 map 风格模式成为可能,其中 kernel 的行为由函数参数化:

#[kernel]
pub fn map<F: Fn(i32) -> i32>(f: F, input: &[i32], mut out: DisjointSlice<i32>) {
    let idx = thread::index_1d();
    let i = idx.get();
    if let Some(out_elem) = out.get_mut(idx) {
        *out_elem = f(input[i]);
    }
}

用闭包启动:

let factor = 3i32;
module
    .map::<_>(&stream, config, move |x| x * factor, &input_dev, &mut output_dev)
    .expect("Launch failed");

闭包参数如何传递#

闭包作为一个值通过启动——而不是作为捕获字段的列表。 启动器推送单个驱动参数(整个闭包结构体以及所有捕获), kernel 将其接收为一个 byval .param

主机          factor = 3i32; cl = move |x| x * factor
              推送一个驱动参数 ─► 闭包结构体 { factor: i32 }

GPU kernel    .entry map_TID_<hex>(
                 .param .align 4 .b8 f[4],    ; 一个 byval 闭包
                 .param .u64 input_ptr,        ; 切片仍然是 (ptr, len)
                 .param .u64 input_len,
                 ...
               )

切片保持其 (ptr, len) 展开,因为该形态由主机启动辅助函数和 PTX 入口点布局共享。只有聚合按值参数(闭包和按值传递的用户结构体) 在 kernel 边界处作为一个 byval 传入。

没有捕获的闭包是零大小类型——后端完全去掉 .param, 主机启动器知道跳过它,使数据包保持对齐。

闭包的 PTX 命名#

闭包泛型 kernel 获得与其他泛型 kernel 相同的 _TID_<hex32> 后缀。 闭包的匿名类型是 hash 元组中的条目之一,因此两个不同的 闭包字面量——即使具有相同的 Fn 签名——产生两个不同的入口点:

闭包

PTX 入口点

move |x| x * factor(一个捕获)

map_TID_<hex_a>

move |x| x + offset(一个捕获)

map_TID_<hex_b>

Move 闭包与引用闭包#

move 关键字决定了捕获如何被传输到 GPU:

Move 闭包(推荐的默认选择)#

let factor = 3i32;
move |x| x * factor   // `factor` 被复制到闭包结构体中
  • 闭包结构体以值形式持有捕获({ factor: i32 })。

  • Kernel 将 factor 作为 byval 闭包的常规字段读取。

  • 主机变量可以在启动后被释放。

  • 在所有系统上都可以工作——不需要特殊的硬件支持。

引用闭包(HMM)#

let factor = 3i32;
|x| x * factor   // 闭包捕获 &factor
  • 闭包结构体包含指向 factor主机指针{ factor: &i32 })。

  • 整个闭包仍然作为一个 byval 参数传递;kernel 通过 **硬件管理内存(HMM)**解引用该主机指针,在访问时迁移主机页面。

  • 主机变量必须保持存活直到 kernel 完成。

  • 需要 HMM 支持(Turing+ GPU、Linux 6.1.24+、CUDA 12.2+)。

何时使用哪种#

场景

使用方式

小型标量捕获(数字、布尔值)

move(零复制开销)

大型结构体捕获

如果 kernel 多次读取则用 move;如果很少访问则用 HMM

原型制作

都可以;move 更具可移植性

主机与设备之间的共享可变状态

引用(HMM)-- 但要注意同步

小技巧

如果不确定,使用 move 闭包。它们更容易推理,适用于所有环境, 并且避免了共享主机/设备内存的同步风险。

Kernel 内闭包#

在设备代码中完全定义和调用的闭包以正常的 Rust 语义工作—— 不涉及主机/设备 ABI,因为一切都已在 GPU 上:

#[kernel]
pub fn apply_transform(input: &[f32], mut out: DisjointSlice<f32>) {
    let idx = thread::index_1d();

    let transform = |x: f32| -> f32 {
        let clamped = if x < 0.0 { 0.0 } else if x > 1.0 { 1.0 } else { x };
        clamped * clamped
    };

    if let Some(out_elem) = out.get_mut(idx) {
        *out_elem = transform(input[idx.get()]);
    }
}

Kernel 内闭包被编译器内联,零开销。它们对于在 kernel 内部 分解逻辑而不引入单独的设备函数很有用。

跨 crate kernel#

Kernel 可以在库 crate 中定义并从二进制 crate 启动:

// 在 lib crate `my_kernels` 中:
use cuda_device::{cuda_module, kernel, thread, DisjointSlice};

#[cuda_module]
pub mod kernels {
    use super::*;

    #[kernel]
    pub fn vecadd(a: &[f32], b: &[f32], mut c: DisjointSlice<f32>) {
        let idx = thread::index_1d();
        let i = idx.get();
        if let Some(c_elem) = c.get_mut(idx) {
            *c_elem = a[i] + b[i];
        }
    }
}
// 在二进制 crate 中:
use my_kernels::kernels;

let module = kernels::load(&ctx)?;
module
    .vecadd(&stream, config, &a, &b, &mut c)
    .expect("Launch failed");

编译器通过 #[kernel] 生成的标记 trait 处理跨 crate 的 kernel 发现。 类型化模块在编译时解析 PTX 名称并缓存加载的函数句柄。

小技巧

对于跨 crate 的泛型 kernel,单态化发生在调用方 crate(即知道具体类型的地方), 因此 PTX 作为二进制编译的一部分生成。