torch_compile
torch.compile
解析
torch.compile
是PyTorch 2.2版本中的一个重要新特性。它是一种新的 PyTorch 编译器,它可以将 Python 和 TorchScript 模型编译成 TorchDynamo 图,从而提高模型的运行效率。是加速 PyTorch 代码速度的最新方法! torch.compile
通过将 PyTorch 代码 JIT(just in time)编译到优化的内核中,使 PyTorch 代码运行得更快, 同时需要最少的资源代码更改
torch.compile
原理
以一个简单模型为例子:
1 |
|
BatchNorm2d
函数有一个 affine 的步骤,假如如果是按照模型的定义来进行计算,affine 的操作是对特征图进行的,但是如果想着直接对卷积核进行 affine,把这两个网络部分合起来进行优化那么就可以提高计算效率。但是对于大型模型,模型结构非常复杂,我们不可能手动把可以优化的模块都放在一起写一份 cuda 或者 triton 代码,因此 pytorch 官方提出了 torch.compile
这个优化模型的特性
torch.compile
主要从两个方面优化模型的计算效率
- 从模型编译上直接优化,同时考虑多个模块来优化计算效率
- 根据 GPU 的型号自适应改变模型的计算过程,根据不同型号 GPU 的性能瓶颈来动态调整计算过程
torch.compile
用法
torch.compile
需要在安装 pytorch 2.0 之后方可使用,若在 GPU 上运行还需依赖安装 Triton,如若未安装,直接 pip 安装 torchtriton 即可。
torch.compile
函数旨在不改变模型的定义下进行优化,有两种优化方式,直接对函数进行优化和对模型进行优化:
- 使用
@torch.compile
对函数进行优化:
1 |
|
- 使用
torch.compile(model)
直接对nn.Module
进行优化:
1 |
|
torch.compile
相对于之前的PyTorch编译器解决方案,如TorchScript和FX Tracing,有以下几个优势:
- 更灵活的模型定义:与 TorchScript 相比,
torch.compile
允许你使用 Python 直接定义模型,而不需要将模型转换为 TorchScript 的静态图。这意味着你可以更灵活地定义模型,而不需要考虑TorchScript的限制。 - 更好的性能:TorchDynamo 图是一种优化的中间表示形式,它允许 PyTorch 编译器进行更多的优化,从而提高模型的运行效率。与 FX Tracing 相比,TorchDynamo 图可以提供更好的性能。
- 更易于调试:由于
torch.compile
允许你使用 Python 直接定义模型,因此你可以更容易地调试模型。你可以使用Python的调试工具来检查模型的输入和输出,从而更容易地找到和修复错误。
优化示例
1 |
|
- 从这个优化效果来看,
torch.compile
主要是对简单的,神经网络中最常见的模块的计算上的优化,像这个 CARAFE 稍微高级一点的操作,如果想优化内存的话还是得自己手动写 cuda 和 triton