Pliron——Pliron IR(类似 MLIR)#

cuda-oxide 不会一步将 Rust 直接编译为 PTX。它将代码通过一系列中间表示进行降级,每个表示捕捉不同的抽象层次。使之成为可能的框架就是 pliron——一个完全用 Rust 编写的可扩展编译器 IR 框架,受 LLVM 的 MLIR 启发。我们将在该框架上构建的 IR 称为 Pliron IR

本章解释了什么是上游 MLIR、为什么 pliron 作为替代方案存在,以及其核心数据结构如何工作。如果你从未构建过编译器,不用担心——这里不需要任何编译器经验,只需要对 Rust 的基本理解。

什么是 MLIR(以及为什么你应该关心)?#

LLVM IR 是一个固定的指令集。它大约有 70 个操作码(addloadbrgetelementptr 等等),如果你的领域不能清晰地映射到这些操作码上,那只能硬来——你将高层的概念压扁为低层的指令,然后寄希望于优化器能够重建你本来的意图。

MLIR(多级中间表示,Multi-Level Intermediate Representation)采取了不同的方法。它不提供一个固定的指令集,而是给你一个框架,让你可以定义多种指令集——称为 dialect——每种都为一个特定的领域量身定制。一个 dialect 是一个操作、类型和属性的集合,用于建模你真正关心的概念。

核心思想很简单:

  1. 定义一个 dialect,其操作匹配你的领域。

  2. 编写 pass 将一个 dialect 转换为另一个(降级)。

  3. 将 pass 链接起来,直到到达某个后端可以消费的东西。

考虑 Triton(PyTorch 背后的 GPU 编译器)如何使用 MLIR。Python GPU 代码首先降级为 TTIR,一个张量级别的 IR,其中像"将标量广播到张量上"这样的操作是一等公民。在那个级别,Triton 可以应用特定领域的优化——例如,将 splat + mul(广播标量,然后逐元素相乘)替换为一个单一的本地向量-标量乘法。当 IR 中有张量操作时,这个优化表达起来很简单,而在压扁为 LLVM IR 之后几乎不可能恢复。

cuda-oxide 面临类似的挑战。我们需要表示三个非常不同的抽象层次:

层次

建模内容

示例操作

dialect-mir

Rust 语义——元组、枚举、切片、带检查的算术运算

mir.extract_fieldmir.get_discriminant

dialect-llvm

接近机器的操作——整数数学、内存、控制流

llvm.addllvm.loadllvm.br

dialect-nvvm

GPU 内部函数——线程索引、warp shuffle、TMA、WGMMA

nvvm.read_ptx_sreg_tid_xnvvm.shfl_sync

没有一个可扩展的 IR,我们要么将 Rust 枚举强行塞入 LLVM IR(丢失语义信息),要么为管道的每个层次构建单独的 IR 框架。MLIR 让我们在单一系统中将所有三个定义为 dialect,并通过类型安全的 pass 在它们之间降级。

pliron 的登场#

MLIR 的 dialect 和降级模型是 cuda-oxide 的正确抽象:在 IR 中保留高层语义的同时,逐步降级到 GPU 后端可以消费的代码。问题在于实现的匹配度。上游 MLIR 是 LLVM C++ 生态系统的一部分,而 cuda-oxide 是一个围绕 Rust crate、Rust 类型和 Cargo 工作流构建的 Rust 编译器项目。

Pliron 是一个受 MLIR 启发的可扩展编译器 IR 框架,完全用 Rust 编写。它遵循相同的概念模型——operation、region、basic block、type、attribute、dialect、pass——同时自然地融入 Rust 原生的编译器栈。

方面

pliron

上游 MLIR

语言

纯 Rust

C++ 配合 Python 绑定

构建系统

cargo build

CMake + 完整 LLVM 构建

所需的 DSL

无——只需 Rust 宏

TableGen、ODS

调试

标准 Rust 工具(dbg!、rust-gdb)

对 C++ 模板使用 gdb

可扩展性

将 dialect 作为 Rust crate 添加

基于 LLVM 头文件的 C++ 扩展

依赖重量

一个 crate(git 依赖)

数 GB 的 LLVM 构建产物

对于 cuda-oxide,这意味着整个编译器——从 MIR 导入到 LLVM IR 导出——都停留在与其他项目相同的 Rust 构建和调试环境中。Pliron 为我们提供了可扩展的 IR 结构,而无需将编译器移出 Rust 生态系统。

核心数据结构#

Pliron 的 IR 由少量核心类型构建。理解这些是阅读(和编写)dialect 代码的关键。

Context#

Context 是拥有所有 IR 数据的中心数据结构——operation、basic block、region、type、attribute 和 dialect 注册。将其视为整个编译单元的 arena 分配器。

// 简化版——真实的 Context 有更多字段
pub struct Context {
    operations: SlotMap<Ptr<Operation>, Operation>,
    basic_blocks: SlotMap<Ptr<BasicBlock>, BasicBlock>,
    regions: SlotMap<Ptr<Region>, Region>,
    dialects: HashMap<DialectName, Dialect>,
    type_store: TypeStore,     // 去重(唯一化)的类型
    // Attribute 不集中存储——见下文"类型与属性"。
}

Pliron 使用分代 arena(来自 slotmap crate)而不是 Box 分配的堆节点。这带来了:

  • O(1) 插入和删除——没有树的重新平衡,没有链表的遍历。

  • 稳定的索引——插入或删除元素不会使其他索引失效。

  • 分代版本控制——每个槽位带有一个代数计数器。如果你删除了一个 operation 并且槽位被重用,任何旧引用将有一个过时的代数,并在运行时失败,而不是静默地读取垃圾数据。

每一块 IR 都存储在 Context 内部。你永远不会在函数边界之间持有直接的 &mut Operation——你持有一个 Ptr<Operation>,并在需要时通过 Context 解引用。

Operations#

一个 operation 是 IR 图中的单个节点。它有 operands(输入)、results(输出)、attributes(编译时元数据),以及可选的 regions(用于函数体和循环嵌套等嵌套结构)。

每个 operation 都属于一个 dialect。命名惯例是 dialect_name.op_name

#[pliron_op(name = "mir.func", dialect = "mir")]
pub struct MirFuncOp;

#[pliron_op(name = "nvvm.read_ptx_sreg_tid_x", dialect = "nvvm")]
pub struct ReadPtxSregTidXOp;

#[pliron_op(name = "llvm.add", dialect = "llvm")]
pub struct AddOp;

#[pliron_op(...)] proc macro(来自 pliron-derive)生成了样板代码,用于将 operation 注册到其 dialect、分配操作码以及连接 Op trait。你通过实现 VerifyPrintableParsable 来定义 operation 的语义。

类型与属性#

类型表示 IR 中的数据类型。Pliron 的内置 dialect 提供整数(i1i32i64)和浮点数(f32f64)。cuda-oxide 的 MIR dialect 用 Rust 特定的类型扩展了这些:

#[pliron_type(name = "mir.tuple", dialect = "mir")]
pub struct MirTupleType {
    pub types: Vec<Ptr<TypeObj>>,
}

#[pliron_type(name = "mir.slice", dialect = "mir")]
pub struct MirSliceType {
    pub element_ty: Ptr<TypeObj>,
}

#[pliron_type(name = "mir.enum", dialect = "mir")]
pub struct MirEnumType {
    pub name: String,
    pub discriminant_ty: Ptr<TypeObj>,
    pub variant_names: Vec<String>,
    // ...
}

属性将编译时元数据附加到 operation——常量值、标志、谓词类型、转换类型等等。

类型和属性的存储方式不同:

  • 类型是唯一化的(去重)。如果你创建了两个 MirTupleType { types: vec![i32, f32] } 实例,pliron 只存储一份并返回相同的指针。类型相等变成指针比较,而不是深层的结构比较。

  • 属性默认不唯一化。 每个 operation 内联携带自己的属性值。这匹配了 MLIR 中"properties"的工作方式——MLIR 最初也对属性去重,发现变更和每个 op 的状态处理很笨拙,于是引入了 properties 作为非存储式的、每个 op 的替代方案。Pliron 跳过了那个弯路,直接从 property 风格的设计开始。

如果你确实想去重某个属性(例如,一个许多 op 引用的 4 KB 常量查找表),可以通过 uniqued_any 工具按属性选择去重。模式是将重的数据包装在 UniquedKey<T> 中,并将 key 存储在属性内部:

struct SomeData { /* 大的或比较昂贵的载荷 */ }

#[pliron_attr(name = "...", dialect = "...")]
pub struct SomeDataAttr(pub UniquedKey<SomeData>);

现在 SomeDataAttr 是一个小的句柄,通过 key 进行比较;底层的 SomeData 在去重表中只存在一份。

Ptr<T>——安全的 arena 引用#

Ptr<T> 是 pliron 中指向 arena 的等价指针。在底层,它由两个字段组成:

pub struct Ptr<T> {
    index: u32,           // arena 中的槽位索引
    version: NonZeroU32,  // 分代版本
    _phantom: PhantomData<T>,
}

version 字段是使其安全的关键。当你删除了一个 operation,pliron 会递增该槽位的代数计数器。如果之后有人试图解引用一个版本不再匹配该槽位当前代数的旧 Ptr<Operation>,访问将失败,而不是读取新的(不相关的)占用者。

PhantomData<T> 确保编译时的类型安全——你不能意外地使用 Ptr<Operation> 来索引到 BasicBlock 的 arena。编译器不会允许。

备注

你通过 Context 与 arena 内容交互:

  • ptr.deref(ctx) 返回一个共享引用(&T)。

  • ptr.deref_mut(ctx) 返回一个排他引用(&mut T)。

这遵循 Rust 的借用模型——Context 是所有者,你通过它借用。

定义-使用链(内存安全)#

在任何编译器 IR 中,你需要不断回答两个问题:

  • 使用-定义:"这个值是在哪里定义的?"(给定一个 use,找到 definition。)

  • 定义-使用:"这个值在哪里被使用?"(给定一个 definition,找到所有的 uses。)

传统编译器用原始指针和手动簿记来实现这些链。删除一个 operation 时忘记更新 use-list?悬垂指针。替换一个值但漏掉了一个 use?过时的引用。欢迎来到你下午调试 opt -O2 段错误的时光。

Pliron 使用 Rust 的类型系统来实现定义-使用链。Pliron 中的 Value 是一个枚举:

pub enum Value {
    OpResult { op: Ptr<Operation>, index: usize },
    BlockArgument { block: Ptr<BasicBlock>, index: usize },
}

每个 value 要么是 operation 的结果,要么是 basic block 的参数。每个定义跟踪其所有使用的集合,每个使用都存储一个指向其定义的指针。当你调用 replace_all_uses_with 时,两侧自动更新。

这意味着:

  • 没有悬垂引用——删除一个 operation 会更新所有 use-list。

  • 没有过时的 use——替换一个 value 会传播给每个消费者。

  • IR 转换期间没有段错误——借用检查器和分代 arena 消除了困扰 C++ 编译器框架的整类 bug。

备注

如果你熟悉 LLVM 的 Value / Use / User 系统,pliron 的设计服务于相同的目的。区别在于 LLVM 用侵入式链表和原始 Value* 指针实现,而 pliron 用 arena 索引和 HashSet<UseNode> 实现。语义相似,但所有权模型直接符合 Rust IR 转换。

Op 接口#

大多数你想在整个 IR 上做的事情——验证它、打印它、将其降级到另一个 dialect——自然地表达为:"每个实现了接口 X 的 op 都知道如何做 X"。Pliron 通过 op 接口 使这个模式成为一等公民:小型 Rust trait,用 #[op_interface] 标记,任何 op(或 type、attribute)都可以实现。

你像定义普通的 Rust trait 一样定义接口,在上面加上 #[op_interface] 属性:

use pliron::derive::op_interface;

#[op_interface]
pub trait MirToLlvmConversion {
    fn convert(
        &self,
        ctx: &mut Context,
        rewriter: &mut DialectConversionRewriter,
        operands_info: &OperandsInfo,
    ) -> Result<()>;
}

每个 op 用 #[op_interface_impl] 实现接口:

use pliron::derive::op_interface_impl;

#[op_interface_impl]
impl MirToLlvmConversion for MirAddOp {
    fn convert(
        &self,
        ctx: &mut Context,
        rewriter: &mut DialectConversionRewriter,
        operands_info: &OperandsInfo,
    ) -> Result<()> {
        // ...将 MirAddOp 降级为一个或多个 dialect-llvm op...
    }
}

框架随后基于接口进行分发,而不是针对特定的具体 op 类型。一个将 MIR 降级为 LLVM 的 pass 大致如下所示:

// 在 DialectConversion::rewrite 内部,对于遇到的每个 op:
if let Some(converter) = op_cast::<dyn MirToLlvmConversion>(op) {
    converter.convert(ctx, rewriter, operands_info)?;
}

如果你明天添加了一个新的 MIR op 并编写了它的 #[op_interface_impl],降级 pass 会自动拾取它——不需要更新中心 match 语句,不需要扩展枚举。相同的模式也用于验证(#[op_interface] trait Verify)、打印和任何其他横切行为。

相同的 #[op_interface] / #[op_interface_impl] 机制也适用于类型和属性:一个类型可以实现 dyn MemorySemantics,一个常量属性可以实现 dyn TypedAttr,等等。

备注

在底层,op_cast 是一个小的运行时查找(一次哈希表探测),将 op 的具体类型映射到宏注册的接口实现。它仅在 pass 分发时触发,不在 IR 构建的热路径中。好处是添加一个 dialect 就是添加一个 crate——永远不需要修改中心枚举或泛型纠缠的函数签名。

cuda-oxide 如何使用 pliron#

cuda-oxide 定义了三个 dialect,每个都是独立的 crate。编译器管道为你将它们注册在 Context 上(且 pliron 的 builtin dialect 从 pliron 0.14 起自动注册),因此 kernel 作者和 pass 作者永远不需要考虑 dialect 设置——依赖该 crate 是你唯一要做的事情。

dialect-mir——Rust 语义#

dialect-mir 将 Rust 的中层 IR 捕捉为 pliron 操作,保留了如果我们直接降级到 LLVM 将会丢失的语义信息。

  • 函数定义MirFuncOp——每个设备函数的入口点,带有与 Rust 函数签名匹配的类型化 block 参数。

  • 算术与比较:带检查和不带检查的二元操作(mir.addmir.submir.eqmir.lt,……)保留了 Rust 的溢出语义。

  • 聚合类型MirTupleTypeMirStructTypeMirEnumTypeMirSliceTypeMirArrayType——一等 Rust 复合类型,带有像 mir.extract_fieldmir.get_discriminant 这样的操作。

  • 内存与控制流mir.loadmir.storemir.refmir.gotomir.cond_brmir.return——带有 GPU 地址空间跟踪(globalsharedlocaltmem)。

dialect-llvm——接近机器的 IR#

dialect-llvm 将 LLVM IR 建模为 pliron 操作,提供到文本 .ll 文件的一对一映射。

  • 算术与类型转换:所有 19 种 LLVM 二元操作(从 llvm.addllvm.frem),加上 13 种类型转换操作(llvm.sextllvm.truncllvm.bitcast,……)。

  • 控制流llvm.brllvm.cond_brllvm.switchllvm.returnllvm.unreachable——block 参数在导出时转换为 PHI 节点。

  • 文本导出dialect_llvm::export 模块生成有效的 LLVM IR 文本,包括 GPU kernel 的 @llvm.used 数组和 !nvvm.annotations 元数据。

dialect-nvvm——GPU 内部函数#

dialect-nvvm 将 LLVM 的 NVPTX 后端内部函数包装为类型化的 pliron 操作。

  • 线程索引nvvm.read_ptx_sreg_tid_xnvvm.read_ptx_sreg_ctaid_xnvvm.barrier0——thread::index_1d() 的构建模块。

  • Warp 级原语:shuffle(nvvm.shfl_sync)、vote 和 reduce 操作,用于 warp 协作算法。

  • 加速器操作:TMA 批量复制、WGMMA 矩阵乘加(Hopper)和 tcgen05 tensor core 操作(Blackwell)——高级 GPU 功能 章节背后的硬件指令。

每个 dialect 在 Pliron Dialects 中有详细讲述。