triton-MM

triton Matrix Multiplication

triton入门:Matrix Multiplication

​ 直接从 Triton 的官方教程入手。对 MM 做优化, 不论用的是 Triton, 还是 CUDA 目标都是一样的,一句话概括:计算一般是没办法省的, 主要是优化内存的使用, 尽可能的用高速但是比较小的 shared memory

  • 驱动程序(Driver Program):这是在 CPU 上运行的 Python 代码,用于准备数据、配置内核参数,并调用 Triton 内核
  • 算子(Operator)本身:这是用 Triton 编写的 GPU 内核,实际执行高性能计算

核函数定义:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
@triton.jit
def matmul_kernel(
# Pointers to matrices
a_ptr, b_ptr, c_ptr,
# Matrix dimensions
M, N, K,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
# by to get the element one row down (A has M rows).
stride_am, stride_ak, #
stride_bk, stride_bn, #
stride_cm, stride_cn,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, #
GROUP_SIZE_M: tl.constexpr, #
ACTIVATION: tl.constexpr #
):

  • triton.jit 基本上是告诉 Triton 编译器这个函数需要被即时编译(Just-In-Time, JIT)。这意味着当这个函数第一次被调用时,Triton 会将其编译成针对特定硬件优化的机器码,从而提高运行效率

  • a_ptr b_ptr c_ptr:分别指向输入矩阵 A、B 和输出矩阵 C 的指针。这些矩阵存储在 GPU 的全局内存(global memory)中

  • Triton 里面一般都是通过 pointer 这种方式来表示变量。为什么?Triton 本质上是要追求效率的, 而效率中很重要的一环就是 memory IO, 也就是 data load. 所以 Triton 希望用户能够直接操控数据的载入过程,来写出来更加高效的 kernel。所以,在 Triton 中,每个数字 scalar 都需要通过他的 pointer,手动 load 到 memory 中来

  • 输入中还有各种 stride,这些变量表示指针在不同维度上移动一个元素时应增加的数量。我们可以直接从 PyTorch 来看

1
2
3
a = torch.rand([3,6])
a.stride()
# (6, 1)

​ 这里的第一个维度的 stride 是 6,因为从 a[m, k] 的地址 到 a[m+1, k] 的地址,中间差了 6 个元素 (具体差了多少 byte 取决于数据类型)。第二个维度的 stride 是 1, 因为从 a[m, k] 的地址 到 a[m, k+1] 的地址, 中间差了 1 个元素

stride 的作用就是为了更加方便的找到每个元素的 pointer (地址)

  • tl.constexpr 标记表明这些参数是编译时常量。这意味着它们的值在编译时就已经确定,并且不会在运行时改变。这可以帮助编译器进行更多的优化。可以理解成这个 kernel 的 hyper-parameters,这个超参数后面可以由 Triton compiler 来进行搜索不同的值进行效率优化

虚拟的"循环"

1
pid = tl.program_id(axis=0)

program_id 是另外一个非常重要的概念。我们写的 Kernel, 比如这里的 matmul_kernel 其实要被重复执行很多次,每次执行处理输入 (比如这里的 a_ptr b_ptr c_ptr 这三个 tensor) 的一部分, 直到所有的部分都处理完

​ 但是,由于 triton 为向量化并行编程,这里的并没有以上过程的 for 循环。但是我们在编程的时候, 心里要有这么一个概念,我们姑且称之为 “循环”。这里的 program_id 就是这个虚拟的 for “循环” 里面的 index (第几次循环)

​ 这里面的 axis=0。是说这个 “循环”, 有几层,这里因为只有 axis=0,所以只有一层. 如果还有 axis=1,那说明还有嵌套的第二层。另外,我们在调用这个 kernel 的时候, 也是需要说明这个循环有多少层,这个就是 grid 的概念

 这里 grid 的概念和 cuda 编程中的 grid 概念类似,grid 对应 cuda 编程中的一个核函数(对应一个 grid),grid 的数量就是线程块的数量,triton 将 block 以下级别的运算都给隐藏起来了,因此使用 triton 编程的时候只需要考虑 block 的资源分配就行了

1
2
3
4
5
6
7
8
9
10
11
grid = lambda META: (
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
)
matmul_kernel[grid](
a, b, c,
M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
ACTIVATION=activation
)
  • lambda 函数:一个匿名函数,接受一个字典 META 作为参数。
  • triton.cdiv:Triton 提供的一个函数,用于计算向上取整的除法结果。triton.cdiv(a, b) 等于 ceil(a / b)
  • 返回值:lambda 函数返回一个元组,其中包含一个值(值的个数要和 tl.program_idaxis 参数匹配,),该值表示网格(grid)的大小。网格大小决定了将启动多少个线程块(block)来执行内核
  • matmul_kernel[grid]:这里使用了 Triton 的语法糖,matmul_kernel[grid] 表示使用 grid 函数来计算网格大小,并启动 matmul_kernel 内核

Memory id(本篇文章最重要的部分)

​ 我们要算的是矩阵乘法 A×B=CA\times B =C,A 的大小是 M x K,B 是 K x N,C 的大小是 M x N, kernel function 的要义不是要一把算完,而是每次算出 C 的一部分,我们假设每次"循环"算的大小是 BLOCK_SIZE_M x BLOCK_SIZE_N,那么总共要 “循环” MBLOCK_SIZE_M×NBLOCK_SIZE_N\frac{M}{BLOCK\_SIZE\_M} \times \frac{N}{BLOCK\_SIZE\_N}

1
2
3
4
5
6
7
8
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m

​ 这个地方是 kernel 级别的编程特有的。我们虽然确定了总共要 “循环” MBLOCK_SIZE_M×NBLOCK_SIZE_N\frac{M}{BLOCK\_SIZE\_M} \times \frac{N}{BLOCK\_SIZE\_N} 次,但是按照什么顺序来循环其实是有很多不同选择的, 而不同的选择其实导致的 efficiency 也是不一样的,我们下面展示两种矩阵乘法的示意图:

visualize

​ 上图中, 我们主要看矩阵 C. 每一个黄色小块的大小是 BLOCK_SIZE_M x BLOCK_SIZE_N。也就是一个 block。图中展示了两种 “循环” 的方式:

  • 第一种叫做 row-major ordering,就是最直接的按照行的模型一个 block 一个 block 地往下进行. 这种方式是最自然的

  • 另外一种叫做 grouped ordering,示意图中是先做9个block, 形成一个大点的 super-block, 然后再进行到下一个 super-block,一个 super-block 一个 super-block 进行计算

​ 为什么要使用第二种看起来没那么直接的方式?答案是 A, B 矩阵中,同样是算 9 个 C 的 block, 第二种方式中,A B 中需要用到的 block (黄色小块) 数量远远小于第一种方式,也就是说,数据的 reuse 更好,这样可以提高前面提到的 cache hit rate

​ 现在再回头看看上面代码, 相对好理解了很多, 这么多代码其实就是为了一件事情, 当我们从 0 到 80 逐渐增加 pid 的时候, 我们希望处理黄色小块的顺序是按照第二种顺序 (图中数字标号) 来处理的.

再换句话说, 我们希望:

1
2
for pid in range(81):
pid --> (pid_m, pid_n) # 这里的 (pid_m, pid_n) 是黄色小块的坐标, 我们希望这个坐标按照第二种方式行进

这段代码本质上就是干这件事情的

1
2
3
4
5
6
7
8
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m

group

  • num_pid_mnum_pid_n 分别得到矩阵长宽两个维度各自分成多少黄色小块,这俩乘起来就是总的 pid (“循环” 总共进行多少次)

  • num_pid_in_group 是上图中红色框 (高是 GROUP_SIZE_M,宽是 num_pid_n) 里总共有多少黄色小块,这么一个红色框叫做一个 group

  • group_id 当前的这次 “循环”, 是在第几个红色框里

  • first_pid_m 当前所在的 group 中的第一个黄色小块,它在全局中是第几个黄色小块 (在 m 维度上)

  • group_size_m 这里重复算一遍 group_size_m 是因为最后一个 group 可能占不满,会小一点

最后我们

1
2
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m

​ 得到我们当前这步循环到底要处理哪个 ([pid_m, pid_n]) 黄色小块,第一行保证 pid_m 一定是小于 first_pid_m + group_size_m 的。第二行保证 pid_n 一定是从左到右一列一列来的, 也就是 n 这个维度是 0, 0, 0, 1, 1, 1, 2, 2, 2, ... 这样来的

pid

​ 但是,如果是上面提到的第一种 row-major ordering 的方式(这种方式简单的多, 但是不高效),那么 pid_mpid_n 的计算方式是:

1
2
pid_m = pid // num_pid_n
pid_n = pid % num_pid_n

raw_major

具体到"循环"某处

​ 前面讲的,是 kernel function 不同 call 之间的关系(pid 和网格序号之间的关系),实际上还没有具体到当前的这次 call。kernel function 的每次 call,对应于一个 pid,前面讲的其实就是找到当前这个 pid 对应 C 中的哪个 block (黄色小块)。现在我们具体到 block 内部. 假设我们要计算 C 中的第一个 block,block-0

计算这个 block 需要的 A 和 B 矩阵中的 9 个 block:

add

​ 先从 A 中取 9 个 block 中的第一个 block,从 B 中取 9 个 block 中的第一个 block,然后二者相乘,加到一个 accumulator 中,直到 9 个 block 做完, 就得到了 C 中的第一个 block

​ 所以,在每个 pid 内部也是需要循环的,这个循环是在 9 个 block 上面做的。但是,为了让这个循环开始, 我们要先 load A 和 B 中 9 个 block 中的第一个 block

​ 一个 block 中, 有多个 elements,接下来要找到每个 element 的 pointer

1
2
3
4
5
6
7
8
9
10
11
12
# ----------------------------------------------------------
# Create pointers for the first blocks of A and B.
# We will advance this pointer as we move in the K direction
# and accumulate
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
# See above `Pointer Arithmetics` section for details
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)

​ 以 A 的第一个 block 为例,a_ptr 是整个 A 矩阵第一个元素的地址,offs_amoffs_bn 是 A 矩阵 9 个 block 中第一个 block 中,每个元素在整个 A 矩阵中的坐标 (也就是 m 维度的 index 和 k 维度的 index),所以,它们其实是一个 list,这里用的是 tl.arange, 其实和 numpy.arange 差不多,Triton 之所以自己搞了 tl.arange 为了后面 compile 的时候比较方便

​ 有了 m 维度 和 k 维度的 index,就可以让它们各自和 m 维度 和 k 维度的 stride 相乘,然后和 a_ptr 相加,就可以得到 A 矩阵 9 个 block 中第一个 block 中每个 element 的地址

1
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)

乘加计算

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Iterate to compute a block of the C matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop.
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# Load the next block of A and B, generate a mask by checking the K dimension.
# If it is out of bounds, set it to 0.
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
# We accumulate along the K dimension.
accumulator += tl.dot(a, b)
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk

​ 这里面有个知识点是 maskmask 的作用是,很多时候 K 可能不能被 BLOCK_SIZE_K 整除,到每一行最后一个 block 的时候,实际大小是不足 BLOCK_SIZE_K 的,所以我们在 load 时候,需要把这块考虑进去

  • tl.load 函数用于从 global memory 中加载数据进入 shared memory

OPs Fusion

1
2
if ACTIVATION == "leaky_relu":
accumulator = leaky_relu(accumulator)

在外部, 你可以定义 leaky_relu

1
2
3
4
5
# We can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`.
@triton.jit
def leaky_relu(x):
x = x + 1
return tl.where(x >= 0, x, 0.01 * x)

​ 自己写 kernel,所以可以把一些操作 fuse 起来。Fuse 的好处在哪?在于你可以 load 一次数据。在这个数据上面进行多种计算。这样就把本来需要多次 load 的时间省下来了

完整定义:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
@triton.jit
def matmul_kernel(
# Pointers to matrices
a_ptr, b_ptr, c_ptr,
# Matrix dimensions
M, N, K,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
# by to get the element one row down (A has M rows).
stride_am, stride_ak, #
stride_bk, stride_bn, #
stride_cm, stride_cn,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, #
GROUP_SIZE_M: tl.constexpr, #
ACTIVATION: tl.constexpr #
):
"""Kernel for computing the matmul C = A x B.
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
"""
# -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse.
# See above `L2 Cache Optimizations` section for details.
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m

# ----------------------------------------------------------
# Create pointers for the first blocks of A and B.
# We will advance this pointer as we move in the K direction
# and accumulate
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
# See above `Pointer Arithmetic` section for details
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)

# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop.
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# Load the next block of A and B, generate a mask by checking the K dimension.
# If it is out of bounds, set it to 0.
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
# We accumulate along the K dimension.
accumulator = tl.dot(a, b, accumulator)
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
# You can fuse arbitrary activation functions here
# while the accumulator is still in FP32!
if ACTIVATION == "leaky_relu":
accumulator = leaky_relu(accumulator)
c = accumulator.to(tl.float16)

# -----------------------------------------------------------
# Write back the block of the output matrix C with masks.
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)


# We can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `matmul_kernel`.
@triton.jit
def leaky_relu(x):
return tl.where(x >= 0, x, 0.01 * x)


def matmul(a, b, activation=""):
# Check constraints.
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
assert a.is_contiguous(), "Matrix A must be contiguous"
M, K = a.shape
K, N = b.shape
# Allocates output.
c = torch.empty((M, N), device=a.device, dtype=torch.float16)
# 1D launch kernel where each block gets its own program.
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
matmul_kernel[grid](
a, b, c, #
M, N, K, #
a.stride(0), a.stride(1), #
b.stride(0), b.stride(1), #
c.stride(0), c.stride(1), #
ACTIVATION=activation #
)
return c

来源:Matrix Multiplication — Triton documentation (triton-lang.org)

进一步思考

​ 上面的代码中进行了分块和分组的操作,分块(block)和分组(group)的设计是为了优化并行性和内存访问效率,使用了多级的 cache 提高效率:

  • 分块的目的(对应 cuda 编程中的 block):为了高效使用 L1 缓存。将大矩阵分解成较小的块,每个块由一个线程块(block)处理。这样可以减少内存带宽的需求,并提高缓存利用率。每个线程块内部的线程可以并行执行
  • 分组的目的(并不直接对应 cuda 编程中某一硬件层级,这个概念只是高效利用 L2 cache):为了高效使用 L2 缓存。将多个线程块组织成组,以促进 L2 缓存的重用。通过分组,可以确保同一组内的线程块在处理相同的数据时能够有效地重用 L2 缓存中的数据,从而减少内存访问延迟。同一组内的块可以并行执行,不同组之间的块也可以并行执行,分组使用 L2 cache 在代码中体现为:
1
2
3
4
5
6
7
8
9
10
GROUP_SIZE_M = 8
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m

​ 注意:数据的加载和存储都是通过 tl.loadtl.store 指令完成的,而 L2 缓存的管理是由硬件自动处理的

避免竞争冒险

​ 可以思考这样一个问题:for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):为什么这个计算不并行呢?

​ 这个循环是我们求每一块内的数据结果的时候,对 K 维度进行分块,我们最终的结果是要遍历整个 K 维度累加数据,才能把最终结果给正确算出来的,在累加的过程中,多个线程块之间容易出现竞争冒险的问题,因此使用顺序执行而非并行执行