import torch import torchvision from torch import nn def create_effnet_b2_instance(num_classes = 3): # fetch the model's pretrained weights effnetb2_pretrained_weights = torchvision.models.EfficientNet_B2_Weights.DEFAULT # fetch the preprocessing transforms effnetb2_transforms = effnetb2_pretrained_weights.transforms() # get the model and load the pretrained weighits effnetb2 = torchvision.models.efficientnet_b2(weights=effnetb2_pretrained_weights) # freeze the feature extractor for param in effnetb2.parameters(): param.requires_grad = False # fix the output effnetb2.classifier = nn.Sequential( nn.Dropout(p = 0.3,inplace=True), nn.Linear(in_features = 1408,out_features = num_classes) ) return effnetb2_transforms,effnetb2