MMCV组件3

MMCV核心组件Registry

从2.0.0版本开始,configregistry 机制移动到了 mmengine 里面,本文是以mmengine模块中的registry讲解的,其实老版本的registry也差不多,调用的时候从mmcv.utils里面调用就行

​ OpenMMLab 的算法库支持了丰富的算法和数据集,因此实现了很多功能相近的模块。例如 ResNet 和 SE-ResNet 的算法实现分别基于 ResNetSEResNet 类,这些类有相似的功能和接口,都属于算法库中的模型组件。为了管理这些功能相似的模块,MMEngine 实现了注册器。OpenMMLab 大多数算法库均使用注册器来管理它们的代码模块

什么是注册器

​ MMEngine 实现的注册器可以看作一个映射表和模块构建方法的组合。映射表维护了一个字符串到类或者函数的映射,例如维护字符串 "ResNet"ResNet 类或函数的映射,使得用户可以通过 "ResNet" 找到 ResNet 类;而模块构建方法则定义了如何根据字符串查找到对应的类或函数以及如何实例化这个类或者调用这个函数,例如,通过字符串 "bn" 找到 nn.BatchNorm2d 并实例化 BatchNorm2d 模块;又或者通过字符串 "build_batchnorm2d" 找到 build_batchnorm2d 函数并返回该函数的调用结果。MMEngine 中的注册器默认使用 build_from_cfg 函数来查找并实例化字符串对应的类或者函数

​ 一个注册器管理的类或函数通常有相似的接口和功能,因此该注册器可以被视作这些类或函数的抽象。例如注册器 MODELS 可以被视作所有模型的抽象,管理了 ResNetSEResNetRegNetX 等分类网络的类以及 build_ResNet, build_SEResNetbuild_RegNetX 等分类网络的构建函数

入门用法

函数声明

1
2
from mmengine import Registry
Registry(name,build_func=None,parent=None,scope=None,locations=[])
  1. name (str): 注册表的名称。这个名称用于标识一个特定的注册表实例,以便在代码中引用
  2. build_func (callable, optional): 用于根据提供的字符串名称创建对象。这个函数接受一个字符串参数,并返回一个对象实例。如果没有提供,那么默认的行为是直接使用字符串作为对象的名称来检索对象
  3. parent (Registry, optional): 父注册表。这允许创建一个继承自另一个注册表的注册表
  4. scope (str, optional): 作用域。这个参数定义了注册表的作用域,例如,它可以是 'singleton',表示注册表中的每个条目都是单例的,或者 'instance',表示每次请求都会创建一个新的实例。
  5. locations (list, optional): 一个可选的字符串列表,用于指定在创建对象时应该搜索的模块位置

下面使用例子讲解

使用注册器管理代码库中的模块,需要以下三个步骤

  1. 创建注册器
  2. 创建一个用于实例化类的构建方法(可选,在大多数情况下可以只使用默认方法)
  3. 将模块加入注册器中

假设我们要实现一系列激活模块并且希望仅修改配置就能够使用不同的激活模块而无需修改代码。

1
2
3
4
5
# 注册注册器
from mmengine import Registry
# scope 表示注册器的作用域,如果不设置,默认为包名,例如在 mmdetection 中,它的 scope 为 mmdet
# locations 表示注册在此注册器的模块所存放的位置,注册器会根据预先定义的位置在构建模块时自动 import
ACTIVATION = Registry('activation', scope='mmengine', locations=['mmengine.models.activations'])

locations 指定的模块 mmengine.models.activations 对应了 mmengine/models/activations.py 文件。在使用注册器构建模块的时候,ACTIVATION 注册器会自动从该文件中导入实现的模块。因此,我们可以在 mmengine/models/activations.py 文件中实现不同的激活函数,例如 SigmoidReLUSoftmax

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch.nn as nn
# 使用注册器管理模块
# 1.不需要传入任何参数,此时默认实例化的配置字符串是 str (类名)
@ACTIVATION.register_module()
class Sigmoid(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
print('call Sigmoid.forward')
return x
# 2.传入指定 str,实例化时候只需要传入对应相同 str 即可
@ACTIVATION.register_module(ReLu)
class ReLU(nn.Module):
def __init__(self, inplace=False):
super().__init__()

def forward(self, x):
print('call ReLU.forward')
return x

@ACTIVATION.register_module()
class Softmax(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
print('call Softmax.forward')
return x

使用注册器管理模块的关键步骤是,将实现的模块注册到注册表 ACTIVATION 中。通过 @ACTIVATION.register_module() 装饰所实现的模块,字符串和类或函数之间的映射就可以由 ACTIVATION 构建和维护通过注册,我们就可以通过 ACTIVATION 建立字符串与类或函数之间的映射

1
2
3
4
5
6
print(ACTIVATION.module_dict)
# {
# 'Sigmoid': __main__.Sigmoid,
# 'ReLU': __main__.ReLU,
# 'Softmax': __main__.Softmax
# }

只有模块所在的文件被导入时,注册机制才会被触发,用户可以通过三种方式将模块添加到注册器中:

  1. locations 指向的文件中实现模块。注册器将自动在预先定义的位置导入模块。这种方式是为了简化算法库的使用,以便用户可以直接使用 REGISTRY.build(cfg)
  2. 手动导入文件。常用于用户在算法库之内或之外实现新的模块
  3. 在配置中使用 custom_imports 字段。 详情请参考导入自定义Python模块

​ 模块成功注册后,我们可以通过配置文件使用这个激活模块。

1
2
3
4
5
6
7
8
9
import torch

input = torch.randn(2)

act_cfg = dict(type='Sigmoid')
activation = ACTIVATION.build(act_cfg)
output = activation(input)
# call Sigmoid.forward
print(output)

​ 如果我们想使用 ReLU,仅需修改配置

1
2
3
4
5
act_cfg = dict(type='ReLU', inplace=True)
activation = ACTIVATION.build(act_cfg)
output = activation(input)
# call ReLU.forward
print(output)

** 如果希望在创建实例前检查输入参数的类型(或者任何其他操作)(类似于一个钩子函数)**,我们可以实现一个构建方法并将其传递给注册器从而实现自定义构建流程

1
2
3
4
5
6
7
def build_activation(cfg, registry, *args, **kwargs):
cfg_ = cfg.copy()
act_type = cfg_.pop('type')
print(f'build activation: {act_type}')
act_cls = registry.get(act_type)
act = act_cls(*args, **kwargs, **cfg_)
return act

并将 build_activation 传递给 build_func 参数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
ACTIVATION = Registry('activation', build_func=build_activation, scope='mmengine', locations=['mmengine.models.activations'])

@ACTIVATION.register_module()
class Tanh(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
print('call Tanh.forward')
return x

act_cfg = dict(type='Tanh')
activation = ACTIVATION.build(act_cfg)
output = activation(input)
# build activation: Tanh
# call Tanh.forward
print(output)

在大多数情况下,使用默认的方法就可以了

进阶用法(以后再写)

MMEngine 的注册器支持层级注册,利用该功能可实现跨项目调用,即可以在一个项目中使用另一个项目的模块。虽然跨项目调用也有其他方法的可以实现,但 MMEngine 注册器提供了更为简便的方法。为了方便跨库调用,MMEngine 提供了 22 个根注册器([查手册](注册器(Registry) — mmengine 0.10.3 文档))