|
import timm |
|
import torch.nn as nn |
|
|
|
class Build_Custom_Model(nn.Module): |
|
def __init__(self, model_name, target_size, pretrained=False): |
|
super().__init__() |
|
self.model = timm.create_model(model_name, pretrained=pretrained, in_chans=1) |
|
if(model_name=="vit_base_patch16_224" or model_name=="swin_base_patch4_window7_224"): |
|
self.n_features = self.model.head.in_features |
|
self.model.head = nn.Linear(self.n_features, target_size) |
|
if(model_name=="resnet34d"): |
|
self.n_features = self.model.fc.in_features |
|
self.model.fc = nn.Linear(self.n_features, target_size) |
|
if(model_name=="resnet18d"): |
|
self.n_features = self.model.fc.in_features |
|
self.model.fc = nn.Linear(self.n_features, target_size) |
|
if(model_name=="tf_efficientnet_b7_ns"): |
|
self.n_features = self.model.classifier.in_features |
|
self.model.classifier = nn.Linear(self.n_features, target_size) |
|
if(model_name=="tf_efficientnet_b0_ns"): |
|
self.n_features = self.model.classifier.in_features |
|
self.model.classifier = nn.Linear(self.n_features, target_size) |
|
if(model_name=="tf_efficientnet_lite0"): |
|
self.n_features = self.model.classifier.in_features |
|
self.model.classifier = nn.Linear(self.n_features, target_size) |
|
if(model_name=="mobilenetv2_050"): |
|
self.n_features = self.model.classifier.in_features |
|
self.model.classifier = nn.Linear(self.n_features, target_size) |
|
if(model_name=="eca_nfnet_l0"): |
|
self.n_features = self.model.head.fc.in_features |
|
self.model.head.fc = nn.Linear(self.n_features, target_size) |
|
|
|
def forward(self, x): |
|
output = self.model(x) |
|
return output |
|
|
|
def reshape_transform(tensor, height=7, width=7): |
|
result = tensor.reshape(tensor.size(0), |
|
height, width, tensor.size(2)) |
|
|
|
|
|
|
|
result = result.transpose(2, 3).transpose(1, 2) |
|
return result |
|
|