Sirreajohn commited on
Commit
c442067
·
1 Parent(s): ed99578

fixed map location issue in models.py

Browse files
Files changed (1) hide show
  1. models.py +2 -2
models.py CHANGED
@@ -3,13 +3,13 @@ from torchvision.models import ViT_B_16_Weights, EfficientNet_B2_Weights
3
 
4
 
5
  def get_vit_16_base_transformer():
6
- vit_b_16_model = torch.load(r"models\ViT_16_base_101_classes_pretrained_custom_head.pth")
7
  vit_b_16_transforms = ViT_B_16_Weights.DEFAULT.transforms()
8
 
9
  return vit_b_16_model, vit_b_16_transforms
10
 
11
  def get_effnet_b2():
12
- eff_net_b2_model = torch.load(r"models\effnet_b2_101_classes_pretrained_custom_head.pth")
13
  eff_net_b2_transforms = EfficientNet_B2_Weights.DEFAULT.transforms()
14
 
15
  return eff_net_b2_model, eff_net_b2_transforms
 
3
 
4
 
5
  def get_vit_16_base_transformer():
6
+ vit_b_16_model = torch.load("models/ViT_16_base_101_classes_pretrained_custom_head.pth", map_location = torch.device("cpu"))
7
  vit_b_16_transforms = ViT_B_16_Weights.DEFAULT.transforms()
8
 
9
  return vit_b_16_model, vit_b_16_transforms
10
 
11
  def get_effnet_b2():
12
+ eff_net_b2_model = torch.load("models/effnet_b2_101_classes_pretrained_custom_head.pth", map_location = torch.device("cpu"))
13
  eff_net_b2_transforms = EfficientNet_B2_Weights.DEFAULT.transforms()
14
 
15
  return eff_net_b2_model, eff_net_b2_transforms