timm
库实用手册
什么是 timm 库?
PyTorch Image Models (timm)是一个图像模型(models)、层(layers)、实用程序(utilities)、优化器(optimizers)、调度器(schedulers)、数据加载/增强(data-loaders / augmentations)和参考训练/验证脚本(reference training / validation scripts)的集合,目的是将各种SOTA模型组合在一起,从而能够重现ImageNet的训练结果
timm
库实现了最新的几乎所有的具有影响力的视觉模型,它不仅提供了模型的权重,还提供了一个很棒的分布式训练和评估的代码框架,方便后人开发。更难能可贵的是它还在不断地更新迭代新的训练方法,新的视觉模型和优化代码。
源代码链接:huggingface/pytorch-image-models: The largest collection of PyTorch image encoders / backbones.
Introduction
创建模型
1 2 3 4 5 6 7 8 import timmimport torch model = timm.create_model('resnet34' , num_classes=100 ) x = torch.randn(1 , 3 , 224 , 224 ) model(x).shape >>> torch.Size([1 , 100 ])
用 timm
库可以很快创建模型. The create_model
function is a factory method that can be used to create over 300 models that are part of the timm
library.
To create a pretrained model, simply pass in pretrained=True
.
1 2 3 pretrained_resnet_34 = timm.create_model('resnet34' , pretrained=True ) >>> Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth" to /home/tmabraham/.cache/torch/hub/checkpoints/resnet34-43635321. pth
timm.list_models()
返回 timm
中可用模型的完整列表。如果要查看具有预训练权重的模型的完整列表,可以在 list_models
中传入 pretrained=True
1 2 3 4 5 6 7 8 9 avail_pretrained_models = timm.list_models(pretrained=True )len (avail_pretrained_models), avail_pretrained_models[:5 ] >>> (592 , ['adv_inception_v3' , 'bat_resnext26ts' , 'beit_base_patch16_224' , 'beit_base_patch16_224_in22k' , 'beit_base_patch16_384' ])
优化器和学习率调度器
timm
提供了丰富的优化器和学习率调度器实现,支持最新的训练技术:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 from timm.optim import create_optimizer_v2from timm.scheduler import create_scheduler optimizer = create_optimizer_v2( model, opt='adamw' , lr=1e-3 , weight_decay=0.05 ) lr_scheduler, _ = create_scheduler( args, optimizer, num_epochs=300 , warmup_epochs=10 )
支持的优化器包括:SGD, Adam, AdamW, RMSprop, Adagrad, AdaHessian 等。学习率调度器支持:StepLR, Cosine, Plateau, Tanh 等多种策略。
特征提取
1 features = model.forward_features(img)
实用功能
timm
最实用的功能是提供了即插即用的模块化组件 ,避免重复造轮子。以下展示如何利用timm
模块快速搭建复杂结构:
调用经典卷积块
1 2 3 4 5 6 7 8 9 10 from timm.models.layers import ConvNormAct x = ConvNormAct( in_channels=64 , out_channels=128 , kernel_size=3 , stride=2 , act_layer="gelu" )(x)
调用 VIT 的经典 attention 模块:
在 timm
库中,timm/models/vision_transformer.py
(VIT)的源代码中有这样一个模块:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 class Attention (nn.Module): fused_attn: Final[bool ] def __init__ ( self, dim: int , num_heads: int = 8 , qkv_bias: bool = False , qk_norm: bool = False , proj_bias: bool = True , attn_drop: float = 0. , proj_drop: float = 0. , norm_layer: Type [nn.Module] = nn.LayerNorm, ) -> None : super ().__init__() assert dim % num_heads == 0 , 'dim should be divisible by num_heads' self.num_heads = num_heads self.head_dim = dim // num_heads self.scale = self.head_dim ** -0.5 self.fused_attn = use_fused_attn() self.qkv = nn.Linear(dim, dim * 3 , bias=qkv_bias) self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim, bias=proj_bias) self.proj_drop = nn.Dropout(proj_drop) def forward (self, x: torch.Tensor ) -> torch.Tensor: B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3 , self.num_heads, self.head_dim).permute(2 , 0 , 3 , 1 , 4 ) q, k, v = qkv.unbind(0 ) q, k = self.q_norm(q), self.k_norm(k) if self.fused_attn: x = F.scaled_dot_product_attention( q, k, v, dropout_p=self.attn_drop.p if self.training else 0. , ) else : q = q * self.scale attn = q @ k.transpose(-2 , -1 ) attn = attn.softmax(dim=-1 ) attn = self.attn_drop(attn) x = attn @ v x = x.transpose(1 , 2 ).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x
如果要使用类似 timm
中 VIT 一样的 attention 模块,就不用每次都造一次轮子,我们可以直接这样用:
1 from timm.models.vision_transformer import Attention
更多高级用法
特征提取与可视化
timm
提供了灵活的中间层特征访问接口:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 feature_layer_names = timm.get_feature_info(model) features = {}def get_features (name ): def hook (model, input , output ): features[name] = output.detach() return hook model.layer3.register_forward_hook(get_features('layer3' )) model.layer4.register_forward_hook(get_features('layer4' ))import matplotlib.pyplot as pltdef visualize_features (feature_maps, layer_name ): plt.figure(figsize=(12 , 6 )) plt.title(f"{layer_name} Feature Maps" ) for i in range (16 ): plt.subplot(4 , 4 , i+1 ) plt.imshow(feature_maps[0 , i].cpu().numpy(), cmap='viridis' ) plt.axis('off' ) plt.tight_layout() plt.show()
模型融合与集成
timm
支持多种模型融合技术:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 model1 = timm.create_model('resnet50' , pretrained=True ) model2 = timm.create_model('efficientnet_b3' , pretrained=True ) model3 = timm.create_model('vit_base_patch16_224' , pretrained=True )def ensemble_predict (models, input , weights=[0.4 , 0.3 , 0.3 ] ): outputs = [model(input ) for model in models] avg_output = sum (w * o for w, o in zip (weights, outputs)) return avg_outputfrom timm.utils.model_ema import ModelEma ema_model = ModelEma(model, decay=0.999 ) from timm.optim import Swa swa_model = Swa(model, optimizer, swa_start=10 , swa_freq=5 )
自定义模型架构
timm
的模块化设计使得创建自定义模型变得简单:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 from timm.models.vision_transformer import VisionTransformerfrom timm.models.layers import PatchEmbed, Mlpclass CustomViT (VisionTransformer ): def __init__ (self, **kwargs ): super ().__init__(**kwargs) self.extra_mlp = Mlp( in_features=self.embed_dim, hidden_features=self.embed_dim * 4 , act_layer=nn.GELU ) self.head = nn.Sequential( nn.Linear(self.embed_dim, self.embed_dim // 2 ), nn.ReLU(), nn.Linear(self.embed_dim // 2 , self.num_classes) ) def forward_features (self, x ): x = super ().forward_features(x) x = self.extra_mlp(x) return x custom_vit = CustomViT( img_size=224 , patch_size=16 , embed_dim=768 , depth=12 , num_heads=12 , num_classes=1000 )
Pytorch Image Models (timm) | timmdocs