增删pth文件特定block

增删pth文件特定block

​ 我在不同任务上测试不同上下采样算子效果的时候,经常要重新训练网络,在之前无论多大的模型我都是从头开始训练,但是这样非常耗时间,本文介绍如何在预训练好的模型文件中替换某一个特定的 block(算子)进行训练

简单替换:

​ 该方式适用于简单的 block 替换,比如我们想把原分类网络从 5 分类变为 10 分类,我们可以这样做:

1
2
3
model = torchvision.models.resnet50(pretrained=True)
# 修改最后线性层的输出通道数
model.fc = nn.Linear(2048,10)

从pth文件中复制:

​ 下面代码的思路是:先实例化我们修改过 block 后的模型,再把对应 pth 文件中的对应不用替换的 block 加载进来(对于大模型这种方法仍然不是很实用)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class CNN(nn.Module):
"""
在里面复制源码再稍微修改,改成我们想要的模型
"""

resnet50 = torchvision.models.resnet50(pretrained=True)
cnn = CNN(Bottleneck, [3, 4, 6, 3])
#读取参数
pretrained_dict = resnet50.state_dict()
model_dict = cnn.state_dict()
# 将pretrained_dict里不属于model_dict的键剔除掉
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 更新现有的model_dict
model_dict.update(pretrained_dict)
# 加载我们真正需要的state_dict
cnn.load_state_dict(model_dict)
print(cnn)

其中表达式 k: v for k, v in ... if ... 是字典推导式(dictionary comprehension)的语法。k: v 表示的是字典中的键值对关系。冒号 : 在这里用来分隔键和值,工作原理如下:

  • 遍历 pretrained_dict.items(),这会迭代字典 pretrained_dict 中的所有键值对
  • 对于每一对 (k, v),检查条件 if k in model_dict 是否成立,即检查当前键 k 是否存在于 model_dict
  • 如果条件满足,键值对 k: v 就会被包含进新的字典中