torch_compile

torch.compile解析

torch.compile是PyTorch 2.2版本中的一个重要新特性。它是一种新的 PyTorch 编译器,它可以将 Python 和 TorchScript 模型编译成 TorchDynamo 图,从而提高模型的运行效率。是加速 PyTorch 代码速度的最新方法! torch.compile 通过将 PyTorch 代码 JIT(just in time)编译到优化的内核中,使 PyTorch 代码运行得更快, 同时需要最少的资源代码更改

torch.compile原理

​ 以一个简单模型为例子:

1
2
3
4
model = nn.Sequential(
nn.Conv2d(c_in, c_out, kernel_size),
nn.BatchNorm2d(c_out)
)

BatchNorm2d 函数有一个 affine 的步骤,假如如果是按照模型的定义来进行计算,affine 的操作是对特征图进行的,但是如果想着直接对卷积核进行 affine,把这两个网络部分合起来进行优化那么就可以提高计算效率。但是对于大型模型,模型结构非常复杂,我们不可能手动把可以优化的模块都放在一起写一份 cuda 或者 triton 代码,因此 pytorch 官方提出了 torch.compile 这个优化模型的特性

torch.compile 主要从两个方面优化模型的计算效率

  • 从模型编译上直接优化,同时考虑多个模块来优化计算效率
  • 根据 GPU 的型号自适应改变模型的计算过程,根据不同型号 GPU 的性能瓶颈来动态调整计算过程

torch.compile用法

torch.compile 需要在安装 pytorch 2.0 之后方可使用,若在 GPU 上运行还需依赖安装 Triton,如若未安装,直接 pip 安装 torchtriton 即可。

torch.compile 函数旨在不改变模型的定义下进行优化,有两种优化方式,直接对函数进行优化和对模型进行优化:

  • 使用 @torch.compile 对函数进行优化:
1
2
3
4
5
@torch.compile
def opt_foo2(x, y):
a = torch.sin(x)
b = torch.cos(x)
return a + b
  • 使用 torch.compile(model) 直接对 nn.Module 进行优化:
1
2
3
4
5
6
7
8
9
10
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.lin = torch.nn.Linear(100, 10)

def forward(self, x):
return torch.nn.functional.relu(self.lin(x))

mod = MyModule()
opt_mod = torch.compile(mod)

torch.compile 相对于之前的PyTorch编译器解决方案,如TorchScript和FX Tracing,有以下几个优势:

  1. 更灵活的模型定义:与 TorchScript 相比,torch.compile 允许你使用 Python 直接定义模型,而不需要将模型转换为 TorchScript 的静态图。这意味着你可以更灵活地定义模型,而不需要考虑TorchScript的限制。
  2. 更好的性能:TorchDynamo 图是一种优化的中间表示形式,它允许 PyTorch 编译器进行更多的优化,从而提高模型的运行效率。与 FX Tracing 相比,TorchDynamo 图可以提供更好的性能。
  3. 更易于调试:由于 torch.compile 允许你使用 Python 直接定义模型,因此你可以更容易地调试模型。你可以使用Python的调试工具来检查模型的输入和输出,从而更容易地找到和修复错误。

优化示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch
import torch.nn as nn
import torch.nn.functional as F
import time

@torch.compile
def carafe(x, kernel, window_size):
b, c, h, w = x.size()
windows = F.unfold(x, window_size, padding=window_size//2, stride=1)
windows = windows.view(b, c, window_size**2, h, w)
return torch.einsum('b c k h w, b k h w -> b c h w', windows, kernel)

if __name__ == '__main__':

x = torch.randn(8, 256, 64, 128).to('cuda')
kernel = torch.randn(8, 25, 64, 128).to('cuda')
window_size = 5

total_time = 0
total_memory = 0

for i in range(100):
start_time = time.time()
y = carafe(x, kernel, window_size)
torch.cuda.synchronize() # 等待所有CUDA核心完成
end_time = time.time()

total_time += (end_time - start_time)
total_memory += torch.cuda.memory_allocated()

avg_time = total_time / 100
avg_memory = total_memory / 100

print(f'Average Time: {avg_time:.6f} seconds')
print(f'Average Memory: {avg_memory / (1024**2):.2f} MB')
"""
without torch.compile:
Average Time: 0.044453 seconds
Average Memory: 142.38 MB

with torch.compile:
Average Time: 0.038044 seconds
Average Memory: 142.38 MB
"""
  • 从这个优化效果来看,torch.compile 主要是对简单的,神经网络中最常见的模块的计算上的优化,像这个 CARAFE 稍微高级一点的操作,如果想优化内存的话还是得自己手动写 cuda 和 triton