Sirreajohn commited on
Commit
5d3b36b
·
1 Parent(s): d09368b

updated models to be loaded on CPU by default

Browse files
Files changed (1) hide show
  1. models.py +2 -2
models.py CHANGED
@@ -5,13 +5,13 @@ from torchvision.models import efficientnet_b2, EfficientNet_B2_Weights
5
 
6
 
7
  def get_vit_16_base_transformer():
8
- vit_b_16_model = torch.load("models/vit_16_base_custom_head_3_classes.pth")
9
  vit_b_16_transforms = ViT_B_16_Weights.DEFAULT.transforms()
10
 
11
  return vit_b_16_model, vit_b_16_transforms
12
 
13
  def get_effnet_b2():
14
- eff_net_b2_model = torch.load("models/eff_netb2_custom_head_3_classes.pth")
15
  eff_net_b2_transforms = EfficientNet_B2_Weights.DEFAULT.transforms()
16
 
17
  return eff_net_b2_model, eff_net_b2_transforms
 
5
 
6
 
7
  def get_vit_16_base_transformer():
8
+ vit_b_16_model = torch.load("models/vit_16_base_custom_head_3_classes.pth", map_location = torch.device('cpu'))
9
  vit_b_16_transforms = ViT_B_16_Weights.DEFAULT.transforms()
10
 
11
  return vit_b_16_model, vit_b_16_transforms
12
 
13
  def get_effnet_b2():
14
+ eff_net_b2_model = torch.load("models/eff_netb2_custom_head_3_classes.pth", map_location = torch.device('cpu'))
15
  eff_net_b2_transforms = EfficientNet_B2_Weights.DEFAULT.transforms()
16
 
17
  return eff_net_b2_model, eff_net_b2_transforms