mmcv组件6

MMCV核心组件Hook

​ MM的整个算法过程就像一个黑盒子:给定输入后(配置文件),黑盒子就会吐出算法结果。整个过程封装度非常高,几乎不需要手写什么代码,Hook机制的作用就是在算法执行过程中加入自定义操作呢

Hook就是一种一种触发器,可以在程序预定义的位置执行预定义的函数。MMCV根据算法的生命周期预定义了6个可以插入自定义函数的位点,用户可以在每个位点自由地插入任意数量的函数操作,如下图所示:

img

​ 这6个位置基本涵盖了自定义操作可能出现的位置,MMCV已经实现了部分常用Hook,其中默认Hook不需要用户自行注册,通过配置文件配置对应的参数即可;定制Hook则需要用户在配置文件中手动配置custom_hooks字段进行注册

img

Hook的注册:

​ Hook 划分为默认 Hook 和定制 Hook,之所以划分为默认 Hook 和定制 Hook,原因是默认 Hook不需要用户自行注册,用户通过 hook 名_config 配置对应参数即可

对于默认 Hook,在 MMDetection 框架训练过程中,其注册代码为:

1
2
3
runner.register_training_hooks(cfg.lr_config, optimizer_config,
cfg.checkpoint_config, cfg.log_config,
cfg.get('momentum_config', None))

register_training_hooks 函数的接收参数其实是字典参数,Runner 内部会根据配置自动生成对应的 Hook 实例,典型的 lr_config 为:

1
2
3
4
5
6
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=0.001,
step=[16, 22])

由于 lr_config 没有显示的调用 Hook 类,故对于用户而言其实不知道是采用 Hook 机制实现的。

​ 但是对于定制类 Hook,其注册源码如下:

1
2
3
4
5
6
7
8
9
10
# user-defined hooks
if cfg.get('custom_hooks', None):
custom_hooks = cfg.custom_hooks
for hook_cfg in cfg.custom_hooks:
hook_cfg = hook_cfg.copy()
priority = hook_cfg.pop('priority', 'NORMAL')
# 通过配置实例化定制 hook
hook = build_from_cfg(hook_cfg, HOOKS)
# 注册
runner.register_hook(hook, priority=priority)

​ 和其他模块不同,当我们定义好一个Hook(并注册到HOOKS注册器中)之后,还需要注册到Runner中才能使用,前后一共进行两次注册。第一次注册到HOOKS是为了程序能够根据Hook名称找到对应的模块,第二次注册到Runner中是为了程序执行到预定义位置时能够调用对应的函数

Hook注册原理:

​ 类似前面的核心组件,Hook同样使用了面向接口编程的思想,Hook类本身只提供预定义位置的接口函数,任何自定义的Hook都需要继承Hook类,然后根据需要重写对应的接口函数。比如检查点保存操作通常发生在每次迭代或epoch后,所以我们需要重写after_train_iterafter_train_epoch

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class Hook:
def before_run(self, runner):
pass

def after_run(self, runner):
pass

def before_epoch(self, runner):
pass

def after_epoch(self, runner):
pass

def before_iter(self, runner):
pass

def after_iter(self, runner):
pass

@HOOKS.register_module()
class CheckpointHook(Hook):
def __init__(self,
interval=-1,
by_epoch=True,
save_optimizer=True,
out_dir=None,
max_keep_ckpts=-1,
**kwargs):
...
def after_train_iter(self, runner):
...
def after_train_epoch(self, runner):
...

​ 和其他模块不同,当我们定义好一个Hook(并注册到HOOKS注册器中)之后,还需要注册到Runner中才能使用,前后一共进行两次注册。第一次注册到HOOKS是为了程序能够根据Hook名称找到对应的模块,第二次注册到Runner中是为了程序执行到预定义位置时能够调用对应的函数

Runner是MMCV用来管理训练过程的一个类,它内部会维护一个list类型变量self._hooks,我们需要把训练过程会调用的Hook实例对象按照优先级顺序全部添加到self._hooks中,这个过程通过Runner.register_hook()函数实现。MMCV预定义了几种优先级, 数字越小表示优先级越高, 如果觉得默认的分级方式颗粒度过大, 也可以直接传入0~100的整数进行精细划分

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def register_hook(self, hook, priority='NORMAL'):
"""
预定义优先级
+--------------+------------+
| Level | Value |
+==============+============+
| HIGHEST | 0 |
+--------------+------------+
| VERY_HIGH | 10 |
+--------------+------------+
| HIGH | 30 |
+--------------+------------+
| ABOVE_NORMAL | 40 |
+--------------+------------+
| NORMAL | 50 |
+--------------+------------+
| BELOW_NORMAL | 60 |
+--------------+------------+
| LOW | 70 |
+--------------+------------+
| VERY_LOW | 90 |
+--------------+------------+
| LOWEST | 100 |
+--------------+------------+
"""
hook.priority = priority
# 插入法排序将Hooks按照priority大小升序排列
inserted = False
for i in range(len(self._hooks) - 1, -1, -1):
if priority >= self._hooks[i].priority:
self._hooks.insert(i + 1, hook)
inserted = True
break
if not inserted:
self._hooks.insert(0, hook)

将Hook实例加入到self._hooks中之后,然后就可以在预定义位置调用call_hook()来调用各个Hook实例中的对应方法。call_hook()称为回调函数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# 开始运行时调用
self.call_hook('after_train_epoch')

while self.epoch < self._max_epochs:

# 开始 epoch 迭代前调用
self.call_hook('before_train_epoch')

for i, data_batch in enumerate(self.data_loader):
# 开始 iter 迭代前调用
self.call_hook('before_train_iter')

self.model.train_step()

# 经过一次迭代后调用
self.call_hook('after_train_iter')

# 经过一个 epoch 迭代后调用
self.call_hook('after_train_epoch')

# 运行完成前调用
self.call_hook('after_train_epoch')

调用call_hook()时会遍历self._hooks中所有Hook实例,并根据fn_name调用Hook实例的指定成员函数。比如fn_name='before_train_epoch'时,call_hook()会挨个调用所有Hook的before_train_epoch()函数。而且由于self._hooks已经按照优先级进行过排序,call_hook()会先调用优先级高的Hook方法

1
2
3
def call_hook(self, fn_name):
for hook in self._hooks:
getattr(hook, fn_name)(self)

Hook机制小结

Hook是一种设置在程序固定位置的触发器,当程序执行到预设位点时则会触发断点,执行Hook函数的流程,结束后再回到断点位置继续执行主流程的代码。实现一个Hook包含5个步骤:

  1. 定义一个类,继承Hook基类
  2. 根据自定义Hook的功能有选择地重写Hook基类中对应的函数
  3. 注册自定义Hook模块到HOOKS查询表中(register_module
  4. 实例化Hook模块并注册到Runner中(register_hook
  5. 使用回调函数调用重写的Hook函数(call_hook