共享内存与同步#

CUDA 线程块中的每个线程都可以访问一小块高速的暂存区,称为共享内存(shared memory)。它位于芯片上,紧邻 SM 的执行单元——比全局内存大约快 100 倍,比 L1 缓存大约快 10 倍。但代价是容量:取决于架构,通常每个 SM 为 48–228 KB,由运行在该 SM 上的所有 block 共享。

cuda-oxide 通过 SharedArrayDynamicSharedArray 暴露共享内存,两者在使用体验上设计得类似于 Rust 数组,同时编译为 PTX 地址空间 3。本章介绍如何使用它们、何时同步,以及将朴素 kernel 变为快速 kernel 的分块(tiling)模式。

参见

CUDA Programming Guide — Shared Memory 了解 bank 结构、广播规则以及各架构容量的硬件细节。


为什么共享内存很重要#

考虑异步 MLP 流水线项目中的朴素 GEMM。每个线程通过从全局内存读取 A 的整行和 B 的整列来计算一个输出元素。对于一个 64×64 的矩阵,每个线程需要 128 次全局加载,而 block 中的其他线程会加载许多相同的元素。硬件忠实地从 DRAM 中取出每一个——或者如果你运气好,从 L2 缓存中取出。

共享内存改变了经济性。一个线程块协同地将 A 的一个**分块(tile)**和 B 的一个分块加载到共享内存中,同步,然后每个线程从分块中读取。每次全局加载被所有复用该元素的线程分摊。对于 16×16 的分块,全局内存流量减少了 16 倍。

advanced/images/shared-memory-tiling.svg

分块计算模式。线程协同地将 A 和 B 的分块从全局内存加载到 SharedArray 中,同步,从快速的片上内存计算部分乘积,再次同步,然后沿 K 维度对下一个分块重复此过程。#


SharedArray —— 静态共享内存#

SharedArray<T, N, ALIGN> 是一个在编译时于共享内存中分配的固定大小数组。你在 kernel 内部将其声明为 static mut

use cuda_device::thread::Runtime2DIndex;
use cuda_device::{kernel, thread, DisjointSlice, SharedArray};

const TILE: usize = 16;

#[kernel]
pub fn tiled_sgemm(
    m: u32, n: u32, k: u32,
    a: &[f32], b: &[f32],
    mut c: DisjointSlice<f32, Runtime2DIndex>,
) {
    static mut TILE_A: SharedArray<f32, 256> = SharedArray::UNINIT;
    static mut TILE_B: SharedArray<f32, 256> = SharedArray::UNINIT;

    let n_sz = n as usize;
    let row = thread::index_2d_row();
    let col = thread::index_2d_col();
    let tx = thread::threadIdx_x() as usize;
    let ty = thread::threadIdx_y() as usize;

    let mut sum = 0.0f32;
    let mut t = 0u32;

    while t < k / TILE as u32 {
        let tile_offset = t as usize * TILE;

        // Phase 1: cooperative load
        unsafe {
            TILE_A[ty * TILE + tx] = a[row * k as usize + tile_offset + tx];
            TILE_B[ty * TILE + tx] = b[(tile_offset + ty) * n_sz + col];
        }

        // All threads must finish loading before any thread reads
        thread::sync_threads();

        // Phase 2: compute from shared memory
        let mut i = 0usize;
        while i < TILE {
            unsafe {
                sum += TILE_A[ty * TILE + i] * TILE_B[i * TILE + tx];
            }
            i += 1;
        }

        // Must sync before overwriting the tile in the next iteration
        thread::sync_threads();

        t += 1;
    }

    // SAFETY: every thread sees the same `n_sz`.
    if let Some(c_idx) = unsafe { thread::index_2d_runtime(n_sz) } {
        if let Some(c_elem) = c.get_mut(c_idx) {
            *c_elem = sum;
        }
    }
}

声明规则#

SharedArray 必须在 kernel 函数体内声明为 static mut。这告诉编译器将其分配在 PTX 共享地址空间(.shared)中。UNINIT 常量跳过初始化——在线程写入之前,内容是未定义的。

参数

含义

T

元素类型(f32u32 等)

N

元素数量(编译时固定)

ALIGN

字节对齐(默认 0 = 自然对齐;TMA 目标使用 128)

小技巧

SharedArray 设计上是 !Sync 的——它包装了 UnsafeCell,防止编译器假设跨线程的不可变性。这是正确的:共享内存在本质上可以被 block 中的所有线程修改,程序员负责同步。

API 概览#

impl<T, const N: usize, const ALIGN: usize> SharedArray<T, N, ALIGN> {
    pub const UNINIT: Self;
    pub const fn len() -> usize;
    pub fn as_ptr(&self) -> *const T;
    pub fn as_mut_ptr(&mut self) -> *mut T;
}

// 索引(通过 static mut 进行 unsafe 访问)
impl Index<usize> for SharedArray<T, N, ALIGN> { ... }
impl IndexMut<usize> for SharedArray<T, N, ALIGN> { ... }

索引在 debug 构建中进行边界检查。在 release 模式下(以及在 GPU 上),边界检查会被省略。如果越界索引,你会得到未定义行为——与 Rust 中任何 static mut 访问的规则相同。


DynamicSharedArray —— 运行时大小分配#

有时你在编译时不知道分块大小,或者你想在多个逻辑数组之间共享同一个共享内存池。DynamicSharedArray 从动态共享内存区域分配,其大小在 launch 时通过 LaunchConfig::shared_mem_bytes 设置:

use cuda_device::{kernel, thread, DynamicSharedArray};

#[kernel]
pub fn reduce_dynamic(input: &[f32], n: u32, mut output: DisjointSlice<f32>) {
    let tid = thread::threadIdx_x() as usize;

    // 获取指向动态共享内存区域的指针
    let smem: *mut f32 = DynamicSharedArray::<f32>::get();

    unsafe {
        // 从全局内存加载
        let idx = thread::index_1d();
        *smem.add(tid) = if idx.get() < n as usize {
            input[idx.get()]
        } else {
            0.0
        };
    }

    thread::sync_threads();

    // 树形规约
    let mut stride = thread::blockDim_x() as usize / 2;
    while stride > 0 {
        if tid < stride {
            unsafe {
                *smem.add(tid) += *smem.add(tid + stride);
            }
        }
        thread::sync_threads();
        stride /= 2;
    }

    if tid == 0 {
        let block_idx = thread::blockIdx_x() as usize;
        unsafe {
            *output.get_unchecked_mut(block_idx) = *smem;
        }
    }
}

Launch 配置:

let config = LaunchConfig {
    grid_dim: ((n + 255) / 256, 1, 1),
    block_dim: (256, 1, 1),
    shared_mem_bytes: 256 * std::mem::size_of::<f32>() as u32,
};

分区动态共享内存#

如果需要从同一个动态池中分配多个数组,使用 offset()

let pool_a: *mut f32 = DynamicSharedArray::<f32>::get();
let pool_b: *mut f32 = DynamicSharedArray::<f32>::offset(
    256 * std::mem::size_of::<f32>()
);

offset 接受一个从池起始位置的字节偏移量。确保总量不超过 shared_mem_bytes——没有运行时保护。

对齐#

DynamicSharedArray 默认为 16 字节对齐。对于 TMA 操作(Hopper+),使用 DynamicSharedArray<f32, 128> 以获得所需的 128 字节对齐。对齐信息编码在 PTX .align 指令中。


同步:sync_threads()#

thread::sync_threads() 是一个block 级屏障。它编译为 PTX bar.sync 0,并保证两件事:

  1. block 中的所有线程在任何一个线程继续执行之前都已到达屏障。

  2. 这些线程的所有内存写入在屏障之后对所有线程可见(它充当共享内存的内存屏障)。

如果在加载和计算阶段之间没有使用 sync_threads(),某些线程可能会在另一个线程写入共享内存位置之前读取它。硬件不保证 warp 内共享内存存储的任何顺序——即使同一个 warp 中的线程,在没有屏障的情况下也可能看到过时的值。

何时同步#

规则很简单:在读取另一个线程写入的数据之前进行同步

情况

需要同步?

线程 A 写 TILE[i],线程 B 读 TILE[i]

线程 A 写 TILE[i],线程 A 读 TILE[i]

否(同一线程)

为下一个循环迭代覆盖分块

是(在新 load 覆盖其他线程可能仍在读取的数据之前)

读取 DisjointSlice(每个线程读取自己的索引)

上述分块 GEMM 每次迭代有两个同步点:一个在加载之后(计算之前),一个在计算之后(下一次 load 覆盖分块之前)。缺少任何一个都是数据竞争。

小技巧

一个常见错误是将 sync_threads() 放在并非所有线程都会进入的条件分支中。block 中的每个线程都必须到达同一个 sync_threads() 调用,否则 kernel 会死锁。如果你需要发散的控制流,请重新组织代码,使屏障位于分支之外。


共享内存 vs. 其他方案#

方案

延迟

容量

程序员工作量

全局内存(朴素)

~500 cycles

GB 级

L1/L2 缓存(隐式)

~30–100 cycles

128 KB–40 MB

共享内存(显式)

~5 cycles

每 SM 48–228 KB

分块、同步、bank 感知

寄存器

~1 cycle

每 SM 64K×32 位

编译器管理

共享内存是程序员在缓存不够用时的工具。L1 和 L2 缓存可以自动提供帮助,但它们受制于访问模式和驱逐策略。共享内存给予你显式控制:你决定加载什么、何时加载、以及保留多久。


Bank 冲突#

共享内存被划分为 32 个 bank(每个 warp lane 一个)。如果同一个 warp 中的两个线程访问映射到同一个 bank 的不同地址,这些访问会被串行化——这就是 bank 冲突。2 路冲突的惩罚是 2× 延迟,最坏情况下 32 路冲突可达 32× 延迟。

映射规则很直观:连续的 32 位字映射到连续的 bank。因此 TILE[0] 在 bank 0,TILE[1] 在 bank 1,……,TILE[32] 回到 bank 0。一种常见的无冲突访问模式是:

// 每个线程读取不同的列:线程 k 读取 TILE[row + k]
// 如果 TILE_WIDTH = 32(或者是 32 的倍数),添加填充:SharedArray<f32, 33 * 16>

对于 16×16 的分块 GEMM,内层循环中没有 bank 冲突,因为 TILE_A[ty * 16 + i] 按行读取(连续元素 = 不同的 bank),而 TILE_B[i * 16 + tx] 以步长 16 按列读取。

参见

CUDA Programming Guide — Shared Memory Bank Conflicts 详细了解 bank 冲突规则和填充策略。


总结#

以下是从朴素到分块的演进过程,应用于我们 MLP 流水线中的 GEMM:

版本

每线程全局加载次数(64×64)

每线程共享加载次数

加速比

朴素(sgemm_naive

128

0

分块(16×16 分块)

8(4 次迭代 × 2 个分块)

128

~4–10×

分块 + 双缓冲

8

128(重叠执行)

~6–15×

分块版本用共享内存带宽和计算换取全局内存带宽——当 kernel 受内存带宽限制时,这是一个很好的交换。双缓冲变体将下一个分块的加载与当前分块的计算重叠,隐藏了更多延迟,但需要每个矩阵两个 SharedArray 以及更复杂的同步。

参见

  • Warp 级编程 —— 基于 shuffle 的规约,作为小规模规约时共享内存的替代方案

  • Tensor Memory Accelerator —— 硬件加速的全局→共享拷贝,替代手动加载循环