import torch # for model import torch.nn as nn import torchvision.models as models #to load vgg 19 model class VGGNet(nn.Module): def __init__(self): super(VGGNet, self).__init__() self.chosen_features = ['0', '5', '10', '19', '28'] self.vgg = models.vgg19(pretrained = True).features #select only certain layers to extract fetaures def forward(self,x): features = [] #returns features from selected conv layers from VGG19 pretrained model for layer_num, layer in self.vgg._modules.items(): x = layer(x) if layer_num in self.chosen_features: features.append(x) return features