timm

timm 库实用手册

什么是 timm 库?

​ PyTorch Image Models (timm)是一个图像模型(models)、层(layers)、实用程序(utilities)、优化器(optimizers)、调度器(schedulers)、数据加载/增强(data-loaders / augmentations)和参考训练/验证脚本(reference training / validation scripts)的集合,目的是将各种SOTA模型组合在一起,从而能够重现ImageNet的训练结果

timm 库实现了最新的几乎所有的具有影响力的视觉模型,它不仅提供了模型的权重,还提供了一个很棒的分布式训练和评估的代码框架,方便后人开发。更难能可贵的是它还在不断地更新迭代新的训练方法,新的视觉模型和优化代码。

​ 源代码链接:huggingface/pytorch-image-models: The largest collection of PyTorch image encoders / backbones.

Introduction

创建模型

1
2
3
4
5
6
7
8
import timm
import torch

model = timm.create_model('resnet34', num_classes=100)
x = torch.randn(1, 3, 224, 224)
model(x).shape
>>>
torch.Size([1, 100])

​ 用 timm 库可以很快创建模型. The create_model function is a factory method that can be used to create over 300 models that are part of the timm library.

To create a pretrained model, simply pass in pretrained=True.

1
2
3
pretrained_resnet_34 = timm.create_model('resnet34', pretrained=True)
>>>
Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth" to /home/tmabraham/.cache/torch/hub/checkpoints/resnet34-43635321.pth

timm.list_models() 返回 timm 中可用模型的完整列表。如果要查看具有预训练权重的模型的完整列表,可以在 list_models 中传入 pretrained=True

1
2
3
4
5
6
7
8
9
avail_pretrained_models = timm.list_models(pretrained=True)
len(avail_pretrained_models), avail_pretrained_models[:5]
>>>
(592,
['adv_inception_v3',
'bat_resnext26ts',
'beit_base_patch16_224',
'beit_base_patch16_224_in22k',
'beit_base_patch16_384'])

优化器和学习率调度器

timm 提供了丰富的优化器和学习率调度器实现,支持最新的训练技术:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from timm.optim import create_optimizer_v2
from timm.scheduler import create_scheduler

# 创建优化器
optimizer = create_optimizer_v2(
model,
opt='adamw',
lr=1e-3,
weight_decay=0.05
)

# 创建学习率调度器
lr_scheduler, _ = create_scheduler(
args,
optimizer,
num_epochs=300,
warmup_epochs=10
)

​ 支持的优化器包括:SGD, Adam, AdamW, RMSprop, Adagrad, AdaHessian 等。学习率调度器支持:StepLR, Cosine, Plateau, Tanh 等多种策略。

特征提取

1
features = model.forward_features(img)

实用功能

timm 最实用的功能是提供了即插即用的模块化组件,避免重复造轮子。以下展示如何利用timm 模块快速搭建复杂结构:

调用经典卷积块

1
2
3
4
5
6
7
8
9
10
from timm.models.layers import ConvNormAct

# 替代手动编写Conv+BN+ReLU
x = ConvNormAct(
in_channels=64,
out_channels=128,
kernel_size=3,
stride=2,
act_layer="gelu" # 支持多种激活函数
)(x)

调用 VIT 的经典 attention 模块:

​ 在 timm 库中,timm/models/vision_transformer.py (VIT)的源代码中有这样一个模块:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
class Attention(nn.Module):
fused_attn: Final[bool]

def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = False,
proj_bias: bool = True,
attn_drop: float = 0.,
proj_drop: float = 0.,
norm_layer: Type[nn.Module] = nn.LayerNorm,
) -> None:
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.fused_attn = use_fused_attn()

self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim, bias=proj_bias)
self.proj_drop = nn.Dropout(proj_drop)

def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)

if self.fused_attn:
x = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.attn_drop.p if self.training else 0.,
)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v

x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x

​ 如果要使用类似 timm 中 VIT 一样的 attention 模块,就不用每次都造一次轮子,我们可以直接这样用:

1
from timm.models.vision_transformer import Attention

更多高级用法

特征提取与可视化

timm 提供了灵活的中间层特征访问接口:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# 获取所有特征层名称
feature_layer_names = timm.get_feature_info(model)

# 注册钩子获取中间特征
features = {}
def get_features(name):
def hook(model, input, output):
features[name] = output.detach()
return hook

model.layer3.register_forward_hook(get_features('layer3'))
model.layer4.register_forward_hook(get_features('layer4'))

# 可视化特征图
import matplotlib.pyplot as plt

def visualize_features(feature_maps, layer_name):
plt.figure(figsize=(12, 6))
plt.title(f"{layer_name} Feature Maps")
for i in range(16): # 显示前16个通道
plt.subplot(4, 4, i+1)
plt.imshow(feature_maps[0, i].cpu().numpy(), cmap='viridis')
plt.axis('off')
plt.tight_layout()
plt.show()

模型融合与集成

timm 支持多种模型融合技术:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# 创建模型集成
model1 = timm.create_model('resnet50', pretrained=True)
model2 = timm.create_model('efficientnet_b3', pretrained=True)
model3 = timm.create_model('vit_base_patch16_224', pretrained=True)

# 加权平均集成
def ensemble_predict(models, input, weights=[0.4, 0.3, 0.3]):
outputs = [model(input) for model in models]
avg_output = sum(w * o for w, o in zip(weights, outputs))
return avg_output

# 模型权重平均
from timm.utils.model_ema import ModelEma
ema_model = ModelEma(model, decay=0.999) # 指数移动平均

# 随机权重平均 (SWA)
from timm.optim import Swa
swa_model = Swa(model, optimizer, swa_start=10, swa_freq=5)

自定义模型架构

timm 的模块化设计使得创建自定义模型变得简单:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from timm.models.vision_transformer import VisionTransformer
from timm.models.layers import PatchEmbed, Mlp

class CustomViT(VisionTransformer):
def __init__(self, **kwargs):
super().__init__(**kwargs)

# 添加新的模块
self.extra_mlp = Mlp(
in_features=self.embed_dim,
hidden_features=self.embed_dim * 4,
act_layer=nn.GELU
)

# 修改分类头
self.head = nn.Sequential(
nn.Linear(self.embed_dim, self.embed_dim // 2),
nn.ReLU(),
nn.Linear(self.embed_dim // 2, self.num_classes)
)

def forward_features(self, x):
# 复用原始ViT的特征提取
x = super().forward_features(x)

# 添加自定义处理
x = self.extra_mlp(x)
return x

# 实例化自定义模型
custom_vit = CustomViT(
img_size=224,
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
num_classes=1000
)

Pytorch Image Models (timm) | timmdocs