MingGatsby's picture
Upload 2 files
7c91758
raw
history blame
No virus
2.07 kB
import timm
import torch.nn as nn
class Build_Custom_Model(nn.Module):
def __init__(self, model_name, target_size, pretrained=False):
super().__init__()
self.model = timm.create_model(model_name, pretrained=pretrained, in_chans=1)
if(model_name=="vit_base_patch16_224" or model_name=="swin_base_patch4_window7_224"):
self.n_features = self.model.head.in_features
self.model.head = nn.Linear(self.n_features, target_size)
if(model_name=="resnet34d"):
self.n_features = self.model.fc.in_features
self.model.fc = nn.Linear(self.n_features, target_size)
if(model_name=="resnet18d"):
self.n_features = self.model.fc.in_features
self.model.fc = nn.Linear(self.n_features, target_size)
if(model_name=="tf_efficientnet_b7_ns"):
self.n_features = self.model.classifier.in_features
self.model.classifier = nn.Linear(self.n_features, target_size)
if(model_name=="tf_efficientnet_b0_ns"):
self.n_features = self.model.classifier.in_features
self.model.classifier = nn.Linear(self.n_features, target_size)
if(model_name=="tf_efficientnet_lite0"):
self.n_features = self.model.classifier.in_features
self.model.classifier = nn.Linear(self.n_features, target_size)
if(model_name=="mobilenetv2_050"):
self.n_features = self.model.classifier.in_features
self.model.classifier = nn.Linear(self.n_features, target_size)
if(model_name=="eca_nfnet_l0"):
self.n_features = self.model.head.fc.in_features
self.model.head.fc = nn.Linear(self.n_features, target_size)
def forward(self, x):
output = self.model(x)
return output
def reshape_transform(tensor, height=7, width=7):
result = tensor.reshape(tensor.size(0),
height, width, tensor.size(2))
# Bring the channels to the first dimension,
# like in CNNs.
result = result.transpose(2, 3).transpose(1, 2)
return result