|
import os |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
class Encoder(nn.Module): |
|
def __init__(self): |
|
super(Encoder, self).__init__() |
|
|
|
basemodel_name = 'tf_efficientnet_b5_ap' |
|
print('Loading base model ()...'.format(basemodel_name), end='') |
|
repo_path = os.path.join(os.path.dirname(__file__), 'efficientnet_repo') |
|
basemodel = torch.hub.load(repo_path, basemodel_name, pretrained=False, source='local') |
|
print('Done.') |
|
|
|
|
|
print('Removing last two layers (global_pool & classifier).') |
|
basemodel.global_pool = nn.Identity() |
|
basemodel.classifier = nn.Identity() |
|
|
|
self.original_model = basemodel |
|
|
|
def forward(self, x): |
|
features = [x] |
|
for k, v in self.original_model._modules.items(): |
|
if (k == 'blocks'): |
|
for ki, vi in v._modules.items(): |
|
features.append(vi(features[-1])) |
|
else: |
|
features.append(v(features[-1])) |
|
return features |
|
|
|
|
|
|