import timm import torch.nn as nn class Build_Custom_Model(nn.Module): def __init__(self, model_name, target_size, pretrained=False): super().__init__() self.model = timm.create_model(model_name, pretrained=pretrained, in_chans=1) if(model_name=="vit_base_patch16_224" or model_name=="swin_base_patch4_window7_224"): self.n_features = self.model.head.in_features self.model.head = nn.Linear(self.n_features, target_size) if(model_name=="resnet34d"): self.n_features = self.model.fc.in_features self.model.fc = nn.Linear(self.n_features, target_size) if(model_name=="resnet18d"): self.n_features = self.model.fc.in_features self.model.fc = nn.Linear(self.n_features, target_size) if(model_name=="tf_efficientnet_b7_ns"): self.n_features = self.model.classifier.in_features self.model.classifier = nn.Linear(self.n_features, target_size) if(model_name=="tf_efficientnet_b0_ns"): self.n_features = self.model.classifier.in_features self.model.classifier = nn.Linear(self.n_features, target_size) if(model_name=="tf_efficientnet_lite0"): self.n_features = self.model.classifier.in_features self.model.classifier = nn.Linear(self.n_features, target_size) if(model_name=="mobilenetv2_050"): self.n_features = self.model.classifier.in_features self.model.classifier = nn.Linear(self.n_features, target_size) if(model_name=="eca_nfnet_l0"): self.n_features = self.model.head.fc.in_features self.model.head.fc = nn.Linear(self.n_features, target_size) def forward(self, x): output = self.model(x) return output def reshape_transform(tensor, height=7, width=7): result = tensor.reshape(tensor.size(0), height, width, tensor.size(2)) # Bring the channels to the first dimension, # like in CNNs. result = result.transpose(2, 3).transpose(1, 2) return result