添加新的内建函数#

你想教 cuda-oxide 一项新的 GPU 技巧。也许 NVIDIA 刚发布了一条新指令,或者你需要一个现有的 PTX 操作但还没有人接入。好消息:这个过程是机械化的。五个 crate,五个步骤,一旦你做过一次,大约三十分钟就能完成。

本章通过两个真实示例——一个非常简单,一个带有一些转折——逐步展示完整流水线,以便你能看到每个阶段具体发生了什么。


五阶段流水线#

每个 GPU 内建函数在编译器中经历相同的路径:

cuda-device          用户编写:  thread::threadIdx_x()
    │
    ▼
mir-importer         编译器看到调用,发出 `dialect-nvvm` 操作
    │
    ▼
dialect-nvvm         该操作作为已验证的 IR 节点存在于此
    │
    ▼
mir-lower            将 `dialect-nvvm` 操作转换为 `dialect-llvm` 操作
    │
    ▼
dialect-llvm         导出文本 LLVM IR  →  llc 将其转为 PTX

mir-lower 阶段,你选择两种策略之一:

策略

适用场景

示例

LLVM 内建函数调用

LLVM 已有内置的内建函数

threadIdx_x、warp shuffle、barrier

内联 PTX 汇编

不存在 LLVM 内建函数,或你需要对 PTX 指令的精确控制

trapwgmmatcgen05mbarrier

以下演示了这两种策略。


示例 1:threadIdx_x(简单案例)#

threadIdx_x() 是 GPU 内建函数中的"Hello, World":零参数,一个 u32 结果,直接映射到单个 LLVM NVVM 内建函数。如果你能理解这个示例,你就能添加任何简单的内建函数。

阶段 1 -- 在 cuda-device 中声明#

文件: crates/cuda-device/src/thread.rs

#[inline(never)]
pub fn threadIdx_x() -> u32 {
    unreachable!("threadIdx_x called outside CUDA kernel context")
}

两条在你理解其中的技巧之前可能看起来奇怪的规定:

  • #[inline(never)] 保持该函数在 MIR 中作为独立的调用可见。如果 rustc 内联了它,编译器将看到的是 unreachable!() 而非可以拦截的调用。那将是……不太有帮助的。

  • 函数体是 unreachable!(),因为此函数从不会真正执行。编译器在任何 GPU 代码执行之前将整个调用替换为 dialect-nvvm 操作。把这个函数看作一个占位符——一个"亲爱的编译器,请在此处插入 GPU 指令"的便条。

小技巧

在函数上方的注释中记录该内建函数映射到什么 LLVM IR 或 PTX。未来的你(或未来的贡献者)会感谢现在的你。

阶段 2 -- 定义 dialect-nvvm 操作#

文件: crates/dialect-nvvm/src/ops/thread.rs

#[pliron_op(name = "nvvm.read_ptx_sreg_tid_x", dialect = "nvvm", format)]
pub struct ReadPtxSregTidXOp;

impl Verify for ReadPtxSregTidXOp {
    fn verify(&self, ctx: &Context) -> Result<(), Error> {
        let op = &*self.get_operation().deref(ctx);
        if op.get_num_operands() != 0 {
            return verify_err!(op.loc(), "expected 0 operands");
        }
        if op.get_num_results() != 1 {
            return verify_err!(op.loc(), "expected 1 result");
        }
        Ok(())
    }
}

然后注册它,让 pliron 知道该操作存在:

pub(super) fn register(ctx: &mut Context) {
    ReadPtxSregTidXOp::register(ctx, ReadPtxSregTidXOp::parser_fn);
}

Verify trait 在早期捕获结构错误。如果某处意外地用两个操作数创建了这个操作,验证会以明确的错误消息失败,而非在三个阶段之后产生无用的 PTX。

阶段 3 -- 在 mir-importer 中识别#

文件: crates/mir-importer/src/translator/terminator/mod.rs

当翻译器处理 MIR 时,每个函数调用都经过 try_dispatch_intrinsic()。被调用函数的完全限定域名(FQDN)通过 extract_func_info()CrateDef::name() 获取,产生类似 cuda_device::thread::threadIdx_x 的路径。match 检查此 FQDN:

match name {
    "cuda_device::threadIdx_x" | "cuda_device::thread::threadIdx_x" => {
        Ok(Some(helpers::emit_nvvm_intrinsic(
            ctx,
            ReadPtxSregTidXOp::get_concrete_op_info(),
            destination, target, block_ptr, prev_op,
            value_map, block_map, loc,
        )?))
    }
    // ... 数百个其他内建函数 ...
}

我们同时匹配重新导出的名称(cuda_device::threadIdx_x)和完整模块路径(cuda_device::thread::threadIdx_x),因此无论用户如何导入该函数,两者都能工作。FQDN 在匹配中按原样使用——在内建函数检查之前不会进行 ::__ 的转换。

emit_nvvm_intrinsic() 辅助函数是一个适用于任何零参数、单结果 NVVM 内建函数的通用函数。它创建操作,将结果存储在值映射中,并发出到下一个基本块的分支。对于简单的内建函数,你永远不需要编写自定义发射器。

阶段 4 -- 降级到 dialect-llvm#

文件: crates/mir-lower/src/convert/intrinsics/basic.rs

该操作实现了 MirToLlvmConversion 操作接口。在 convert/interface_impls.rs 中,该 impl 分发到转换函数:

impl MirToLlvmConversion for ReadPtxSregTidXOp {
    fn rewrite(ctx, rewriter, op, operands_info) -> Result<()> {
        basic::convert_read_tid_x(ctx, rewriter, op, operands_info)
    }
}

然后 basic::convert_read_tid_x 发出 LLVM 内建函数调用:

pub fn convert_read_tid_x(
    ctx: &mut Context,
    rewriter: &mut DialectConversionRewriter,
    op: Ptr<Operation>,
    _operands_info: &OperandsInfo,
) -> Result<()> {
    let intrinsic_name = "llvm_nvvm_read_ptx_sreg_tid_x";
    let i32_ty = IntegerType::get(ctx, 32, Signedness::Signless);
    let func_ty = FuncType::get(ctx, i32_ty.into(), vec![], false);

    let call_op = call_intrinsic(ctx, rewriter, intrinsic_name, func_ty, vec![])?;
    rewriter.replace_operation_with_values(ctx, op, vec![result]);
    Ok(())
}

备注

LLVM 内建函数名称使用点(llvm.nvvm.read.ptx.sreg.tid.x),但 pliron 标识符不能包含点。内部我们使用下划线(llvm_nvvm_read_ptx_sreg_tid_x)。导出阶段将它们转换回去。

阶段 5 -- 导出(无需更改)#

文件: crates/dialect-llvm/src/export.rs

CallOp 导出器已经处理了下划线到点的转换:

let fixed_name = if name.starts_with("llvm_nvvm") {
    name.replace('_', ".")
} else {
    strip_device_prefix(&name)
};

由于 threadIdx_x 不是 convergent 的(它是每线程寄存器读取,而非集体操作),无需向 is_convergent_intrinsic() 添加任何内容。

最终输出:

declare i32 @llvm.nvvm.read.ptx.sreg.tid.x()
%v5 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()

经过 llc 之后:

mov.u32  %r1, %tid.x;

五个阶段,一条 PTX 指令。还不错。


示例 2:shuffle_xor(复杂案例)#

Warp shuffle 更有趣。用户传递两个参数,但底层的 LLVM 内建函数期望四个参数。编译器在幕后填充额外的参数。该操作也是 convergent 的,意味着 LLVM 不得跨控制流移动、复制或推测它。

阶段 1 -- 在 cuda-device 中声明#

文件: crates/cuda-device/src/warp.rs

#[inline(never)]
pub fn shuffle_xor(var: u32, lane_mask: u32) -> u32 {
    let _ = (var, lane_mask);
    unreachable!("shuffle_xor called outside CUDA kernel context")
}

与之前相同的模式。let _ = (var, lane_mask); 抑制未使用变量的警告——一个不花费任何代价的小礼节。

阶段 2 -- 定义 dialect-nvvm 操作#

文件: crates/dialect-nvvm/src/ops/warp.rs

#[pliron_op(name = "nvvm.shfl_sync_bfly_i32", dialect = "nvvm", format)]
pub struct ShflSyncBflyI32Op;

impl Verify for ShflSyncBflyI32Op {
    fn verify(&self, ctx: &Context) -> Result<(), Error> {
        let op = &*self.get_operation().deref(ctx);
        if op.get_num_operands() != 2 {
            return verify_err!(op.loc(), "expected 2 operands");
        }
        if op.get_num_results() != 1 {
            return verify_err!(op.loc(), "expected 1 result");
        }
        Ok(())
    }
}

这次是两个操作数(值和 lane mask),一个结果。

阶段 3 -- 在 mir-importer 中识别#

文件: crates/mir-importer/src/translator/terminator/mod.rs

threadIdx_x 不同,这次调用专门的发射器,因为它需要处理用户的两个参数:

"cuda_device::warp::shuffle_xor" => Ok(Some(
    intrinsics::warp::emit_warp_shuffle_i32(
        ctx, body,
        ShflSyncBflyI32Op::get_concrete_op_info(),
        args, destination, target, ...
    )?
)),

文件: crates/mir-importer/src/translator/terminator/intrinsics/warp.rs

pub fn emit_warp_shuffle_i32(ctx, body, shuffle_opid, args, ...) {
    if args.len() != 2 { return error; }

    let (val, _) = translate_operand(ctx, body, &args[0], ...);
    let (lane_or_mask, _) = translate_operand(ctx, body, &args[1], ...);

    let shuffle_op = Operation::new(ctx, shuffle_opid,
        vec![u32_type.to_ptr()],          // 结果: [u32]
        vec![val, lane_or_mask],          // 操作数: [value, mask]
        vec![], 0,
    );

}

发射器将用户的 MIR 操作数翻译为 pliron 值——通常是来自每个操作数的 alloca 槽的 mir.load 结果,或者是字面量参数的 mir.constant——并将它们接入 NVVM 操作。(这些还不是 SSA 值;pliron::opts::mem2reg 将在翻译完成后折叠 load/store 链。)对于具有不同参数模式的内建函数,你编写类似的自定义发射器。

阶段 4 -- 降级到 dialect-llvm#

文件: crates/mir-lower/src/convert/intrinsics/warp.rs

这里变得有趣了。butterfly shuffle 的 LLVM 内建函数接受四个参数,而不是两个:

用户 API:       shuffle_xor(value, lane_mask)            →  2 个参数
LLVM 内建函数:  shfl.sync.bfly.i32(mask, value, lane_mask, clamp) →  4 个参数
                                     ^^^^                   ^^^^^
                                     总是 -1                 总是 31

转换器添加编译器提供的常量:

fn convert_shuffle_i32(op, ctx, intrinsic_name, clamp: i32) {
    let operands = get_operands(op, ctx)?;
    let (val, lane_or_delta) = (operands[0], operands[1]);

    let mask_val  = create_i32_const(ctx, -1);   // 所有 32 个 lane 参与
    let clamp_val = create_i32_const(ctx, clamp); // 完整 warp 宽度

    let func_ty = FuncType::get(ctx.ctx, i32_ty.into(),
        vec![i32_ty, i32_ty, i32_ty, i32_ty], false);

    let call_op = call_intrinsic(ctx, intrinsic_name, func_ty,
        vec![mask_val, val, lane_or_delta, clamp_val]);

    map_result(op, call_op, ctx);
}

mask 值 -1(所有位设置)表示 warp 中所有 32 个 lane 参与。clamp 值 31 表示 shuffle 在完整的 warp 宽度处回绕。这些是几乎所有用例的正确默认值,将它们暴露给用户只会是噪音。

阶段 5 -- 导出(添加 Convergent)#

Warp shuffle 是convergent 的——LLVM 不得相对于控制流重新排序它们。导出步骤检查 is_convergent_intrinsic()

fn is_convergent_intrinsic(name: &str) -> bool {
    name == "llvm.nvvm.barrier0"
        || name.starts_with("llvm.nvvm.shfl")      // shuffle
        || name.starts_with("llvm.nvvm.vote")       // vote
        || name.starts_with("llvm.nvvm.mbarrier")   // 异步 barrier
        || name.starts_with("llvm.nvvm.cp.async.bulk")
        // ...
}

如果你新的内建函数是 convergent 的,添加到这里。如果你忘记了,LLVM 可能会将其提升出 if 块,你的 warp 级代码将产生错误结果或死锁。不太好。

最终输出:

declare i32 @llvm.nvvm.shfl.sync.bfly.i32(i32, i32, i32, i32)

%v8 = call i32 @llvm.nvvm.shfl.sync.bfly.i32(
    i32 -1, i32 %v3, i32 %v4, i32 31) #0

attributes #0 = { convergent }

经过 llc 之后:

shfl.sync.bfly.b32  %r3, %r1, %r2, 31;

内联 PTX 路径#

某些操作没有 LLVM 内建函数。对于这些,我们直接发出内联 PTX 汇编。辅助函数 inline_asm_convergent() 处理样板代码:

// wgmma fence: 无输入,无输出,仅副作用
inline_asm_convergent(
    ctx, void_ty.into(), vec![],
    "wgmma.fence.sync.aligned;", ""
);

// mbarrier arrive: 一个输入(指针),一个输出(token)
inline_asm_convergent(
    ctx, i64_ty.into(), vec![ptr_val],
    "mbarrier.arrive.shared.b64 $0, [$1];", "=l,r"
);

约束字符串遵循 LLVM 内联汇编语法:

约束

含义

=l

输出:64 位寄存器

=r

输出:32 位寄存器

r

输入:32 位寄存器

l

输入:64 位寄存器

(空)

无输入或输出(仅副作用)

内联汇编上的 sideeffect convergent 标记告诉 LLVM 不要动它——不要移动它,不要删除它,不要复制它。


端到端:完整旅程#

以下是 threadIdx_x 经历的每种表示形式,自上而下:

Rust:         thread::threadIdx_x()
MIR:          _3 = threadIdx_x() -> bb1
Pliron MIR:   %v = nvvm.read_ptx_sreg_tid_x : i32
Pliron LLVM:  %v = call i32 @llvm_nvvm_read_ptx_sreg_tid_x()
LLVM IR:      %v5 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
PTX:          mov.u32 %r1, %tid.x;

以及 shuffle_xor

Rust:         warp::shuffle_xor(val, mask)
MIR:          _5 = shuffle_xor(_3, _4) -> bb2
Pliron MIR:   %v = nvvm.shfl_sync_bfly_i32 %val, %mask : i32
Pliron LLVM:  %v = call i32 @llvm_nvvm_shfl_sync_bfly_i32(i32 -1, %val, %mask, i32 31)
LLVM IR:      %v8 = call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, %v3, %v4, i32 31) #0
PTX:          shfl.sync.bfly.b32 %r3, %r1, %r2, 31;

六种表示形式。一个 Rust 函数调用变成一条 PTX 指令。中间步骤的存在使每个转换都是小的、可验证的、可独立测试的。


快速参考清单#

你需要接触的每个文件,按顺序:

  1. cuda-device/src/<module>.rs——pub fn 带有 #[inline(never)]unreachable!() 函数体。这是面向用户的 API。

  2. dialect-nvvm/src/ops/<module>.rs——#[pliron_op(name = "nvvm.<name>", ...)] 结构体,Verify impl(检查操作数/结果数量),register() 调用。

  3. mir-importer/src/translator/terminator/mod.rs——try_dispatch_intrinsic() 中的 match 分支。对于零参数内建函数使用 helpers::emit_nvvm_intrinsic(),对于有参数的内建函数编写自定义发射器。

  4. mir-lower/src/convert/interface_impls.rs——新操作的 MirToLlvmConversion impl,分发到转换函数。

  5. mir-lower/src/convert/intrinsics/<module>.rs——转换逻辑。对于 LLVM 内建函数使用 call_intrinsic(),对于内联 PTX 使用 inline_asm_convergent()

  6. dialect-llvm/src/export.rs——仅在 convergent 时:添加到 is_convergent_intrinsic()


并行对照#

阶段

threadIdx_x(简单)

shuffle_xor(复杂)

cuda-device

fn() -> u32

fn(u32, u32) -> u32

dialect-nvvm

0 个操作数,1 个结果

2 个操作数,1 个结果

mir-importer

通用辅助函数

自定义发射器

mir-lower

call @intrinsic()

call @intrinsic(4 args)

Convergent?

是(#0

PTX

mov.u32 %r1, %tid.x

shfl.sync.bfly.b32 ...


以上就是整个过程。五个文件,每个都有清晰而有限的职责。这个模式足够机械化,一旦你做过一次,添加一个新的内建函数应该大约需要三十分钟——其中大部分时间用于阅读 PTX ISA 规范,以精确了解你想要什么指令。