1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
| import torch import torch.nn as nn import torchvision.models as models from torchvision.transforms import transforms import torchvision.utils as vutils from tensorboardX import SummaryWriter import matplotlib.pyplot as plt from PIL import Image
vgg16_pretrained = models.vgg16(pretrained=True) writer = SummaryWriter("./Visualize_featuremap")
class MyModel(nn.Module): def __init__(self): super(MyModel, self).__init__() self.encoder1 = vgg16_pretrained.features[:5] self.encoder2 = vgg16_pretrained.features[5:10] self.encoder3 = vgg16_pretrained.features[10:17]
def forward(self, x): x = self.encoder1(x) x = self.encoder2(x) x = self.encoder3(x) return x
def visualize(module, input, output): output = output.squeeze(0).unsqueeze(1) output_image = vutils.make_grid(output, nrow=8, padding=2, normalize=True) writer.add_image('output:', img_tensor=output_image)
mymodel = MyModel() mymodel.encoder1.register_forward_hook(lambda module, input, output: visualize(module, input, output)) mymodel.encoder2.register_forward_hook(lambda module, input, output: visualize(module, input, output)) mymodel.encoder3.register_forward_hook(lambda module, input, output: visualize(module, input, output))
input_image = Image.open("./Lena.jpg").convert("RGB") transforms = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ])
input_image = transforms(input_image).unsqueeze(0)
writer.add_image('original_image', img_tensor=input_image.squeeze(0)) _ = mymodel(input_image) writer.close()
|