MMCV组件1

MMCV核心组件FileHandler

注意:在2.0.0版本以后,FileHandlerFileClient的功能已经转到mmengine中了(它们在mmengine.fileio模块中),但是功能和用法大体不变,导入和使用方式如下:

1
2
3
4
5
6
7
8
9
10
11
12
# 在2.0.0半本以后,mmcv.load 函数已经被移除。在新版本中,
# 此功能已经被重新设计和整合到了 mmcv.config 模块中
from mmengine import Config
cfg = Config.fromfile("/path_to_your_file")

# ,FileClient 类已经被迁移到了 MMEngine 项目中
from mmengine.fileio import FileClient, FileHandler

>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
# 在2.0.0版本以前,使用方式是:
from mmcv.fileio import FileClient
from mmcv import load

MMCV 整体概述

MMCV 从一开始的定位就是提供底层通用组件,故在设计之初就已经考虑到了灵活性和可扩展性,其主要特性是:

  • 统一可扩展的 io api
  • 支持非常丰富的图像/视频处理算子
  • 图片/视频的标注文件可视化
  • 常用的工具类例如 timer 和 progress bar 等等
  • 上层框架需要的 hook 机制以及可以直接使用的 runner
  • 高度灵活的 config 模式和注册器机制
  • 高效高质量的 cuda operator

MMCV 核心组件FileHandler分析:

fileio 中有两个核心组件:涉及文件读写的 FileHandler 和文件获取后端 FileClient

  • FileHandler 的作用是对外提供统一的文件读写 API,其根据待读写的文件后缀名自动选择对应的 handler 进行具体操作
  • FileClient 的作用是对外提供统一的文件内容获取 API,主要用于训练过程中数据的读取,通过用户选择或者自定义不同的 FileClient 后端,可以轻松实现文件缓存、文件加速读取等功能

以上两个核心组件都是支持可扩展的

实现逻辑

要实现这个功能,mmcv采用面向接口编程的思想,核心代码如下:

  • 在 base 类中定义接口
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from abc import ABCMeta, abstractmethod
# 继承ABCMeta元类,使其无法直接实例化
class BaseFileHandler(metaclass=ABCMeta):

#@abstractmethod表示子类必须要实现该方法,否则报错
# 文件读取
@abstractmethod
def load_from_fileobj(self, file, **kwargs):
pass
# 文件存储,需要传入对象obj和file
@abstractmethod
def dump_to_fileobj(self, obj, file, **kwargs):
pass

#dump成字符串返回,当你不想保存时候使用
@abstractmethod
def dump_to_str(self, obj, **kwargs):
pass
# 对外实际上是采用下面两个api
def load_from_path(self, filepath, mode='r', **kwargs):
with open(filepath, mode) as f:
return self.load_from_fileobj(f, **kwargs)

def dump_to_path(self, obj, filepath, mode='w', **kwargs):
with open(filepath, mode) as f:
self.dump_to_fileobj(obj, f, **kwargs)

上述核心就是先定义几个抽象方法,然后再定义几个对外调用 API 即可。考虑到不同的读写具体子类在进行读写操作时候可能参数不一样,故上述每个方法上面都加了 **kwargs 可变字典参数

  • 子类实现抽象方法

以 json 读写为例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class JsonHandler(BaseFileHandler):
# 直接json.load即可
def load_from_fileobj(self, file):
return json.load(file)
# 直接json.dump即可
def dump_to_fileobj(self, obj, file, **kwargs):
#setdefault 是字典的一个内置方法,用于设置字典中键的默认值。当尝试访问字典中不存在的键时
#如果使用 get 方法,它会返回 None。而 setdefault 方法在键不存在时,
#会返回默认值并在字典中创建一个新的键值对
kwargs.setdefault('default', set_default)
#json 模块中的 dump 函数用于将Python对象转换为JSON格式的字符串并写入到一个文件中
json.dump(obj, file, **kwargs)

# 直接json.dumps返回格式化的json str
def dump_to_str(self, obj, **kwargs):
kwargs.setdefault('default', set_default)
return json.dumps(obj, **kwargs)
  • 对外读写接口,屏蔽掉具体 handler 子类
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# 目前已经提供的handler
file_handlers = {
'json': JsonHandler(),
'yaml': YamlHandler(),
'yml': YamlHandler(),
'pickle': PickleHandler(),
'pkl': PickleHandler()
}

# 对外统一文件读取接口
def load(file, file_format=None, **kwargs):
# 1 输入参数检查
if isinstance(file, Path):
file = str(file)
if file_format is None and is_str(file):
file_format = file.split('.')[-1]
if file_format not in file_handlers:
raise TypeError(f'Unsupported format: {file_format}')

# 2 基于文件格式,选择不同的handler
handler = file_handlers[file_format]

# 3 读取文件内容
if is_str(file):
obj = handler.load_from_path(file, **kwargs)
elif hasattr(file, 'read'):
obj = handler.load_from_fileobj(file, **kwargs)
else:
raise TypeError('"file" must be a filepath str or a file-object')
return obj

# 文件写流程也是一样的
def dump(obj, file=None, file_format=None, **kwargs):

那么它的用法(对外接口)就十分简洁了:

1
2
3
4
5
6
7
8
import mmcv

# load data from a file
data = mmcv.load('test.json')
data = mmcv.load('test.yaml')
data = mmcv.load('test.pkl')

mmcv.dump(data, 'out.pkl')

自定义文件类型开发:

如果你需要的文件格式不在上述列表,如何进行自定义扩展开发呢?这里以读写 .npy 文件为例进行简要代码构建。

  • 继承 BaseFileHandler,然后实现抽象方法,最后注册到 MMCV 中
1
2
3
4
5
6
7
8
9
10
11
@register_handler('npy')
class NpyHandler(BaseFileHandler):
def load_from_fileobj(self, file, **kwargs):
return np.load(file)

def dump_to_fileobj(self, obj, file, **kwargs):
np.save(file, obj)

def dump_to_str(self, obj, **kwargs):
# 实际上这么写没有意义,这里只是举例
return obj.tobytes()

需要特别说明的是 @register_handler('npy'),这是一个装饰器,目的是把我们刚才实现的 handler 注册到 MMCV 中,然后 MMCV 就可以直接找到该 handler 了,装饰器的核心代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def register_handler(file_formats, **kwargs):
def wrap(cls):
# 这句话其实核心是:file_handlers[ext] = handler
# 把我们写的handler类设置到file_handlers的字典中
_register_handler(cls(**kwargs), file_formats)
return cls

return wrap

>>> #file_handlers变成:
file_handlers = {
'json': JsonHandler(),
'yaml': YamlHandler(),
'yml': YamlHandler(),
'pickle': PickleHandler(),
'pkl': PickleHandler(),
'npy': NpyHandler()
}