File size: 671 Bytes
9d6c5d2
 
 
 
 
 
 
5d3b36b
9d6c5d2
 
 
 
 
5d3b36b
9d6c5d2
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import torch

from torchvision.models import vit_b_16, ViT_B_16_Weights
from torchvision.models import efficientnet_b2, EfficientNet_B2_Weights


def get_vit_16_base_transformer():
    vit_b_16_model = torch.load("models/vit_16_base_custom_head_3_classes.pth", map_location = torch.device('cpu'))
    vit_b_16_transforms = ViT_B_16_Weights.DEFAULT.transforms()

    return vit_b_16_model, vit_b_16_transforms

def get_effnet_b2():
    eff_net_b2_model = torch.load("models/eff_netb2_custom_head_3_classes.pth", map_location = torch.device('cpu'))
    eff_net_b2_transforms = EfficientNet_B2_Weights.DEFAULT.transforms()

    return eff_net_b2_model, eff_net_b2_transforms