Pytorch-hook

Pytorch 的 hook 机制:

简介:

在PyTorch中,hook机制是一种强大的工具,它允许开发者在不修改模型原始结构的情况下,插入自定义函数来访问、记录或修改模型运行过程中的中间层输出(forward pass)和梯度(backward pass)。

具体来说,在torch.nn.Module及其包含的torch.Tensor对象上可以注册hook。通过调用.register_forward_hook().register_backward_hook()方法,可以在前向传播结束时或反向传播过程中执行特定的回调函数:

  • register_forward_hook(hook_function):当模块的前向传播完成时,会调用这个hook函数,它可以访问到该模块的输入和输出张量。
  • register_backward_hook(hook_function):当模块接收到关于其输出的反向传播信号时,这个hook函数会被调用,它能够访问到模块输出的梯度以及输入的梯度。

Hook常用于以下场景:

  1. 可视化中间层特征图
  2. 监控和记录网络训练过程中的中间变量和梯度信息
  3. 对模型训练进行干预,如正则化或其他动态调整策略
  4. 实现元学习中的特殊计算需求
  5. 保存特定层的输出用于后续分析或微调

使用hook时需要注意释放资源,因为在每次迭代结束后,如果不手动保存这些中间结果,它们通常会被自动清理以节省内存。

PyTorch中的hook机制分为针对Tensor和Module(模型)的两种类型:

  1. Tensor Hook

    • torch.Tensor.register_hook(hook_function):这是一个作用在张量上的钩子函数,当该张量参与反向传播计算梯度时,会调用你注册的hook_function。这个函数接收一个参数,即该张量对应的梯度(grad),并且允许你在梯度计算完成后对其进行修改或观察,这种情况用得少,几乎只能用于修改反向传播的梯度值
  2. Module Hooks

    • torch.nn.Module.register_forward_hook(hook_function):当你需要监控或者修改模块执行前向传播过程时使用的钩子。在模块完成其forward()方法并生成输出后,hook_function会被调用,它接受三个参数:当前模块实例、模块接收到的输入以及模块计算出的输出。

    • torch.nn.Module.register_forward_pre_hook(hook_function):与register_forward_hook类似,但会在模块执行前向传播之前调用,此时输入尚未经过模块处理。

    • torch.nn.Module.register_backward_hook(hook_function):用于监控或修改模块在反向传播阶段的行为。当模块的输出梯度计算完毕后,hook_function会被调用,它接收四个参数:当前模块实例、输入梯度、输出梯度以及输出自身

    • 如果希望对整个模型或特定层的所有输入和输出张量添加hooks,可以使用模块级别的hook,例如:

      • nn.Module.register_forward_hook(hook_function)
      • nn.Module.register_backward_hook(hook_function)
      • nn.Module.register_full_backward_hook(hook_function)

      这些接口将允许在模型前向传播或反向传播的不同阶段添加自定义处理逻辑

torch.Tensor.register_hook(hook_function) 用法

函数签名:

1
tensor.grad_fn.register_hook(hook: Callable[[Tensor], None]) -> torch.utils.hooks.RemovableHandle

参数说明:

  • hook: Callable[[Tensor], None]
    • 这是一个接受一个张量作为输入的可调用对象(通常为函数)
    • 输入的张量是该操作在反向传播过程中产生的梯度,即当计算当前张量的上游梯度时,传递给hook_function的是当前张量的梯度
    • hook函数没有返回值要求,其目的是为了让你能够监控或者修改这个张量的梯度
    • 对于 grad_fn=True 的 Tensor,可以直接使用 tensor.register_hook 而不需要显式地先指向 grad_fn 属性

返回值

  • 无直接返回值,但会返回一个 torch.utils.hooks.RemovableHandle 对象。通过该对象加上 .remove() 可以随时移除之前注册的钩子函数

注意:尽量不要在动态图的中间节点的hook函数中使用某个节点的 grad_fn 的 inplace 操作,比如 c=a+b 如果在 b 的 hook 中原地改变 b 的梯度,则 a 的梯度也会一样改变,因为 Pytorch 的底层机制,反向传播过程中发现 a 和 b 的梯度相同,就直接将 b 的梯度信息复制过去了!

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
a = torch.tensor(2.0, requires_grad=True)
b = torch.tensor(3.0, requires_grad=True)
c = a * b

def c_hook(grad):
print(grad)
return grad + 2

c.register_hook(c_hook)
c.register_hook(lambda grad: print(grad))
c.retain_grad()

d = torch.tensor(4.0, requires_grad=True)
d.register_hook(lambda grad: grad + 100)

e = c * d
e.retain_grad()
e.register_hook(lambda grad: grad * 2)
e.retain_grad() # second-time retain_grad() is NOP (NOP:空操作)

e.backward()
print(f'a.grad {a.grad}')
print(f'b.grad {b.grad}')
print(f'c.grad {c.grad}')
print(f'd.grad {d.grad}')
print(f'e.grad {e.grad}')
#---------运行结果---------
tensor(8.)
tensor(10.)
a.grad 30.0
b.grad 20.0
c.grad 10.0
d.grad 112.0
e.grad 2.0

代码解释

register_forward_hook(hook_function) 用法:

函数签名

1
nn.Module.register_forward_hook(hook)

参数

  • hook:类型为 Callable[[Module, Tensor, Tensor], None] 的函数对象。这个函数需要接受三个参数,并且不返回任何值(因为返回了值也没什么用)
    • module(类型: torch.nn.Module):触发钩子的模块实例。
    • input类型: Tuple[Tensor, ...] 或单个 Tensor):传递给模块执行前向传播操作时的输入数据,如果模块接收多个输入,则是一个包含所有输入张量的元组。
    • output类型: TensorTuple[Tensor, ...]):模块执行完前向传播后生成的输出数据,根据模块的具体输出,可以是单个张量或多张量组成的元组

返回值

  • 无直接返回值,但会返回一个 torch.utils.hooks.RemovableHandle 对象。通过该对象加上 .remove() 可以随时移除之前注册的钩子函数
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
import torch.nn as nn

class SumNet(nn.Module):
def __init__(self):
super(SumNet, self).__init__()

@staticmethod
def forward(a, b, c):
d = a + b + c
print('forward():')
print(' a:', a)
print(' b:', b)
print(' c:', c)
print(' d:', d)
return d
# forward_pre_hook在调用forward函数前自动调用
def forward_pre_hook(module, input_positional_args):
a, b = input_positional_args
new_input_positional_args = a + 10, b
print('forward_pre_hook():')
print(' module:', module)
print(' input_positional_args:', input_positional_args)
print(' new_input_positional_args:', new_input_positional_args)
return new_input_positional_args
# forward_hook在调用forward函数后自动调用
def forward_hook(module, input_positional_args, output):
new_output = output + 100
print('forward_hook():')
print(' module:', module)
print(' input_positional_args:', input_positional_args)
print(' output:', output)
print(' new_output:', new_output)
# return new_output

if __name__ == '__main__':
sum_net = SumNet()
sum_net.register_forward_pre_hook(forward_pre_hook)
sum_net.register_forward_hook(forward_hook)

a = torch.tensor(1.0, requires_grad=True)
b = torch.tensor(2.0, requires_grad=True)
c = torch.tensor(3.0, requires_grad=True)
print('start\n')
print('a:', a)
print('b:', b)
print('c:', c)
print('before model')
d = sum_net(a, b, c=c)

print('after model')
print('d:', d)
#--------------------------
start:
a: tensor(1., requires_grad=True)
b: tensor(2., requires_grad=True)
c: tensor(3., requires_grad=True)

before model:
forward_pre_hook():
module: SumNet()
input_positional_args: (tensor(1., requires_grad=True), tensor(2., requires_grad=True))
new_input_positional_args: (tensor(11., grad_fn=<AddBackward0>), tensor(2., requires_grad=True))
forward():
a: tensor(11., grad_fn=<AddBackward0>)
b: tensor(2., requires_grad=True)
c: tensor(3., requires_grad=True)
d: tensor(16., grad_fn=<AddBackward0>)
forward_hook():
module: SumNet()
input_positional_args: (tensor(11., grad_fn=<AddBackward0>), tensor(2., requires_grad=True))
output: tensor(16., grad_fn=<AddBackward0>)
new_output: tensor(116., grad_fn=<AddBackward0>)

after model:
d: tensor(116., grad_fn=<AddBackward0>)

register_backward_hook() 用法

函数签名:

1
nn.Module.register_backward_hook(hook)

参数说明:

  • hook:类型为 Callable[[Module, Tensor, Tensor], None] 的函数对象
    • module (类型: torch.nn.Module):当前注册 hook 的模块实例
    • grad_input (类型: tuple[Tensor, ...]):输入张量的梯度元组,每个元素对应模块正向传播接收到的一个输入张量的梯度
    • grad_output (类型: tuple[Tensor, ...]):输出张量的梯度元组,每个元素对应模块正向传播生成的一个输出张量的梯度

返回值

  • 无直接返回值,但会返回一个 torch.utils.hooks.RemovableHandle 对象。通过该对象加上 .remove() 可以随时移除之前注册的钩子函数

使用是同理的,通过使用 register_backward_hook(),开发者可以监视和/或修改模块在反向传播过程中的梯度信息,这对于调试、监控训练动态以及实现一些高级算法如梯度裁剪、正则化等非常有用

技巧:如果前向传播和反向传播的hook都有时,可以用一个类去封装hook使代码模块化

hook 应用实例:特征图提取特征可视化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision.transforms import transforms
import matplotlib.pyplot as plt
from PIL import Image

vgg16_pretrained = models.vgg16(pretrained=True)
_, axes = plt.subplots(1, 4, figsize=(12, 5))


class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.encoder1 = vgg16_pretrained.features[:5]
self.encoder2 = vgg16_pretrained.features[5:10]
self.encoder3 = vgg16_pretrained.features[10:17]

def forward(self, x):
x = self.encoder1(x)
x = self.encoder2(x)
x = self.encoder3(x)
return x


def visualize(activation, cnt):
axes[cnt].imshow(activation[0, 1, :, :].detach().numpy(), cmap='gray')
axes[cnt].set_title(f'Encoder: {cnt}')


mymodel = MyModel()
mymodel.encoder1.register_forward_hook(lambda mymodel, input, output: visualize(output, 1))
mymodel.encoder2.register_forward_hook(lambda mymodel, input, output: visualize(output, 2))
mymodel.encoder3.register_forward_hook(lambda mymodel, input, output: visualize(output, 3))

input_image = Image.open("./1.jpg").convert("RGB")
transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
input_image = transforms(input_image).unsqueeze(0)
axes[0].imshow(input_image.permute(0, 2, 3, 1).squeeze(0), cmap='gray')
axes[0].set_title('input')
print(input_image.shape)
_ = mymodel(input_image)
plt.show()

visualize

参考文献:

elliotwaite/pytorch-hooks-tutorial: Examples of using PyTorch hooks, as covered in my YouTube tutorial video. (github.com)