MIR Importer#

前几章解释了 rustc 如何产生 Stable MIR(rustc_public)以及 pliron 如何提供 IR 框架(Pliron)。本章是两者交汇的地方:mir-importer 接受 rustc 交给我们的 Stable MIR,并将其翻译成 dialect-mir——保留 Rust 语义的 pliron dialect。翻译器最初生成 alloca/load/store 形式——生成成本低、易于推理,并且在输入时对于 pliron 来说是等同的。随后的 pliron::opts::mem2reg pass 然后将这些槽位提升回 SSA 形式,使 dialect-mir 准备好降级到 dialect-llvm

但翻译只是工作的一半。mir-importer 还编排了整个编译管道:翻译、验证、降级、导出和生成 PTX。它既是翻译器,也是舞台导演。

该 crate 位于 crates/mir-importer 中,分为两部分:

  • translator/——MIR 到 pliron 的翻译逻辑(有趣的部分)。

  • pipeline.rs——将每个阶段串联在一起的编排(负责任的部分)。


管道编排#

在深入翻译细节之前,先看大局。run_pipeline() 函数是 rustc-codegen-cuda 在收集设备函数后调用的入口点。它接受一个 CollectedFunction 结构体列表和一个 PipelineConfig,然后运行六个阶段:

步骤 1:翻译 Rust MIR → `dialect-mir`
步骤 2:验证 `dialect-mir` 模块
步骤 3:运行 `pliron::opts::mem2reg` 将 alloca 槽位提升回 SSA
步骤 4:降级 `dialect-mir` → `dialect-llvm`(通过 mir-lower)
步骤 5:导出 `dialect-llvm` 为文本 LLVM IR(.ll)
步骤 6:运行 llc 将 .ll 编译为 .ptx

每个 CollectedFunction 携带管道需要了解的关于一个设备函数的全部信息:

pub struct CollectedFunction {
    pub instance: Instance,
    pub is_kernel: bool,
    pub export_name: String,
}

instance 是来自 rustc_public 的单态化函数。is_kernel 将 kernel 入口点与设备辅助函数区分开(kernel 在 LLVM IR 中获得特殊元数据,以便 NVPTX 后端将它们生成 .entry 点)。export_name 是出现在最终 PTX 中的符号名称——对于设备函数,这通常是一个完全限定名称(FQDN),与 CrateDef::name() 对同一函数返回的匹配。

对于每个函数,管道:

  1. 通过 instance.body() 检索 MIR 体。

  2. 调用 translate_function() 生成包含 dialect-mir 表示的 pliron 模块(使用 mir.alloca 槽位作为局部变量)。

  3. 在模块上运行 pliron 的验证器,以早期捕获结构错误——类型不匹配、缺少操作数、支配关系破坏——在它们变成下游晦涩的 LLVM 失败之前。

  4. 运行 pliron::opts::mem2reg 将 alloca 槽位提升回 dialect-mir 中的 SSA 值。

  5. 运行 lower_mir_to_llvm(来自 mir-lower crate)以通过 DialectConversion 将每个 dialect-mir 操作降级为其 dialect-llvm 等价物。

  6. dialect-llvm 模块导出为文本 .ll 字符串,写入磁盘,并调用 llc 生成最终的 .ptx 文件。

如果有任何步骤失败,管道停止并返回一个类型化错误(NoBodyTranslationVerificationLoweringExportPtxGeneration),附带足够的上下文来诊断问题。没有静默损坏,没有神秘的空输出文件。


翻译架构#

translator/ 目录是 Stable MIR 变成 pliron IR 的地方。每个模块处理 MIR 结构的一个层次,它们组合得干净利落:

模块

用途

body

函数级翻译、alloca 槽位播种、FQDN 名称清理

block

基本块翻译协调器

statement

语句翻译(赋值、存储)

terminator

终止器翻译(goto、call、return、基于 FQDN 的内部函数分发)

rvalue

表达式翻译(二元操作、类型转换、聚合)

types

Rust 类型到 dialect-mir 类型的转换

values

MIR local → alloca 槽位映射(ValueMap)+ 槽位地址空间推断

调用流程遵循 MIR 的结构自上而下:

translate_function()
  └─ body::translate_body()
       ├─ emit_entry_allocas()            // 每个非 ZST local 一个 mir.alloca
       │     └─ SlotAddrSpaceMap::analyze // 指针槽位的地址空间推断
       └─ 对于每个基本块:
            └─ block::translate_block()
                  ├─ statement::translate_statement()
                  │     └─ rvalue::translate_rvalue()
                  └─ terminator::translate_terminator()

translate_body() 设置函数的签名,创建与 MIR 基本块对应的 pliron block,清理函数名(将来自 instance.name() 的 FQDN 中的 :: 转换为 __),在入口块顶部为每个非 ZST local 生成一个 mir.alloca,并将传入的函数参数存储到它们各自的槽位中。每个非入口块保持无参数——跨块的数据流通过 alloca 槽位承载,而不是通过 block 参数。然后它按顺序遍历每个 block,逐个翻译语句和终止器。

translate_statement() 处理 block 中的扁平操作——赋值、存储活动/死亡标记和判别式写入。当赋值涉及右侧表达式(MIR Rvalue)时,它委托给 translate_rvalue(),后者处理二元操作、一元操作、类型转换、聚合构造、判别式读取、指针算术以及 Rust 编译产生的其他十几种东西。

translate_terminator() 处理结束 block 的操作:GotoSwitchIntCallReturnAssertDropUnreachable。这里也是内部函数分发所在的地方——但它有自己专门的章节。


SSA 挑战(以及我们如何推迟它)#

这是翻译中最棘手的部分,值得仔细解释。

Rust MIR 处于严格的 SSA 形式。局部变量是命名存储位置,任何 block 都可以读写。如果 _3bb0 中被赋值,它可以在 bb1bb5 或任何其他地方自由使用——MIR 不在乎。

Pliron IR(类似 MLIR)最终期望严格的 SSA:一个值必须支配所有使用,如果一个值需要从一个 block 流向另一个 block,它必须显式地作为 block 参数传递。你不能只是跨 block 抓取局部变量。

我们分两个阶段解决这种紧张关系:

  1. Importer:alloca + load/store。 mir-importer 特意直接构造 SSA。每个非 ZST 的 MIR local 由一个栈槽位支持——在入口块顶部生成一个 mir.alloca 并记录在 ValueMap 中。对局部变量的每次写入变成存储到其槽位的 mir.store;每次读取变成 mir.load。因此分支终止器都是零操作数的,且每个非入口 block 是无参数的:所有跨 block 数据流通过 alloca 槽位传递,而不是通过 block 参数。

  2. pliron::opts::mem2reg:槽位 → SSA。dialect-mir 模块验证后,pipeline.rs 运行 pliron 内置的 mem2reg pass。它将每个有资格的 alloca 提升回 SSA 值,将每个 load 重新连接到到达定义,并在值沿多条控制路径汇合的地方插入 block 参数(pliron 中 phi 节点的写法)。地址逃逸的槽位——那些我们真正需要保留在栈上的——在 dialect-mirdialect-llvm 降级时保持不动,以便转换为真正的 alloca

这是一个微型的问题示例。给定 MIR 中 _1bb0 中被写入并在 bb1 中被读取:

// Rust MIR
bb0: { _1 = 42_i32; goto -> bb1; }
bb1: { _0 = _1;     return; }

Importer 首先生成这个基于 alloca 的 dialect-mir(在 mem2reg 之前):

^bb0:
  %s1 = mir.alloca           : !mir.ptr<i32>
  %c  = mir.constant 42_i32  : i32
  mir.store %c, %s1
  mir.goto ^bb1                     // 零操作数;_1 通过 %s1 流动
^bb1:                               // 无 block 参数
  %r = mir.load %s1 : i32
  mir.return %r : i32

pliron::opts::mem2reg 提升了 %s1 之后,同一函数变为 SSA 形式的 dialect-mir,block 参数仅在实际需要汇合到达定义的地方出现:

^bb0:
  %c = mir.constant 42_i32 : i32
  mir.goto ^bb1(%c : i32)
^bb1(%r : i32):
  mir.return %r : i32

(具有单个前驱后继的函数,如上面的例子,最终只有一个 block 参数;具有多个到达定义的汇合处,是 mem2reg 引入非平凡 phi 风格参数的地方。)

要点:importer 从不运行活跃性分析,也从不跨 block 传递值。所有"哪个值在哪里是活跃的?"的推理都推迟给 mem2reg,它已经在一个 pass 中为整个 dialect-mir 模块正确解决了这个问题。translator/terminator/mod.rs 模块文档字符串以嵌入式源代码的形式携带了相同的示例,以供快速参考。


类型翻译#

types 模块将 Rust 类型(通过 rustc_public 看到的)转换为 dialect-mir 类型。大多数映射是直接的,但有几个值得关注:

Rust 类型

dialect-mir 类型

备注

i32u64

IntegerType

带有有符号性跟踪

f32f64

Float32Type / Float64Type

标准 IEEE 754

bool

IntegerType(1)

1 位整数,传统惯例

(A, B, C)

MirTupleType

异构乘积类型

&[T]

MirSliceType

指针 + 长度

DisjointSlice<T>

MirDisjointSliceType

安全性验证的 mutable slice

struct Foo

MirStructType

带有来自 rustc layout 的字段偏移

*mut T / *const T

MirPtrType

带有 GPU 地址空间

enum Option<T>

MirEnumType

判别式 + 变体

动态结构体布局#

这里有一个能为用户省去真正麻烦的微妙之处。考虑这个结构体:

struct Extreme {
    a: u8,
    b: i128,
}

Rust 的布局算法可能会为对齐重新排序字段:

用户编写:             struct Extreme { a: u8, b: i128 }
rustc 可能的布局:     [b: i128 @ 偏移 0][a: u8 @ 偏移 16]
MirStructType:        mem_to_decl 映射、偏移量、total_size
LLVM 结构体:          { i128, i8, [15 x i8] }   // 显式填充

cuda-oxide 向 rustc 查询每个字段的确切字节偏移,并用显式的填充字节构建结构体类型。MirStructType 存储一个 mem_to_decl 映射(内存顺序到声明顺序)、每个字段的偏移量以及总大小。降级到 LLVM 时,填充被具体化为字段之间的 [N x i8] 数组。

实际效果:#[repr(C)] 不是必需的,对于宿主和设备代码之间共享的类型。cuda-oxide 自动匹配 rustc 的布局,因此你的结构体可以使用 Rust 的默认 repr(Rust) 布局,编译器会在两侧做正确的事。少一个要记住的属性,少一个要踩的地雷。


内部函数分发#

当翻译器遇到 Call 终止器时,它不会立即生成 mir.call 操作。首先,它检查被调用者是否为已知的内部函数——来自 cuda_device 的函数,直接映射到 GPU 硬件指令而非带有函数体的函数。

try_dispatch_intrinsic() 函数根据被调用者的**完全限定域名(FQDN)**进行匹配,该名称从 CrateDef::name() 获取:

match name {
    "cuda_device::thread::threadIdx_x" => emit_nvvm_intrinsic(ReadPtxSregTidXOp),
    "cuda_device::warp::shuffle_xor"   => emit_warp_shuffle_i32(ShflSyncBflyI32Op),
    "cuda_device::sync::syncthreads"   => emit_nvvm_intrinsic(Barrier0Op),
    // ... 100+ 个内部函数
    _ => translate_as_normal_call()
}

匹配使用完整的 FQDN(例如 cuda_device::thread::threadIdx_x,而不仅仅是 threadIdx_x),以避免不同模块中同名函数之间的歧义。对于非泛型、非内部函数的调用,相同的 FQDN 也用作调用目标名称——collector 生成匹配的名称,降级层在两侧将 :: 转换为 __

如果函数是识别的内部函数,翻译器直接生成相应的 dialect-nvvm 操作——没有函数体,没有调用开销,只有硬件指令。线程索引、warp shuffle、barrier、共享内存操作、TMA 批量复制和矩阵乘法指令都走这条路径。

如果函数不是内部函数,它落到正常路径:生成一个 mir.call 操作,通过符号名称引用被调用者。被调用者的体将被单独翻译(它也在收集的函数列表中),所以一切都能链接起来。

备注

内部函数分发表是向 cuda-oxide 添加新 GPU 操作的主要扩展点。如果 NVIDIA 推出了新指令并且你想暴露它,你在 cuda_device 中添加一个函数,添加一个 dialect-nvvm op,并在这里添加一个匹配分支。参见 添加新的内部函数 以获取逐步指南。


处理展开路径#

MIR 忠实地建模了 Rust 的 panic 语义。每个函数调用有两个可能的后续——一个返回目标和一个展开目标:

_2 = mul(_1, _3) -> [return: bb1, unwind: bb2]

在 CPU 上,展开路径很重要:它运行析构函数,展开栈,并要么捕获 panic 要么中止进程。在 GPU 上,CUDA 工具链今天没有暴露这种能力——nvcc/ptxas 剥离 landing pad,且没有异常处理基础设施留存到 PTX 中。硬件本身可以支持展开(Volta 之后的绝对分支 + 每线程调用栈跟踪是足够的),但编译器和运行时没有将其连接起来。NVIDIA 有一个活跃的项目,用于为汽车安全添加 C++ 异常支持;当前的 cuda-oxide 设计向前兼容该工作。

目前,cuda-oxide 将所有展开路径视为不可达。如果在运行时发生了 panic——比如说,debug 模式下的整数溢出或显式的 panic!()——GPU 会 trap,kernel 会崩溃。这在语义上等同于 panic=abort,而不需要用户设置该标志。

在实践中,翻译器简单地忽略每个 CallAssert 终止器中的展开目标,只生成返回路径分支。展开 block 从未被翻译。它们消失了,就像从未存在过一样。

这并不像听起来那么可怕。Rust 的借用检查器和类型系统阻止了大多数会引起 panic 的 bug。而对于那些漏网之鱼(数组边界检查、None 上的 unwrap),GPU trap 无论如何都是正确的行为——一个 GPU 线程在 kernel 中间无法对逻辑错误做任何有用的"恢复"。


汇总#

让我们跟踪一个简单的 kernel 通过完整管道,看看所有部分如何连接。这是一个向量加法 kernel:

#[kernel]
pub fn vecadd(a: &[f32], b: &[f32], mut c: DisjointSlice<f32>) {
    let idx = thread::index_1d();
    if let Some(c_elem) = c.get_mut(idx) {
        *c_elem = a[idx.get()] + b[idx.get()];
    }
}

mir-importer 将 Stable MIR 翻译为 dialect-mir(并且 pliron::opts::mem2reg 将 alloca 槽位提升回 SSA)后,结果看起来大致如下(简化版,为清晰起见省略了许多细节):

mir.func @vecadd(%a: mir.slice<f32>, %b: mir.slice<f32>,
                 %c: mir.disjoint_slice<f32>) {
^entry:
    %idx = nvvm.read_ptx_sreg_tid_x : i32
    %len = mir.extract_field %c[1]       // slice 长度
    %in_bounds = mir.lt %idx, %len
    mir.cond_br %in_bounds, ^compute, ^exit

^compute:
    %a_val = mir.load ...                // a[idx]
    %b_val = mir.load ...                // b[idx]
    %sum = mir.add %a_val, %b_val : f32
    mir.store %sum, ...                  // c[idx] = sum
    mir.goto ^exit

^exit:
    mir.return
}

注意几点:

  • thread::index_1d() 作为内部函数分发,变成了 nvvm.read_ptx_sreg_tid_x——直接的 GPU 寄存器读取,而不是函数调用。

  • DisjointSlice::get_mut() 变成了边界检查(mir.lt)和一个条件分支。Rust 中的 if let Some 模式变成了显式的控制流。

  • Block 参数在这里不存在。记住,mir-importer 本身将每个分支终止器生成为零操作数,且每个非入口 block 无参数——值通过 alloca 槽位跨越 block 边界,直到 pliron::opts::mem2reg 运行。在这个函数中,mem2reg 可以看到没有任何东西需要在 ^compute^exit 处汇合(没有值在分支中存活下来),因此它将这些槽位提升走而不引入任何 block 参数。一个有循环的 kernel,其值跨越回边存活,将在 mem2reg 之后以 mir.goto ^header(%i, %acc)^header(%i: i32, %acc: f32) 结束——这是 pliron 中 phi 节点的写法。

  • 没有展开路径。 原始 MIR 在每一个可能 panic 的操作上都有展开目标。它们消失了。

从这里开始,管道接管:

  1. 验证——pliron 检查每个操作的类型是否匹配,每个 block 的参数是否正确,以及支配关系是否成立。

  2. 降级——lower_mir_to_llvmmir.add 转换为 llvm.fadd,将 mir.load 转换为 llvm.load,将 mir.slice 转换为指针和长度的 LLVM 结构体,等等。

  3. 导出——dialect-llvm 被打印为文本 .ll 文件,带有适当的 !nvvm.annotations 元数据,将 vecadd 标记为 kernel 入口点。

  4. llc——LLVM 的 NVPTX 后端将 .ll 编译为 .ptx,结果被写入宿主二进制文件旁边。

dialect-mir 忠实地捕捉了 Rust 语义。下一步是将其降级为 LLVM 能够理解的内容——在 降级管道 中讲述。