TensorBoardX教程(2)

Tensorboard的示例应用

Tensorboard进行中间特征图可视化:

​ 采用了注册钩子函数,再钩子函数中记录特征图,并将其可视化下来的方案,代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision.transforms import transforms
import torchvision.utils as vutils
from tensorboardX import SummaryWriter
import matplotlib.pyplot as plt
from PIL import Image


vgg16_pretrained = models.vgg16(pretrained=True)
writer = SummaryWriter("./Visualize_featuremap")


class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.encoder1 = vgg16_pretrained.features[:5]
self.encoder2 = vgg16_pretrained.features[5:10]
self.encoder3 = vgg16_pretrained.features[10:17]

def forward(self, x):
x = self.encoder1(x)
x = self.encoder2(x)
x = self.encoder3(x)
return x


def visualize(module, input, output):
output = output.squeeze(0).unsqueeze(1)
output_image = vutils.make_grid(output, nrow=8, padding=2, normalize=True)
writer.add_image('output:', img_tensor=output_image)

mymodel = MyModel()
mymodel.encoder1.register_forward_hook(lambda module, input, output: visualize(module, input, output))
mymodel.encoder2.register_forward_hook(lambda module, input, output: visualize(module, input, output))
mymodel.encoder3.register_forward_hook(lambda module, input, output: visualize(module, input, output))

input_image = Image.open("./Lena.jpg").convert("RGB")
transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# writer.add_image('original_image', img_tensor=torch.Tensor(input_image).permute(2, 0, 1))
input_image = transforms(input_image).unsqueeze(0)

writer.add_image('original_image', img_tensor=input_image.squeeze(0))
_ = mymodel(input_image)
writer.close()

需要注意的小细节:

  • torchvision.utils.make_grid函数本质上是将batch沿着长与宽展开,他并不会改变通道数,如果需要将多通道数的特征图可视化,需要将通道数积累到batch维度上面去,并且保留通道维度(把他当成灰度图像可视化)
  • TensorBoard本身是不能调整图像配色的,如果需要更加好看的图像配色,可以使用matplotlib将灰度像素映射到color bar中,再传入Tensorboard

效果如下图所示:

Lena图提取特征后的feature map