Pliron 方言#
cuda-oxide 并非在一次庞大的转换中将 Rust 降级到 PTX。它使用了三种自定义的 pliron 方言,每种方言建模不同的抽象层次。本章将逐一介绍这三种方言——它们的类型、操作,以及它们如何组合在一起构成编译流水线。
如果你还没有阅读 Pliron -- Pliron IR (MLIR-like) 章节,现在是阅读它的好时机。该章节中的概念(operation、type、attribute、region、Ptr<T>、def-use chain)是构建本页所有内容的基础。
三种方言概览#
方言 |
用途 |
抽象层次 |
|---|---|---|
dialect-mir |
建模 Rust MIR 语义 |
最高层 -- Rust 类型、元组、枚举、切片、检查算术 |
dialect-llvm |
建模 LLVM IR |
中间层 -- 扁平类型、GEP、PHI 就绪的控制流 |
dialect-nvvm |
建模 NVIDIA GPU 内建函数 |
正交层 -- 线程索引、warp、TMA、WGMMA、tcgen05 |
dialect-nvvm 是"正交的"而非栈中的一个层级,因为它的操作与 dialect-llvm 操作并存,而非在其之下。一次 warp shuffle 和一个整数加法共存于同一个函数体中。
数据在流水线中的流转如下:
dialect-mir ──(mem2reg)──▶ dialect-mir (SSA) ──(DialectConversion)──▶ dialect-llvm + dialect-nvvm ops ──(export.rs)──▶ 文本 LLVM IR ──(llc)──▶ PTX
每个箭头都是一个明确定义的转换。前两个发生在 pliron 内部;最后一个是 LLVM 的 NVPTX 后端在完成它的本职工作。
dialect-mir -- Rust 层#
dialect-mir 将 Rust 的类型系统和控制流语义保留为 pliron 操作。这是有意为之:我们希望在将 Rust 概念(元组、枚举、检查算术、地址空间)展平为 LLVM 类型系统之前对其进行推理。
类型#
该方言定义了七种自定义类型,反映了 Rust 的复合类型:
类型 |
示例 |
描述 |
|---|---|---|
|
|
异构元组 |
|
|
带 GPU 地址空间的指针 |
|
|
固定大小的数组 |
|
|
带布局信息的命名结构体 |
|
|
胖指针(指针 + 长度) |
|
|
安全检查切片 -- 每个线程访问唯一元素 |
|
|
Rust 枚举,包含判别值和变体负载 |
mir.ptr 和 mir.slice 上的地址空间跟踪数据在 GPU 内存层级中的位置:
地址空间 |
含义 |
|---|---|
0 |
全局(运行时解析) |
1 |
通用(设备 DRAM) |
3 |
共享(每 block 的 SRAM) |
4 |
常量(只读缓存) |
5 |
局部(每线程栈,溢出到 DRAM) |
6 |
张量内存(Blackwell TMEM) |
操作#
dialect-mir 定义了 54 种操作,分为 11 个类别:
类别 |
示例 |
数量 |
|---|---|---|
函数 |
|
1 |
控制流 |
|
5 |
常量 |
|
3 |
内存 |
|
9 |
算术 |
|
15 |
比较 |
|
6 |
聚合 |
|
8 |
枚举 |
|
3 |
类型转换 |
|
1 |
存储 |
|
2 |
调用 |
|
1 |
操作种类很多,但它们归属于自然的组别。如果你了解 Rust MIR(或者读过 rustc_public 章节),每种操作都直接映射到一个 MIR 概念。
IR 的样子#
以下是 dialect-mir 操作在实际使用中的几个示例。为了可读性有所简化——实际的打印形式包含更多元数据。
检查加法(Rust:let sum = a + b,其中 a, b: i32):
// mir.checked_add 返回一个 (result, overflow_flag) 元组
%checked = mir.checked_add %a, %b : i32
%sum = mir.extract_field %checked, 0 : mir.tuple<i32, i1>
%overflowed = mir.extract_field %checked, 1 : mir.tuple<i32, i1>
mir.assert %overflowed == false, "attempt to add with overflow" -> bb1
结构体构造和字段访问(Rust:point.x):
%point = mir.construct_struct %x, %y : mir.struct<"Point", [f32, f32]>
%x_val = mir.extract_field %point, 0 : mir.struct<"Point", [f32, f32]>
共享内存分配(GPU 特定的部分):
%shmem = mir.shared_alloc : mir.ptr<f32, mutable, addrspace: 3>
mir.store %value, %shmem : f32
验证#
每个 MIR 操作在构造时都会验证类型一致性。这能在早期捕获导入错误——在它们有机会通过降级流水线传播并在三步之后以神秘的 LLVM 错误形式出现之前。
被检查的内容示例:
mir.add验证两个操作数具有相同的类型。mir.cond_br验证条件是i1(布尔值)。mir.extract_field验证字段索引在边界内,且结果类型与字段类型匹配。mir.store验证值类型与指针的指向类型匹配。
DisjointSlice 的安全保证("一个线程,一个元素")在类型系统层面通过 ThreadIndex 强制执行——只有硬件派生的线程索引才能访问该切片。没有单独的编译器 pass 用于 disjoint-access 验证;安全性来自 Rust 类型系统和 cuda-device 的 API 设计。
dialect-llvm -- LLVM 层#
dialect-llvm 将 LLVM IR 建模为 pliron 操作。它提供了与文本 .ll 文件近乎 1:1 的映射——每个 LLVM 指令都有一个对应的 pliron 操作,并且类型直接映射到 LLVM 的类型系统。
类型#
类型 |
示例 |
描述 |
|---|---|---|
整数 |
|
Pliron 内置,直接使用 |
浮点 |
|
Pliron 内置( |
|
|
不透明指针,可带地址空间 |
|
|
命名或匿名,可以是不透明的 |
|
|
固定大小的数组 |
|
|
SIMD 向量 |
|
|
函数签名 |
|
|
单元类型 |
注意这里没有 Rust 特定的类型。当代码到达 dialect-llvm 时,元组已经变成了结构体,枚举已经变成了判别值索引的结构体,切片已经变成了指针-长度对。降级 pass(在降级流水线中介绍)处理所有这些展平工作。
操作#
该方言定义了 62 种操作:
类别 |
示例 |
数量 |
|---|---|---|
算术 |
|
19 |
类型转换 |
|
13 |
控制流 |
|
5 |
内存 |
|
4 |
原子操作 |
|
5 |
比较 |
|
2 |
聚合 |
|
3 |
调用 |
|
2 |
内联汇编 |
|
2 |
常量 |
|
3 |
符号 |
|
3 |
选择 |
|
1 |
如果你之前读过 LLVM IR,这里没有会令你感到意外的地方。操作名称有意保持与对应的 LLVM 操作相同,在 IR 中以 llvm. 作为前缀。
导出引擎#
dialect-llvm 的皇冠明珠是 export.rs——该模块将 pliron IR 模块转换为有效的文本 LLVM IR。这不仅仅是"打印每个操作";在导出过程中会发生若干非平凡的转换:
块参数变成 PHI 节点。 Pliron IR(类 MLIR)将汇合点建模为块参数——基本块之间的函数式调用约定。LLVM IR 则使用 PHI 节点。导出器从分支操作数构建前驱映射,并在每个非入口块的顶部输出 phi 指令。
值命名。 一个预 pass 为每个值分配顺序的 SSA 名称(%v0、%v1、...)。常量有特殊处理:llvm.constant 的结果映射到其字面值(而非 %vN 名称),以便 PHI 可以引用出现在输出中较后位置的块中的常量。
NVVM 内建函数名称转换。 Pliron 标识符使用下划线;LLVM 内建函数使用点。导出器通过将下划线替换为点来转换所有以 llvm_ 开头的名称:llvm_nvvm_read_ptx_sreg_tid_x 变为 llvm.nvvm.read.ptx.sreg.tid.x。这是一个机械的转换,而非查表操作。
convergent 属性标记。 barrier、shuffle 和 vote 内建函数必须标记为 convergent,以防止 LLVM 将它们提升出控制流。导出器通过前缀模式匹配(点格式)名称来识别这些操作,并在其调用点附加 #0,同时在模块级别输出 attributes #0 = { convergent }。
内核元数据。 标记为 kernel 的函数获得 ptx_kernel 调用约定和一条 !nvvm.annotations 元数据条目。
以下是简单向量加法内核的导出 LLVM IR 的样子:
target datalayout = "e-i64:64-i128:128-v16:16-v32:32-n16:32:64"
target triple = "nvptx64-nvidia-cuda"
declare i32 @llvm.nvvm.read.ptx.sreg.tid.x()
declare i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
declare i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
define ptx_kernel void @vecadd(ptr addrspace(1) %v0, i64 %v1,
ptr addrspace(1) %v2, i64 %v3,
ptr addrspace(1) %v4, i64 %v5) {
entry:
%v6 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x() #0
%v7 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x() #0
%v8 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x() #0
%v9 = mul i32 %v8, %v7
%v10 = add i32 %v9, %v6
; ... bounds check, load, add, store ...
ret void
}
!nvvm.annotations = !{!0}
!0 = !{ptr @vecadd, !"kernel", i32 1}
attributes #0 = { convergent }
注意切片已被标量化:每个 Rust &[f32] 变成了一个 ptr addrspace(1) 和一个 i64 长度。这发生在降级 pass 中;当 dialect-llvm 看到它们时,它们已经是扁平的参数了。
dialect-nvvm -- GPU 层#
dialect-nvvm 将 NVIDIA 的 GPU 内建函数包装为类型化的 pliron 操作。这些操作并不构成降级链中的一个"层级"——它们在 dialect-mir → dialect-llvm 降级 pass 期间被插入,并与 dialect-llvm 操作共存于同一个函数体中。在导出时,它们变成对 @llvm.nvvm.* 内建函数的 call 指令。
架构覆盖#
该方言按模块组织,每个模块面向一个 GPU 功能集:
模块 |
描述 |
操作数 |
最低 SM |
GPU 家族 |
|---|---|---|---|---|
|
线程/block 索引、 |
18 |
全部 |
所有 GPU |
|
Lane id、shuffle、vote、match |
18 |
全部 |
所有 GPU |
|
协作 |
1 |
sm_70 |
Volta+ |
|
Clock、trap、breakpoint、 |
6 |
全部 |
所有 GPU |
|
Atomic load/store/RMW/cmpxchg |
4 |
sm_70 |
Volta+ |
|
线程块集群 + DSMEM |
11 |
sm_90 |
Hopper+ |
|
异步 barrier + fence proxy + nanosleep |
10 |
sm_90 |
Hopper+ |
|
张量内存加速器(批量 G2S/S2G) |
15 |
sm_90 |
Hopper+ |
|
Warpgroup 矩阵乘累加 |
5 |
sm_90 |
Hopper+ |
|
共享内存矩阵存储 + bf16 转换 |
5 |
sm_90 |
Hopper+ |
|
Tensor Core Gen 5 + TMEM |
24 |
sm_100 |
Blackwell+ |
|
集群启动控制 |
6 |
sm_100 |
Blackwell+ |
总计 123 种操作。大多数用户只会用到前三个模块(线程索引、warp shuffle、barrier)。其余的是用于高级 GPU 编程——TMA、矩阵加速器和 Blackwell 的张量内存——在高级 GPU 功能章节中介绍。
从 Rust 到 PTX:一个内建函数的旅程#
每个 NVVM 操作映射到三个命名层次:
Pliron 操作 |
LLVM 内建函数 |
PTX 指令 |
|---|---|---|
|
|
|
|
|
|
|
|
|
|
|
|
第一列是 dialect-nvvm 中的 Rust 结构体名称。第二列是 export.rs 输出的内容(经过下划线到点的转换后)。第三列是 llc 产生的内容。你永远不需要手动编写这些内容——它们由 mir-lower 在看到对 cuda-device 内建函数(如 thread::index_x() 或 warp::shfl_sync_bfly())的调用时生成。
验证策略#
NVVM 操作使用最小化的结构验证:每个操作检查其操作数数量与结果数量,少数操作会验证结果类型(线程索引操作要求 i32 结果;tcgen05 加载检查其 32 寄存器和 4 寄存器变体的确切结果数量)。
这是有意为之。NVVM 操作是由 mir-lower 机器生成的——它们从不由用户手写。LLVM 的 NVPTX 后端在下游提供全面的类型验证。为每个 NVVM 操作添加完整的类型检查会使方言的代码量翻倍而不带来实际收益。
备注
GPU 架构要求(sm_70、sm_90、sm_100)已记录,但在 pliron 层面不强制执行。架构验证在后续使用特定的 -mcpu=sm_XX 标志调用 llc 时发生。如果你使用了 Hopper 内建函数但目标为 Volta,llc 会明确地告诉你。
方言如何交互#
以下是单个 Rust 操作在三个抽象层次中的完整生命周期:
Rust 源码: let sum = a + b; // a, b: f32
dialect-mir: %sum = mir.add %a, %b : f32
↓ (DialectConversion)
dialect-llvm: %v5 = fadd float %v3, %v4
↓ (export.rs)
LLVM IR: %v5 = fadd float %v3, %v4
↓ (llc --mcpu=sm_80)
PTX: add.f32 %f3, %f1, %f2;
dialect-mir → dialect-llvm 步骤是重要工作发生的地方:mir.add 在 f32 上变为 fadd(浮点加法),而 mir.add 在 i32 上变为 add(整数加法)。检查运算如 mir.checked_add 展开为 llvm.add、一个常量 i1 false 的溢出标志,以及一个 insertvalue 插入结构体——GPU 路径省略溢出检测(因为 GPU 整数算术是回绕的)。降级 pass 处理所有这些翻译。
对于 GPU 特定的操作,dialect-nvvm 进入舞台:
Rust 源码: let tid = thread::threadIdx_x();
dialect-mir: %tid = mir.call @cuda_oxide_device_<hash>_thread_index_x()
↓ (DialectConversion, 识别该内建函数)
dialect-nvvm: %v2 = nvvm.read_ptx_sreg_tid_x : i32
↓ (export.rs)
LLVM IR: %v2 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x() #0
↓ (llc)
PTX: mov.u32 %r1, %tid.x;
降级 pass 通过完全限定名称(FQDN)识别对 cuda_device 内建函数的调用,并将其替换为相应的 dialect-nvvm 操作。不需要通用的"函数调用"机制——内建函数变成直接的硬件指令。
全局概览#
综合来看,编译后的内核体包含 dialect-llvm 和 dialect-nvvm 操作的混合:
llvm.func @vecadd(...) {
entry:
%tid = nvvm.read_ptx_sreg_tid_x // NVVM: 线程索引
%ntid = nvvm.read_ptx_sreg_ntid_x // NVVM: block 大小
%ctaid = nvvm.read_ptx_sreg_ctaid_x // NVVM: block 索引
%offset = llvm.mul %ctaid, %ntid // LLVM: 整数数学
%idx = llvm.add %offset, %tid // LLVM: 整数数学
%cmp = llvm.icmp slt %idx, %len // LLVM: 边界检查
llvm.cond_br %cmp, bb1, bb2 // LLVM: 分支
bb1:
%p_a = llvm.gep %a, %idx // LLVM: 指针算术
%val_a = llvm.load %p_a // LLVM: 内存访问
%p_b = llvm.gep %b, %idx
%val_b = llvm.load %p_b
%sum = llvm.fadd %val_a, %val_b // LLVM: 浮点加法
%p_c = llvm.gep %c, %idx
llvm.store %sum, %p_c
llvm.br bb2
bb2:
llvm.return void
}
顶部的 dialect-nvvm 操作计算全局线程索引。其余的全部是标准的 dialect-llvm——加载、存储、算术、分支。导出引擎将这一切序列化为单个 .ll 文件,llc 将其编译为 PTX。
关于这些方言如何通过降级 pass 连接,参见 降级流水线。