ailm's picture
Upload 3 files
ed7df54 verified
raw
history blame contribute delete
684 Bytes
import torch # for model
import torch.nn as nn
import torchvision.models as models #to load vgg 19 model
class VGGNet(nn.Module):
def __init__(self):
super(VGGNet, self).__init__()
self.chosen_features = ['0', '5', '10', '19', '28']
self.vgg = models.vgg19(pretrained = True).features #select only certain layers to extract fetaures
def forward(self,x):
features = [] #returns features from selected conv layers from VGG19 pretrained model
for layer_num, layer in self.vgg._modules.items():
x = layer(x)
if layer_num in self.chosen_features:
features.append(x)
return features