共享内存与同步
CUDA 线程块中的每个线程都可以访问一小块高速的暂存区,称为共享内存(shared memory) 。它位于芯片上,紧邻 SM 的执行单元——比全局内存大约快 100 倍,比 L1 缓存大约快 10 倍。但代价是容量:取决于架构,通常每个 SM 为 48–228 KB,由运行在该 SM 上的所有 block 共享。
cuda-oxide 通过 SharedArray 和 DynamicSharedArray 暴露共享内存,两者在使用体验上设计得类似于 Rust 数组,同时编译为 PTX 地址空间 3。本章介绍如何使用它们、何时同步,以及将朴素 kernel 变为快速 kernel 的分块(tiling)模式。
为什么共享内存很重要
考虑异步 MLP 流水线 项目中的朴素 GEMM。每个线程通过从全局内存读取 A 的整行和 B 的整列来计算一个输出元素。对于一个 64×64 的矩阵,每个线程需要 128 次全局加载,而 block 中的其他线程会加载许多相同的元素。硬件忠实地从 DRAM 中取出每一个——或者如果你运气好,从 L2 缓存中取出。
共享内存改变了经济性。一个线程块协同地将 A 的一个**分块(tile)**和 B 的一个分块加载到共享内存中,同步,然后每个线程从分块中读取。每次全局加载被所有复用该元素的线程分摊。对于 16×16 的分块,全局内存流量减少了 16 倍。
分块计算模式。线程协同地将 A 和 B 的分块从全局内存加载到 SharedArray 中,同步,从快速的片上内存计算部分乘积,再次同步,然后沿 K 维度对下一个分块重复此过程。
SharedArray —— 静态共享内存
SharedArray<T, N, ALIGN> 是一个在编译时于共享内存中分配的固定大小数组。你在 kernel 内部将其声明为 static mut :
use cuda_device :: thread :: Runtime2DIndex ;
use cuda_device ::{ kernel , thread , DisjointSlice , SharedArray };
const TILE : usize = 16 ;
#[kernel]
pub fn tiled_sgemm (
m : u32 , n : u32 , k : u32 ,
a : & [ f32 ], b : & [ f32 ],
mut c : DisjointSlice < f32 , Runtime2DIndex > ,
) {
static mut TILE_A : SharedArray < f32 , 256 > = SharedArray :: UNINIT ;
static mut TILE_B : SharedArray < f32 , 256 > = SharedArray :: UNINIT ;
let n_sz = n as usize ;
let row = thread :: index_2d_row ();
let col = thread :: index_2d_col ();
let tx = thread :: threadIdx_x () as usize ;
let ty = thread :: threadIdx_y () as usize ;
let mut sum = 0.0 f32 ;
let mut t = 0 u32 ;
while t < k / TILE as u32 {
let tile_offset = t as usize * TILE ;
// Phase 1: cooperative load
unsafe {
TILE_A [ ty * TILE + tx ] = a [ row * k as usize + tile_offset + tx ];
TILE_B [ ty * TILE + tx ] = b [( tile_offset + ty ) * n_sz + col ];
}
// All threads must finish loading before any thread reads
thread :: sync_threads ();
// Phase 2: compute from shared memory
let mut i = 0 usize ;
while i < TILE {
unsafe {
sum += TILE_A [ ty * TILE + i ] * TILE_B [ i * TILE + tx ];
}
i += 1 ;
}
// Must sync before overwriting the tile in the next iteration
thread :: sync_threads ();
t += 1 ;
}
// SAFETY: every thread sees the same `n_sz`.
if let Some ( c_idx ) = unsafe { thread :: index_2d_runtime ( n_sz ) } {
if let Some ( c_elem ) = c . get_mut ( c_idx ) {
* c_elem = sum ;
}
}
}
声明规则
SharedArray 必须在 kernel 函数体内声明为 static mut 。这告诉编译器将其分配在 PTX 共享地址空间(.shared )中。UNINIT 常量跳过初始化——在线程写入之前,内容是未定义的。
小技巧
SharedArray 设计上是 !Sync 的——它包装了 UnsafeCell ,防止编译器假设跨线程的不可变性。这是正确的:共享内存在本质上可以被 block 中的所有线程修改,程序员负责同步。
API 概览
impl < T , const N : usize , const ALIGN : usize > SharedArray < T , N , ALIGN > {
pub const UNINIT : Self ;
pub const fn len () -> usize ;
pub fn as_ptr ( & self ) -> * const T ;
pub fn as_mut_ptr ( & mut self ) -> * mut T ;
}
// 索引(通过 static mut 进行 unsafe 访问)
impl Index < usize > for SharedArray < T , N , ALIGN > { .. . }
impl IndexMut < usize > for SharedArray < T , N , ALIGN > { .. . }
索引在 debug 构建中进行边界检查。在 release 模式下(以及在 GPU 上),边界检查会被省略。如果越界索引,你会得到未定义行为——与 Rust 中任何 static mut 访问的规则相同。
DynamicSharedArray —— 运行时大小分配
有时你在编译时不知道分块大小,或者你想在多个逻辑数组之间共享同一个共享内存池。DynamicSharedArray 从动态共享内存区域分配,其大小在 launch 时通过 LaunchConfig::shared_mem_bytes 设置:
use cuda_device ::{ kernel , thread , DynamicSharedArray };
#[kernel]
pub fn reduce_dynamic ( input : & [ f32 ], n : u32 , mut output : DisjointSlice < f32 > ) {
let tid = thread :: threadIdx_x () as usize ;
// 获取指向动态共享内存区域的指针
let smem : * mut f32 = DynamicSharedArray :: < f32 > :: get ();
unsafe {
// 从全局内存加载
let idx = thread :: index_1d ();
* smem . add ( tid ) = if idx . get () < n as usize {
input [ idx . get ()]
} else {
0.0
};
}
thread :: sync_threads ();
// 树形规约
let mut stride = thread :: blockDim_x () as usize / 2 ;
while stride > 0 {
if tid < stride {
unsafe {
* smem . add ( tid ) += * smem . add ( tid + stride );
}
}
thread :: sync_threads ();
stride /= 2 ;
}
if tid == 0 {
let block_idx = thread :: blockIdx_x () as usize ;
unsafe {
* output . get_unchecked_mut ( block_idx ) = * smem ;
}
}
}
Launch 配置:
let config = LaunchConfig {
grid_dim : (( n + 255 ) / 256 , 1 , 1 ),
block_dim : ( 256 , 1 , 1 ),
shared_mem_bytes : 256 * std :: mem :: size_of :: < f32 > () as u32 ,
};
分区动态共享内存
如果需要从同一个动态池中分配多个数组,使用 offset() :
let pool_a : * mut f32 = DynamicSharedArray :: < f32 > :: get ();
let pool_b : * mut f32 = DynamicSharedArray :: < f32 > :: offset (
256 * std :: mem :: size_of :: < f32 > ()
);
offset 接受一个从池起始位置的字节偏移量 。确保总量不超过 shared_mem_bytes ——没有运行时保护。
对齐
DynamicSharedArray 默认为 16 字节对齐。对于 TMA 操作(Hopper+),使用 DynamicSharedArray<f32, 128> 以获得所需的 128 字节对齐。对齐信息编码在 PTX .align 指令中。
同步:sync_threads()
thread::sync_threads() 是一个block 级屏障 。它编译为 PTX bar.sync 0 ,并保证两件事:
block 中的所有线程 在任何一个线程继续执行之前都已到达屏障。
这些线程的所有内存写入 在屏障之后对所有线程可见(它充当共享内存的内存屏障)。
如果在加载和计算阶段之间没有使用 sync_threads() ,某些线程可能会在另一个线程写入共享内存位置之前读取它。硬件不保证 warp 内共享内存存储的任何顺序——即使同一个 warp 中的线程,在没有屏障的情况下也可能看到过时的值。
何时同步
规则很简单:在读取另一个线程写入的数据之前进行同步 。
上述分块 GEMM 每次迭代有两个同步点:一个在加载之后(计算之前),一个在计算之后(下一次 load 覆盖分块之前)。缺少任何一个都是数据竞争。
小技巧
一个常见错误是将 sync_threads() 放在并非所有线程都会进入的条件分支中。block 中的每个线程都必须到达同一个 sync_threads() 调用,否则 kernel 会死锁。如果你需要发散的控制流,请重新组织代码,使屏障位于分支之外。
共享内存 vs. 其他方案
共享内存是程序员在缓存不够用时的工具。L1 和 L2 缓存可以自动提供帮助,但它们受制于访问模式和驱逐策略。共享内存给予你显式控制:你决定加载什么、何时加载、以及保留多久。
Bank 冲突
共享内存被划分为 32 个 bank (每个 warp lane 一个)。如果同一个 warp 中的两个线程访问映射到同一个 bank 的不同地址,这些访问会被串行化——这就是 bank 冲突 。2 路冲突的惩罚是 2× 延迟,最坏情况下 32 路冲突可达 32× 延迟。
映射规则很直观:连续的 32 位字映射到连续的 bank。因此 TILE[0] 在 bank 0,TILE[1] 在 bank 1,……,TILE[32] 回到 bank 0。一种常见的无冲突访问模式是:
// 每个线程读取不同的列:线程 k 读取 TILE[row + k]
// 如果 TILE_WIDTH = 32(或者是 32 的倍数),添加填充:SharedArray<f32, 33 * 16>
对于 16×16 的分块 GEMM,内层循环中没有 bank 冲突,因为 TILE_A[ty * 16 + i] 按行读取(连续元素 = 不同的 bank),而 TILE_B[i * 16 + tx] 以步长 16 按列读取。
总结
以下是从朴素到分块的演进过程,应用于我们 MLP 流水线中的 GEMM:
分块版本用共享内存带宽和计算换取全局内存带宽——当 kernel 受内存带宽限制时,这是一个很好的交换。双缓冲变体将下一个分块的加载与当前分块的计算重叠,隐藏了更多延迟,但需要每个矩阵两个 SharedArray 以及更复杂的同步。