File size: 541 Bytes
3a90c4a
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from torch import nn
import torchvision
import torch
def create_model(num_classes:int=3):
    weights = torchvision.models.EfficientNet_B2_Weights.IMAGENET1K_V1
    model=torchvision.models.efficientnet_b2(weights=weights)
    transform=weights.transforms()
    
    for param in model.parameters():
        param.requires_grad=False
    
    model.classifier=nn.Sequential(nn.Dropout(p=0.3, inplace=True),
                                         nn.Linear(in_features=1408, out_features=num_classes, bias=True))
    return model,transform