MMCV核心组件Hook
MM的整个算法过程就像一个黑盒子:给定输入后(配置文件),黑盒子就会吐出算法结果。整个过程封装度非常高,几乎不需要手写什么代码,Hook机制的作用就是在算法执行过程中加入自定义操作呢
Hook就是一种一种触发器,可以在程序预定义的位置执行预定义的函数。MMCV根据算法的生命周期预定义了6个可以插入自定义函数的位点,用户可以在每个位点自由地插入任意数量的函数操作,如下图所示:
这6个位置基本涵盖了自定义操作可能出现的位置,MMCV已经实现了部分常用Hook,其中默认Hook不需要用户自行注册,通过配置文件配置对应的参数即可;定制Hook则需要用户在配置文件中手动配置custom_hooks
字段进行注册
Hook的注册:
Hook 划分为默认 Hook 和定制 Hook,之所以划分为默认 Hook 和定制 Hook,原因是默认 Hook不需要用户自行注册,用户通过 hook 名_config
配置对应参数即可
对于默认 Hook,在 MMDetection 框架训练过程中,其注册代码为:
1 2 3
| runner.register_training_hooks(cfg.lr_config, optimizer_config, cfg.checkpoint_config, cfg.log_config, cfg.get('momentum_config', None))
|
register_training_hooks
函数的接收参数其实是字典参数,Runner 内部会根据配置自动生成对应的 Hook 实例,典型的 lr_config
为:
1 2 3 4 5 6
| lr_config = dict( policy='step', warmup='linear', warmup_iters=500, warmup_ratio=0.001, step=[16, 22])
|
由于 lr_config
没有显示的调用 Hook 类,故对于用户而言其实不知道是采用 Hook 机制实现的。
但是对于定制类 Hook,其注册源码如下:
1 2 3 4 5 6 7 8 9 10
| if cfg.get('custom_hooks', None): custom_hooks = cfg.custom_hooks for hook_cfg in cfg.custom_hooks: hook_cfg = hook_cfg.copy() priority = hook_cfg.pop('priority', 'NORMAL') hook = build_from_cfg(hook_cfg, HOOKS) runner.register_hook(hook, priority=priority)
|
和其他模块不同,当我们定义好一个Hook(并注册到HOOKS
注册器中)之后,还需要注册到Runner中才能使用,前后一共进行两次注册。第一次注册到HOOKS
是为了程序能够根据Hook名称找到对应的模块,第二次注册到Runner中是为了程序执行到预定义位置时能够调用对应的函数
Hook注册原理:
类似前面的核心组件,Hook同样使用了面向接口编程的思想,Hook
类本身只提供预定义位置的接口函数,任何自定义的Hook都需要继承Hook
类,然后根据需要重写对应的接口函数。比如检查点保存操作通常发生在每次迭代或epoch后,所以我们需要重写after_train_iter
和after_train_epoch
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
| class Hook: def before_run(self, runner): pass
def after_run(self, runner): pass
def before_epoch(self, runner): pass
def after_epoch(self, runner): pass
def before_iter(self, runner): pass
def after_iter(self, runner): pass
@HOOKS.register_module() class CheckpointHook(Hook): def __init__(self, interval=-1, by_epoch=True, save_optimizer=True, out_dir=None, max_keep_ckpts=-1, **kwargs): ... def after_train_iter(self, runner): ... def after_train_epoch(self, runner): ...
|
和其他模块不同,当我们定义好一个Hook(并注册到HOOKS
注册器中)之后,还需要注册到Runner中才能使用,前后一共进行两次注册。第一次注册到HOOKS
是为了程序能够根据Hook名称找到对应的模块,第二次注册到Runner中是为了程序执行到预定义位置时能够调用对应的函数
Runner是MMCV用来管理训练过程的一个类,它内部会维护一个list类型变量self._hooks
,我们需要把训练过程会调用的Hook实例对象按照优先级顺序全部添加到self._hooks
中,这个过程通过Runner.register_hook()
函数实现。MMCV预定义了几种优先级, 数字越小表示优先级越高, 如果觉得默认的分级方式颗粒度过大, 也可以直接传入0~100的整数进行精细划分
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
| def register_hook(self, hook, priority='NORMAL'): """ 预定义优先级 +--------------+------------+ | Level | Value | +==============+============+ | HIGHEST | 0 | +--------------+------------+ | VERY_HIGH | 10 | +--------------+------------+ | HIGH | 30 | +--------------+------------+ | ABOVE_NORMAL | 40 | +--------------+------------+ | NORMAL | 50 | +--------------+------------+ | BELOW_NORMAL | 60 | +--------------+------------+ | LOW | 70 | +--------------+------------+ | VERY_LOW | 90 | +--------------+------------+ | LOWEST | 100 | +--------------+------------+ """ hook.priority = priority inserted = False for i in range(len(self._hooks) - 1, -1, -1): if priority >= self._hooks[i].priority: self._hooks.insert(i + 1, hook) inserted = True break if not inserted: self._hooks.insert(0, hook)
|
将Hook实例加入到self._hooks
中之后,然后就可以在预定义位置调用call_hook()
来调用各个Hook实例中的对应方法。call_hook()
称为回调函数。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
| self.call_hook('after_train_epoch')
while self.epoch < self._max_epochs:
self.call_hook('before_train_epoch')
for i, data_batch in enumerate(self.data_loader): self.call_hook('before_train_iter')
self.model.train_step()
self.call_hook('after_train_iter')
self.call_hook('after_train_epoch')
self.call_hook('after_train_epoch')
|
调用call_hook()
时会遍历self._hooks
中所有Hook实例,并根据fn_name
调用Hook实例的指定成员函数。比如fn_name='before_train_epoch'
时,call_hook()
会挨个调用所有Hook的before_train_epoch()
函数。而且由于self._hooks
已经按照优先级进行过排序,call_hook()
会先调用优先级高的Hook方法
1 2 3
| def call_hook(self, fn_name): for hook in self._hooks: getattr(hook, fn_name)(self)
|
Hook机制小结
Hook是一种设置在程序固定位置的触发器,当程序执行到预设位点时则会触发断点,执行Hook函数的流程,结束后再回到断点位置继续执行主流程的代码。实现一个Hook包含5个步骤:
- 定义一个类,继承Hook基类
- 根据自定义Hook的功能有选择地重写Hook基类中对应的函数
- 注册自定义Hook模块到HOOKS查询表中(
register_module
)
- 实例化Hook模块并注册到Runner中(
register_hook
)
- 使用回调函数调用重写的Hook函数(
call_hook
)