Pliron 方言#

cuda-oxide 并非在一次庞大的转换中将 Rust 降级到 PTX。它使用了三种自定义的 pliron 方言,每种方言建模不同的抽象层次。本章将逐一介绍这三种方言——它们的类型、操作,以及它们如何组合在一起构成编译流水线。

如果你还没有阅读 Pliron -- Pliron IR (MLIR-like) 章节,现在是阅读它的好时机。该章节中的概念(operation、type、attribute、region、Ptr<T>、def-use chain)是构建本页所有内容的基础。


三种方言概览#

方言

用途

抽象层次

dialect-mir

建模 Rust MIR 语义

最高层 -- Rust 类型、元组、枚举、切片、检查算术

dialect-llvm

建模 LLVM IR

中间层 -- 扁平类型、GEP、PHI 就绪的控制流

dialect-nvvm

建模 NVIDIA GPU 内建函数

正交层 -- 线程索引、warp、TMA、WGMMA、tcgen05

dialect-nvvm 是"正交的"而非栈中的一个层级,因为它的操作与 dialect-llvm 操作并存,而非在其之下。一次 warp shuffle 和一个整数加法共存于同一个函数体中。

数据在流水线中的流转如下:

dialect-mir ──(mem2reg)──▶ dialect-mir (SSA) ──(DialectConversion)──▶ dialect-llvm + dialect-nvvm ops ──(export.rs)──▶ 文本 LLVM IR ──(llc)──▶ PTX

每个箭头都是一个明确定义的转换。前两个发生在 pliron 内部;最后一个是 LLVM 的 NVPTX 后端在完成它的本职工作。


dialect-mir -- Rust 层#

dialect-mir 将 Rust 的类型系统和控制流语义保留为 pliron 操作。这是有意为之:我们希望在将 Rust 概念(元组、枚举、检查算术、地址空间)展平为 LLVM 类型系统之前对其进行推理。

类型#

该方言定义了七种自定义类型,反映了 Rust 的复合类型:

类型

示例

描述

mir.tuple

mir.tuple<i32, f32, i64>

异构元组

mir.ptr

mir.ptr<f32, mutable, addrspace: 1>

带 GPU 地址空间的指针

mir.array

mir.array<f32, 256>

固定大小的数组

mir.struct

mir.struct<"Point", [f32, f32]>

带布局信息的命名结构体

mir.slice

mir.slice<f32, addrspace: 1>

胖指针(指针 + 长度)

mir.disjoint_slice

mir.disjoint_slice<f32>

安全检查切片 -- 每个线程访问唯一元素

mir.enum

mir.enum<"Option_i32", [("None", []), ("Some", [i32])]>

Rust 枚举,包含判别值和变体负载

mir.ptrmir.slice 上的地址空间跟踪数据在 GPU 内存层级中的位置:

地址空间

含义

0

全局(运行时解析)

1

通用(设备 DRAM)

3

共享(每 block 的 SRAM)

4

常量(只读缓存)

5

局部(每线程栈,溢出到 DRAM)

6

张量内存(Blackwell TMEM)

操作#

dialect-mir 定义了 54 种操作,分为 11 个类别:

类别

示例

数量

函数

mir.func

1

控制流

mir.gotomir.cond_brmir.returnmir.assertmir.unreachable

5

常量

mir.constantmir.float_constantmir.undef

3

内存

mir.allocamir.loadmir.storemir.assignmir.refmir.ptr_offsetmir.shared_allocmir.global_allocmir.extern_shared

9

算术

mir.addmir.submir.mulmir.divmir.remmir.checked_addmir.checked_submir.checked_mulmir.negmir.notmir.shrmir.shlmir.bitandmir.bitormir.bitxor

15

比较

mir.eqmir.nemir.ltmir.lemir.gtmir.ge

6

聚合

mir.extract_fieldmir.insert_fieldmir.construct_structmir.construct_tuplemir.construct_arraymir.extract_array_elementmir.field_addrmir.array_element_addr

8

枚举

mir.get_discriminantmir.construct_enummir.enum_payload

3

类型转换

mir.cast

1

存储

mir.storage_livemir.storage_dead

2

调用

mir.call

1

操作种类很多,但它们归属于自然的组别。如果你了解 Rust MIR(或者读过 rustc_public 章节),每种操作都直接映射到一个 MIR 概念。

IR 的样子#

以下是 dialect-mir 操作在实际使用中的几个示例。为了可读性有所简化——实际的打印形式包含更多元数据。

检查加法(Rust:let sum = a + b,其中 a, b: i32):

// mir.checked_add 返回一个 (result, overflow_flag) 元组
%checked = mir.checked_add %a, %b : i32
%sum     = mir.extract_field %checked, 0 : mir.tuple<i32, i1>
%overflowed = mir.extract_field %checked, 1 : mir.tuple<i32, i1>
mir.assert %overflowed == false, "attempt to add with overflow" -> bb1

结构体构造和字段访问(Rust:point.x):

%point = mir.construct_struct %x, %y : mir.struct<"Point", [f32, f32]>
%x_val = mir.extract_field %point, 0 : mir.struct<"Point", [f32, f32]>

共享内存分配(GPU 特定的部分):

%shmem = mir.shared_alloc : mir.ptr<f32, mutable, addrspace: 3>
mir.store %value, %shmem : f32

验证#

每个 MIR 操作在构造时都会验证类型一致性。这能在早期捕获导入错误——在它们有机会通过降级流水线传播并在三步之后以神秘的 LLVM 错误形式出现之前。

被检查的内容示例:

  • mir.add 验证两个操作数具有相同的类型。

  • mir.cond_br 验证条件是 i1(布尔值)。

  • mir.extract_field 验证字段索引在边界内,且结果类型与字段类型匹配。

  • mir.store 验证值类型与指针的指向类型匹配。

DisjointSlice 的安全保证("一个线程,一个元素")在类型系统层面通过 ThreadIndex 强制执行——只有硬件派生的线程索引才能访问该切片。没有单独的编译器 pass 用于 disjoint-access 验证;安全性来自 Rust 类型系统和 cuda-device 的 API 设计。


dialect-llvm -- LLVM 层#

dialect-llvm 将 LLVM IR 建模为 pliron 操作。它提供了与文本 .ll 文件近乎 1:1 的映射——每个 LLVM 指令都有一个对应的 pliron 操作,并且类型直接映射到 LLVM 的类型系统。

类型#

类型

示例

描述

整数

i1i8i16i32i64i128

Pliron 内置,直接使用

浮点

halffloatdouble

Pliron 内置(FP16TypeFP32TypeFP64Type

llvm.ptr

ptr addrspace(1)

不透明指针,可带地址空间

llvm.struct

{ i32, float }%MyStruct

命名或匿名,可以是不透明的

llvm.array

[256 x float]

固定大小的数组

llvm.vector

<4 x float>

SIMD 向量

llvm.func

(i32, ptr) -> void

函数签名

llvm.void

void

单元类型

注意这里没有 Rust 特定的类型。当代码到达 dialect-llvm 时,元组已经变成了结构体,枚举已经变成了判别值索引的结构体,切片已经变成了指针-长度对。降级 pass(在降级流水线中介绍)处理所有这些展平工作。

操作#

该方言定义了 62 种操作:

类别

示例

数量

算术

addsubmulfaddfsubfmulfdivfremfneg、...

19

类型转换

zextsexttruncfpextfptruncsitofpuitofpfptosifptouiptrtointinttoptraddrspacecastbitcast

13

控制流

brcond_brswitchreturnunreachable

5

内存

loadstoreallocagep

4

原子操作

atomic_loadatomic_storeatomicrmwcmpxchgfence

5

比较

icmpfcmp

2

聚合

extract_valueinsert_valueextractelement

3

调用

callcall_intrinsic

2

内联汇编

inline_asminline_asm_multi

2

常量

constantzeroundef

3

符号

funcglobaladdressof

3

选择

select

1

如果你之前读过 LLVM IR,这里没有会令你感到意外的地方。操作名称有意保持与对应的 LLVM 操作相同,在 IR 中以 llvm. 作为前缀。

导出引擎#

dialect-llvm 的皇冠明珠是 export.rs——该模块将 pliron IR 模块转换为有效的文本 LLVM IR。这不仅仅是"打印每个操作";在导出过程中会发生若干非平凡的转换:

块参数变成 PHI 节点。 Pliron IR(类 MLIR)将汇合点建模为块参数——基本块之间的函数式调用约定。LLVM IR 则使用 PHI 节点。导出器从分支操作数构建前驱映射,并在每个非入口块的顶部输出 phi 指令。

值命名。 一个预 pass 为每个值分配顺序的 SSA 名称(%v0%v1、...)。常量有特殊处理:llvm.constant 的结果映射到其字面值(而非 %vN 名称),以便 PHI 可以引用出现在输出中较后位置的块中的常量。

NVVM 内建函数名称转换。 Pliron 标识符使用下划线;LLVM 内建函数使用点。导出器通过将下划线替换为点来转换所有以 llvm_ 开头的名称:llvm_nvvm_read_ptx_sreg_tid_x 变为 llvm.nvvm.read.ptx.sreg.tid.x。这是一个机械的转换,而非查表操作。

convergent 属性标记。 barrier、shuffle 和 vote 内建函数必须标记为 convergent,以防止 LLVM 将它们提升出控制流。导出器通过前缀模式匹配(点格式)名称来识别这些操作,并在其调用点附加 #0,同时在模块级别输出 attributes #0 = { convergent }

内核元数据。 标记为 kernel 的函数获得 ptx_kernel 调用约定和一条 !nvvm.annotations 元数据条目。

以下是简单向量加法内核的导出 LLVM IR 的样子:

target datalayout = "e-i64:64-i128:128-v16:16-v32:32-n16:32:64"
target triple = "nvptx64-nvidia-cuda"

declare i32 @llvm.nvvm.read.ptx.sreg.tid.x()
declare i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
declare i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()

define ptx_kernel void @vecadd(ptr addrspace(1) %v0, i64 %v1,
                                ptr addrspace(1) %v2, i64 %v3,
                                ptr addrspace(1) %v4, i64 %v5) {
entry:
    %v6 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x() #0
    %v7 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x() #0
    %v8 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x() #0
    %v9 = mul i32 %v8, %v7
    %v10 = add i32 %v9, %v6
    ; ... bounds check, load, add, store ...
    ret void
}

!nvvm.annotations = !{!0}
!0 = !{ptr @vecadd, !"kernel", i32 1}
attributes #0 = { convergent }

注意切片已被标量化:每个 Rust &[f32] 变成了一个 ptr addrspace(1) 和一个 i64 长度。这发生在降级 pass 中;当 dialect-llvm 看到它们时,它们已经是扁平的参数了。


dialect-nvvm -- GPU 层#

dialect-nvvm 将 NVIDIA 的 GPU 内建函数包装为类型化的 pliron 操作。这些操作并不构成降级链中的一个"层级"——它们在 dialect-mirdialect-llvm 降级 pass 期间被插入,并与 dialect-llvm 操作共存于同一个函数体中。在导出时,它们变成对 @llvm.nvvm.* 内建函数的 call 指令。

架构覆盖#

该方言按模块组织,每个模块面向一个 GPU 功能集:

模块

描述

操作数

最低 SM

GPU 家族

thread

线程/block 索引、barrier0、threadfence

18

全部

所有 GPU

warp

Lane id、shuffle、vote、match

18

全部

所有 GPU

grid

协作 grid_sync

1

sm_70

Volta+

debug

Clock、trap、breakpoint、vprintf

6

全部

所有 GPU

atomic

Atomic load/store/RMW/cmpxchg

4

sm_70

Volta+

cluster

线程块集群 + DSMEM

11

sm_90

Hopper+

mbarrier

异步 barrier + fence proxy + nanosleep

10

sm_90

Hopper+

tma

张量内存加速器(批量 G2S/S2G)

15

sm_90

Hopper+

wgmma

Warpgroup 矩阵乘累加

5

sm_90

Hopper+

stmatrix

共享内存矩阵存储 + bf16 转换

5

sm_90

Hopper+

tcgen05

Tensor Core Gen 5 + TMEM

24

sm_100

Blackwell+

clc

集群启动控制

6

sm_100

Blackwell+

总计 123 种操作。大多数用户只会用到前三个模块(线程索引、warp shuffle、barrier)。其余的是用于高级 GPU 编程——TMA、矩阵加速器和 Blackwell 的张量内存——在高级 GPU 功能章节中介绍。

从 Rust 到 PTX:一个内建函数的旅程#

每个 NVVM 操作映射到三个命名层次:

Pliron 操作

LLVM 内建函数

PTX 指令

ReadPtxSregTidXOp

llvm.nvvm.read.ptx.sreg.tid.x

mov.u32 %r1, %tid.x

Barrier0Op

llvm.nvvm.barrier0

bar.sync 0

ShflSyncBflyI32Op

llvm.nvvm.shfl.sync.bfly.i32

shfl.sync.bfly.b32

CpAsyncBulkTensorG2sTile2dOp

llvm.nvvm.cp.async.bulk.tensor.2d.tile.g2s.im2col.*

cp.async.bulk.tensor.2d.tile.g2s ...

第一列是 dialect-nvvm 中的 Rust 结构体名称。第二列是 export.rs 输出的内容(经过下划线到点的转换后)。第三列是 llc 产生的内容。你永远不需要手动编写这些内容——它们由 mir-lower 在看到对 cuda-device 内建函数(如 thread::index_x()warp::shfl_sync_bfly())的调用时生成。

验证策略#

NVVM 操作使用最小化的结构验证:每个操作检查其操作数数量与结果数量,少数操作会验证结果类型(线程索引操作要求 i32 结果;tcgen05 加载检查其 32 寄存器和 4 寄存器变体的确切结果数量)。

这是有意为之。NVVM 操作是由 mir-lower 机器生成的——它们从不由用户手写。LLVM 的 NVPTX 后端在下游提供全面的类型验证。为每个 NVVM 操作添加完整的类型检查会使方言的代码量翻倍而不带来实际收益。

备注

GPU 架构要求(sm_70、sm_90、sm_100)已记录,但在 pliron 层面不强制执行。架构验证在后续使用特定的 -mcpu=sm_XX 标志调用 llc 时发生。如果你使用了 Hopper 内建函数但目标为 Volta,llc 会明确地告诉你。


方言如何交互#

以下是单个 Rust 操作在三个抽象层次中的完整生命周期:

Rust 源码:     let sum = a + b;        // a, b: f32

dialect-mir:   %sum = mir.add %a, %b : f32
                ↓  (DialectConversion)
dialect-llvm:  %v5 = fadd float %v3, %v4
                ↓  (export.rs)
LLVM IR:       %v5 = fadd float %v3, %v4
                ↓  (llc --mcpu=sm_80)
PTX:           add.f32 %f3, %f1, %f2;

dialect-mirdialect-llvm 步骤是重要工作发生的地方:mir.addf32 上变为 fadd(浮点加法),而 mir.addi32 上变为 add(整数加法)。检查运算如 mir.checked_add 展开为 llvm.add、一个常量 i1 false 的溢出标志,以及一个 insertvalue 插入结构体——GPU 路径省略溢出检测(因为 GPU 整数算术是回绕的)。降级 pass 处理所有这些翻译。

对于 GPU 特定的操作,dialect-nvvm 进入舞台:

Rust 源码:     let tid = thread::threadIdx_x();

dialect-mir:   %tid = mir.call @cuda_oxide_device_<hash>_thread_index_x()
                ↓  (DialectConversion, 识别该内建函数)
dialect-nvvm:  %v2 = nvvm.read_ptx_sreg_tid_x : i32
                ↓  (export.rs)
LLVM IR:       %v2 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x() #0
                ↓  (llc)
PTX:           mov.u32 %r1, %tid.x;

降级 pass 通过完全限定名称(FQDN)识别对 cuda_device 内建函数的调用,并将其替换为相应的 dialect-nvvm 操作。不需要通用的"函数调用"机制——内建函数变成直接的硬件指令。

全局概览#

综合来看,编译后的内核体包含 dialect-llvmdialect-nvvm 操作的混合:

llvm.func @vecadd(...) {
  entry:
    %tid    = nvvm.read_ptx_sreg_tid_x     // NVVM: 线程索引
    %ntid   = nvvm.read_ptx_sreg_ntid_x    // NVVM: block 大小
    %ctaid  = nvvm.read_ptx_sreg_ctaid_x   // NVVM: block 索引
    %offset = llvm.mul %ctaid, %ntid        // LLVM: 整数数学
    %idx    = llvm.add %offset, %tid        // LLVM: 整数数学
    %cmp    = llvm.icmp slt %idx, %len      // LLVM: 边界检查
    llvm.cond_br %cmp, bb1, bb2             // LLVM: 分支
  bb1:
    %p_a    = llvm.gep %a, %idx             // LLVM: 指针算术
    %val_a  = llvm.load %p_a                // LLVM: 内存访问
    %p_b    = llvm.gep %b, %idx
    %val_b  = llvm.load %p_b
    %sum    = llvm.fadd %val_a, %val_b      // LLVM: 浮点加法
    %p_c    = llvm.gep %c, %idx
    llvm.store %sum, %p_c
    llvm.br bb2
  bb2:
    llvm.return void
}

顶部的 dialect-nvvm 操作计算全局线程索引。其余的全部是标准的 dialect-llvm——加载、存储、算术、分支。导出引擎将这一切序列化为单个 .ll 文件,llc 将其编译为 PTX。


关于这些方言如何通过降级 pass 连接,参见 降级流水线