增删pth文件特定block
增删pth文件特定block
我在不同任务上测试不同上下采样算子效果的时候,经常要重新训练网络,在之前无论多大的模型我都是从头开始训练,但是这样非常耗时间,本文介绍如何在预训练好的模型文件中替换某一个特定的 block(算子)进行训练
简单替换:
该方式适用于简单的 block 替换,比如我们想把原分类网络从 5 分类变为 10 分类,我们可以这样做:
1 |
|
从pth文件中复制:
下面代码的思路是:先实例化我们修改过 block 后的模型,再把对应 pth 文件中的对应不用替换的 block 加载进来(对于大模型这种方法仍然不是很实用)
1 |
|
其中表达式 k: v for k, v in ... if ...
是字典推导式(dictionary comprehension)的语法。k: v
表示的是字典中的键值对关系。冒号 :
在这里用来分隔键和值,工作原理如下:
- 遍历
pretrained_dict.items()
,这会迭代字典pretrained_dict
中的所有键值对 - 对于每一对
(k, v)
,检查条件if k in model_dict
是否成立,即检查当前键k
是否存在于model_dict
中 - 如果条件满足,键值对
k: v
就会被包含进新的字典中