添加新的内建函数#
你想教 cuda-oxide 一项新的 GPU 技巧。也许 NVIDIA 刚发布了一条新指令,或者你需要一个现有的 PTX 操作但还没有人接入。好消息:这个过程是机械化的。五个 crate,五个步骤,一旦你做过一次,大约三十分钟就能完成。
本章通过两个真实示例——一个非常简单,一个带有一些转折——逐步展示完整流水线,以便你能看到每个阶段具体发生了什么。
五阶段流水线#
每个 GPU 内建函数在编译器中经历相同的路径:
cuda-device 用户编写: thread::threadIdx_x()
│
▼
mir-importer 编译器看到调用,发出 `dialect-nvvm` 操作
│
▼
dialect-nvvm 该操作作为已验证的 IR 节点存在于此
│
▼
mir-lower 将 `dialect-nvvm` 操作转换为 `dialect-llvm` 操作
│
▼
dialect-llvm 导出文本 LLVM IR → llc 将其转为 PTX
在 mir-lower 阶段,你选择两种策略之一:
策略 |
适用场景 |
示例 |
|---|---|---|
LLVM 内建函数调用 |
LLVM 已有内置的内建函数 |
|
内联 PTX 汇编 |
不存在 LLVM 内建函数,或你需要对 PTX 指令的精确控制 |
|
以下演示了这两种策略。
示例 1:threadIdx_x(简单案例)#
threadIdx_x() 是 GPU 内建函数中的"Hello, World":零参数,一个 u32 结果,直接映射到单个 LLVM NVVM 内建函数。如果你能理解这个示例,你就能添加任何简单的内建函数。
阶段 1 -- 在 cuda-device 中声明#
文件: crates/cuda-device/src/thread.rs
#[inline(never)]
pub fn threadIdx_x() -> u32 {
unreachable!("threadIdx_x called outside CUDA kernel context")
}
两条在你理解其中的技巧之前可能看起来奇怪的规定:
#[inline(never)]保持该函数在 MIR 中作为独立的调用可见。如果 rustc 内联了它,编译器将看到的是unreachable!()而非可以拦截的调用。那将是……不太有帮助的。函数体是
unreachable!(),因为此函数从不会真正执行。编译器在任何 GPU 代码执行之前将整个调用替换为dialect-nvvm操作。把这个函数看作一个占位符——一个"亲爱的编译器,请在此处插入 GPU 指令"的便条。
小技巧
在函数上方的注释中记录该内建函数映射到什么 LLVM IR 或 PTX。未来的你(或未来的贡献者)会感谢现在的你。
阶段 2 -- 定义 dialect-nvvm 操作#
文件: crates/dialect-nvvm/src/ops/thread.rs
#[pliron_op(name = "nvvm.read_ptx_sreg_tid_x", dialect = "nvvm", format)]
pub struct ReadPtxSregTidXOp;
impl Verify for ReadPtxSregTidXOp {
fn verify(&self, ctx: &Context) -> Result<(), Error> {
let op = &*self.get_operation().deref(ctx);
if op.get_num_operands() != 0 {
return verify_err!(op.loc(), "expected 0 operands");
}
if op.get_num_results() != 1 {
return verify_err!(op.loc(), "expected 1 result");
}
Ok(())
}
}
然后注册它,让 pliron 知道该操作存在:
pub(super) fn register(ctx: &mut Context) {
ReadPtxSregTidXOp::register(ctx, ReadPtxSregTidXOp::parser_fn);
}
Verify trait 在早期捕获结构错误。如果某处意外地用两个操作数创建了这个操作,验证会以明确的错误消息失败,而非在三个阶段之后产生无用的 PTX。
阶段 3 -- 在 mir-importer 中识别#
文件: crates/mir-importer/src/translator/terminator/mod.rs
当翻译器处理 MIR 时,每个函数调用都经过 try_dispatch_intrinsic()。被调用函数的完全限定域名(FQDN)通过 extract_func_info() 从 CrateDef::name() 获取,产生类似 cuda_device::thread::threadIdx_x 的路径。match 检查此 FQDN:
match name {
"cuda_device::threadIdx_x" | "cuda_device::thread::threadIdx_x" => {
Ok(Some(helpers::emit_nvvm_intrinsic(
ctx,
ReadPtxSregTidXOp::get_concrete_op_info(),
destination, target, block_ptr, prev_op,
value_map, block_map, loc,
)?))
}
// ... 数百个其他内建函数 ...
}
我们同时匹配重新导出的名称(cuda_device::threadIdx_x)和完整模块路径(cuda_device::thread::threadIdx_x),因此无论用户如何导入该函数,两者都能工作。FQDN 在匹配中按原样使用——在内建函数检查之前不会进行 :: 到 __ 的转换。
emit_nvvm_intrinsic() 辅助函数是一个适用于任何零参数、单结果 NVVM 内建函数的通用函数。它创建操作,将结果存储在值映射中,并发出到下一个基本块的分支。对于简单的内建函数,你永远不需要编写自定义发射器。
阶段 4 -- 降级到 dialect-llvm#
文件: crates/mir-lower/src/convert/intrinsics/basic.rs
该操作实现了 MirToLlvmConversion 操作接口。在 convert/interface_impls.rs 中,该 impl 分发到转换函数:
impl MirToLlvmConversion for ReadPtxSregTidXOp {
fn rewrite(ctx, rewriter, op, operands_info) -> Result<()> {
basic::convert_read_tid_x(ctx, rewriter, op, operands_info)
}
}
然后 basic::convert_read_tid_x 发出 LLVM 内建函数调用:
pub fn convert_read_tid_x(
ctx: &mut Context,
rewriter: &mut DialectConversionRewriter,
op: Ptr<Operation>,
_operands_info: &OperandsInfo,
) -> Result<()> {
let intrinsic_name = "llvm_nvvm_read_ptx_sreg_tid_x";
let i32_ty = IntegerType::get(ctx, 32, Signedness::Signless);
let func_ty = FuncType::get(ctx, i32_ty.into(), vec![], false);
let call_op = call_intrinsic(ctx, rewriter, intrinsic_name, func_ty, vec![])?;
rewriter.replace_operation_with_values(ctx, op, vec![result]);
Ok(())
}
备注
LLVM 内建函数名称使用点(llvm.nvvm.read.ptx.sreg.tid.x),但 pliron 标识符不能包含点。内部我们使用下划线(llvm_nvvm_read_ptx_sreg_tid_x)。导出阶段将它们转换回去。
阶段 5 -- 导出(无需更改)#
文件: crates/dialect-llvm/src/export.rs
CallOp 导出器已经处理了下划线到点的转换:
let fixed_name = if name.starts_with("llvm_nvvm") {
name.replace('_', ".")
} else {
strip_device_prefix(&name)
};
由于 threadIdx_x 不是 convergent 的(它是每线程寄存器读取,而非集体操作),无需向 is_convergent_intrinsic() 添加任何内容。
最终输出:
declare i32 @llvm.nvvm.read.ptx.sreg.tid.x()
%v5 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
经过 llc 之后:
mov.u32 %r1, %tid.x;
五个阶段,一条 PTX 指令。还不错。
示例 2:shuffle_xor(复杂案例)#
Warp shuffle 更有趣。用户传递两个参数,但底层的 LLVM 内建函数期望四个参数。编译器在幕后填充额外的参数。该操作也是 convergent 的,意味着 LLVM 不得跨控制流移动、复制或推测它。
阶段 1 -- 在 cuda-device 中声明#
文件: crates/cuda-device/src/warp.rs
#[inline(never)]
pub fn shuffle_xor(var: u32, lane_mask: u32) -> u32 {
let _ = (var, lane_mask);
unreachable!("shuffle_xor called outside CUDA kernel context")
}
与之前相同的模式。let _ = (var, lane_mask); 抑制未使用变量的警告——一个不花费任何代价的小礼节。
阶段 2 -- 定义 dialect-nvvm 操作#
文件: crates/dialect-nvvm/src/ops/warp.rs
#[pliron_op(name = "nvvm.shfl_sync_bfly_i32", dialect = "nvvm", format)]
pub struct ShflSyncBflyI32Op;
impl Verify for ShflSyncBflyI32Op {
fn verify(&self, ctx: &Context) -> Result<(), Error> {
let op = &*self.get_operation().deref(ctx);
if op.get_num_operands() != 2 {
return verify_err!(op.loc(), "expected 2 operands");
}
if op.get_num_results() != 1 {
return verify_err!(op.loc(), "expected 1 result");
}
Ok(())
}
}
这次是两个操作数(值和 lane mask),一个结果。
阶段 3 -- 在 mir-importer 中识别#
文件: crates/mir-importer/src/translator/terminator/mod.rs
与 threadIdx_x 不同,这次调用专门的发射器,因为它需要处理用户的两个参数:
"cuda_device::warp::shuffle_xor" => Ok(Some(
intrinsics::warp::emit_warp_shuffle_i32(
ctx, body,
ShflSyncBflyI32Op::get_concrete_op_info(),
args, destination, target, ...
)?
)),
文件: crates/mir-importer/src/translator/terminator/intrinsics/warp.rs
pub fn emit_warp_shuffle_i32(ctx, body, shuffle_opid, args, ...) {
if args.len() != 2 { return error; }
let (val, _) = translate_operand(ctx, body, &args[0], ...);
let (lane_or_mask, _) = translate_operand(ctx, body, &args[1], ...);
let shuffle_op = Operation::new(ctx, shuffle_opid,
vec![u32_type.to_ptr()], // 结果: [u32]
vec![val, lane_or_mask], // 操作数: [value, mask]
vec![], 0,
);
}
发射器将用户的 MIR 操作数翻译为 pliron 值——通常是来自每个操作数的 alloca 槽的 mir.load 结果,或者是字面量参数的 mir.constant——并将它们接入 NVVM 操作。(这些还不是 SSA 值;pliron::opts::mem2reg 将在翻译完成后折叠 load/store 链。)对于具有不同参数模式的内建函数,你编写类似的自定义发射器。
阶段 4 -- 降级到 dialect-llvm#
文件: crates/mir-lower/src/convert/intrinsics/warp.rs
这里变得有趣了。butterfly shuffle 的 LLVM 内建函数接受四个参数,而不是两个:
用户 API: shuffle_xor(value, lane_mask) → 2 个参数
LLVM 内建函数: shfl.sync.bfly.i32(mask, value, lane_mask, clamp) → 4 个参数
^^^^ ^^^^^
总是 -1 总是 31
转换器添加编译器提供的常量:
fn convert_shuffle_i32(op, ctx, intrinsic_name, clamp: i32) {
let operands = get_operands(op, ctx)?;
let (val, lane_or_delta) = (operands[0], operands[1]);
let mask_val = create_i32_const(ctx, -1); // 所有 32 个 lane 参与
let clamp_val = create_i32_const(ctx, clamp); // 完整 warp 宽度
let func_ty = FuncType::get(ctx.ctx, i32_ty.into(),
vec![i32_ty, i32_ty, i32_ty, i32_ty], false);
let call_op = call_intrinsic(ctx, intrinsic_name, func_ty,
vec![mask_val, val, lane_or_delta, clamp_val]);
map_result(op, call_op, ctx);
}
mask 值 -1(所有位设置)表示 warp 中所有 32 个 lane 参与。clamp 值 31 表示 shuffle 在完整的 warp 宽度处回绕。这些是几乎所有用例的正确默认值,将它们暴露给用户只会是噪音。
阶段 5 -- 导出(添加 Convergent)#
Warp shuffle 是convergent 的——LLVM 不得相对于控制流重新排序它们。导出步骤检查 is_convergent_intrinsic():
fn is_convergent_intrinsic(name: &str) -> bool {
name == "llvm.nvvm.barrier0"
|| name.starts_with("llvm.nvvm.shfl") // shuffle
|| name.starts_with("llvm.nvvm.vote") // vote
|| name.starts_with("llvm.nvvm.mbarrier") // 异步 barrier
|| name.starts_with("llvm.nvvm.cp.async.bulk")
// ...
}
如果你新的内建函数是 convergent 的,添加到这里。如果你忘记了,LLVM 可能会将其提升出 if 块,你的 warp 级代码将产生错误结果或死锁。不太好。
最终输出:
declare i32 @llvm.nvvm.shfl.sync.bfly.i32(i32, i32, i32, i32)
%v8 = call i32 @llvm.nvvm.shfl.sync.bfly.i32(
i32 -1, i32 %v3, i32 %v4, i32 31) #0
attributes #0 = { convergent }
经过 llc 之后:
shfl.sync.bfly.b32 %r3, %r1, %r2, 31;
内联 PTX 路径#
某些操作没有 LLVM 内建函数。对于这些,我们直接发出内联 PTX 汇编。辅助函数 inline_asm_convergent() 处理样板代码:
// wgmma fence: 无输入,无输出,仅副作用
inline_asm_convergent(
ctx, void_ty.into(), vec![],
"wgmma.fence.sync.aligned;", ""
);
// mbarrier arrive: 一个输入(指针),一个输出(token)
inline_asm_convergent(
ctx, i64_ty.into(), vec![ptr_val],
"mbarrier.arrive.shared.b64 $0, [$1];", "=l,r"
);
约束字符串遵循 LLVM 内联汇编语法:
约束 |
含义 |
|---|---|
|
输出:64 位寄存器 |
|
输出:32 位寄存器 |
|
输入:32 位寄存器 |
|
输入:64 位寄存器 |
(空) |
无输入或输出(仅副作用) |
内联汇编上的 sideeffect convergent 标记告诉 LLVM 不要动它——不要移动它,不要删除它,不要复制它。
端到端:完整旅程#
以下是 threadIdx_x 经历的每种表示形式,自上而下:
Rust: thread::threadIdx_x()
MIR: _3 = threadIdx_x() -> bb1
Pliron MIR: %v = nvvm.read_ptx_sreg_tid_x : i32
Pliron LLVM: %v = call i32 @llvm_nvvm_read_ptx_sreg_tid_x()
LLVM IR: %v5 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
PTX: mov.u32 %r1, %tid.x;
以及 shuffle_xor:
Rust: warp::shuffle_xor(val, mask)
MIR: _5 = shuffle_xor(_3, _4) -> bb2
Pliron MIR: %v = nvvm.shfl_sync_bfly_i32 %val, %mask : i32
Pliron LLVM: %v = call i32 @llvm_nvvm_shfl_sync_bfly_i32(i32 -1, %val, %mask, i32 31)
LLVM IR: %v8 = call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, %v3, %v4, i32 31) #0
PTX: shfl.sync.bfly.b32 %r3, %r1, %r2, 31;
六种表示形式。一个 Rust 函数调用变成一条 PTX 指令。中间步骤的存在使每个转换都是小的、可验证的、可独立测试的。
快速参考清单#
你需要接触的每个文件,按顺序:
cuda-device/src/<module>.rs——pub fn带有#[inline(never)]和unreachable!()函数体。这是面向用户的 API。dialect-nvvm/src/ops/<module>.rs——#[pliron_op(name = "nvvm.<name>", ...)]结构体,Verifyimpl(检查操作数/结果数量),register()调用。mir-importer/src/translator/terminator/mod.rs——try_dispatch_intrinsic()中的match分支。对于零参数内建函数使用helpers::emit_nvvm_intrinsic(),对于有参数的内建函数编写自定义发射器。mir-lower/src/convert/interface_impls.rs——新操作的MirToLlvmConversionimpl,分发到转换函数。mir-lower/src/convert/intrinsics/<module>.rs——转换逻辑。对于 LLVM 内建函数使用call_intrinsic(),对于内联 PTX 使用inline_asm_convergent()。dialect-llvm/src/export.rs——仅在 convergent 时:添加到is_convergent_intrinsic()。
并行对照#
阶段 |
|
|
|---|---|---|
cuda-device |
|
|
dialect-nvvm |
0 个操作数,1 个结果 |
2 个操作数,1 个结果 |
mir-importer |
通用辅助函数 |
自定义发射器 |
mir-lower |
|
|
Convergent? |
否 |
是( |
PTX |
|
|
以上就是整个过程。五个文件,每个都有清晰而有限的职责。这个模式足够机械化,一旦你做过一次,添加一个新的内建函数应该大约需要三十分钟——其中大部分时间用于阅读 PTX ISA 规范,以精确了解你想要什么指令。