Extract Feature from a Pretrained Pytorch Model

Code

1
2
3
4
5
6
import torch
import torchvision.models.vgg as models
input = torch.rand(1, 3, 5, 5)
vgg16 = models.vgg16(pretrained=True)
output = vgg16.features\[:3\](input)
print(output)

其中,vgg16.features[:3]的意思是只选取vgg16网络的前三层,然后input作为输入。

Reference

https://discuss.pytorch.org/t/extracting-and-using-features-from-a-pretrained-model/20723/11