Pytorch-grad-cam 特征可视化:
简介:
Pytorch-grad-cam
是一个用于PyTorch的库,它提供了多种类激活映射(Class Activation Mapping, CAM)方法,用于可视化卷积神经网络(CNN)的特征。这些方法可以帮助我们理解模型在进行预测时关注的输入图像的哪些区域。pytorch-grad-cam
库支持多种CAM方法,包括但不限于:
GradCAM:通过权重激活映射来可视化模型的注意力区域
GradCAM++:在GradCAM的基础上使用二阶梯度
XGradCAM:类似于GradCAM,但通过归一化激活来缩放梯度
AblationCAM:通过关闭激活并测量输出下降来可视化
ScoreCAM:通过扰动图像并测量输出下降来可视化
EigenCAM:使用主成分分析来可视化
LayerCAM:通过正梯度加权激活来可视化
这个库还支持对象检测和语义分割任务,并且提供了一些高级功能,如图像增强和性能评估指标。其实上面几个方法可视化效果大同小异,用的时候随便选一个 GradCAM++ 就行了
使用方式:
先介绍两个函数:
pytorch_grad_cam.GradCAMPlusPlus
函数
1 cam = pytorch_grad_cam.GradCAMPlusPlus(model, target_layers,reshape_transform=None )
model:为需要可视化的模型,把先前训练好的模型实例传入就好
target_layers:需要可视化的层,类型是 List[torch.nn.Module]
,可以传入单个或者多个层,如果传入多个层,CAM 的注意力会在这些层之间平均,这样会在你不知道哪个层表现最好的时候很有用
reshape_transform:看不懂是什么参数,平时也不用,不管
返回值 cam:一个 callabel,他也是一个神经网络,使用时输入和 model 一样的数据,返回 batch 大小的注意力灰度图
1 2 3 4 5 6 7 8 import pytorch_grad_cam cam = pytorch_grad_cam.GradCAMPlusPlus(model=resnet18, target_layers=target_layers) grayscale_cam = cam(net_input, targets=targets) grayscale_cam = grayscale_cam[0 , :]print (grayscale_cam.shape)
其中 cam(net_input, targets=targets)
的 targets 是指定哪种预测类别注意力使用的,如果没有指定 targets,该函数会自动选用预测概率最大的类别进行可视化,该函数还可以传入 aug_smooth=True and eigen_smooth=True
来进行图像平滑操作
show_cam_on_image
方法
show_cam_on_image
方法用于在单个图像上显示 CAM 结果
1 2 from pytorch_grad_cam.utils.image import show_cam_on_image visualization_img = show_cam_on_image(img: np.ndarray, mask: np.ndarray, use_rgb: bool =False , colormap: int =cv2.COLORMAP_JET, image_weight: float )
img:网络输入图像,即需要可视化注意力的图像
mask:注意力灰度图,就是上面返回的 grayscale_cam
use_rgb:告诉传入的 img 使用 RGB 图像还是 BGR 图像
colormap:类似于 matplotlib 库中的 cmap
返回值 visualization_img:返回传入层 target layers 平均注意力的灰度图 ,用于可视化的 RGB 图像
应用实例:查看 resnet 各个层的注意力
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 import numpy as npimport matplotlib.pyplot as pltfrom PIL import Imageimport torchvision.models as modelsimport torchvision.transforms as transformsimport pytorch_grad_cam from pytorch_grad_cam.utils.image import show_cam_on_image resnet18 = models.resnet18(pretrained=True ) resnet18.eval () target_layers = [resnet18.layer1[1 ].bn2, resnet18.layer2[1 ].bn2, resnet18.layer3[1 ].bn2, resnet18.layer4[1 ].bn2] rgb_img = Image.open ('./1.jpg' ).convert("RGB" ) trans = transforms.Compose([ transforms.ToTensor(), transforms.Resize(224 , antialias=True ), transforms.CenterCrop(224 ) ]) crop_img = trans(rgb_img) net_input = transforms.Normalize((0.485 , 0.456 , 0.406 ), (0.229 , 0.224 , 0.225 ))(crop_img).unsqueeze(0 ) canvas_img = (crop_img*255 ).byte().numpy().transpose(1 , 2 , 0 ) src_img = np.float32(canvas_img) / 255 fig, axes = plt.subplots(1 , 5 ) axes[0 ].imshow(src_img)for cnt, layer in enumerate (target_layers, 1 ): cam = pytorch_grad_cam.GradCAMPlusPlus(model=resnet18, target_layers=[layer]) grayscale_cam = cam(net_input) grayscale_cam = grayscale_cam[0 , :] visualization_img = show_cam_on_image(src_img, grayscale_cam, use_rgb=True ) axes[cnt].imshow(visualization_img) plt.show()