File size: 570 Bytes
12f775f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch
from torchvision.models import  ViT_B_16_Weights, EfficientNet_B2_Weights


def get_vit_16_base_transformer():
    vit_b_16_model = torch.load(r"models\ViT_16_base_101_classes_pretrained_custom_head.pth")
    vit_b_16_transforms = ViT_B_16_Weights.DEFAULT.transforms()

    return vit_b_16_model, vit_b_16_transforms

def get_effnet_b2():
    eff_net_b2_model = torch.load(r"models\effnet_b2_101_classes_pretrained_custom_head.pth")
    eff_net_b2_transforms = EfficientNet_B2_Weights.DEFAULT.transforms()

    return eff_net_b2_model, eff_net_b2_transforms