Spaces:
Sleeping
Sleeping
Sirreajohn
commited on
Commit
·
5d3b36b
1
Parent(s):
d09368b
updated models to be loaded on CPU by default
Browse files
models.py
CHANGED
@@ -5,13 +5,13 @@ from torchvision.models import efficientnet_b2, EfficientNet_B2_Weights
|
|
5 |
|
6 |
|
7 |
def get_vit_16_base_transformer():
|
8 |
-
vit_b_16_model = torch.load("models/vit_16_base_custom_head_3_classes.pth")
|
9 |
vit_b_16_transforms = ViT_B_16_Weights.DEFAULT.transforms()
|
10 |
|
11 |
return vit_b_16_model, vit_b_16_transforms
|
12 |
|
13 |
def get_effnet_b2():
|
14 |
-
eff_net_b2_model = torch.load("models/eff_netb2_custom_head_3_classes.pth")
|
15 |
eff_net_b2_transforms = EfficientNet_B2_Weights.DEFAULT.transforms()
|
16 |
|
17 |
return eff_net_b2_model, eff_net_b2_transforms
|
|
|
5 |
|
6 |
|
7 |
def get_vit_16_base_transformer():
|
8 |
+
vit_b_16_model = torch.load("models/vit_16_base_custom_head_3_classes.pth", map_location = torch.device('cpu'))
|
9 |
vit_b_16_transforms = ViT_B_16_Weights.DEFAULT.transforms()
|
10 |
|
11 |
return vit_b_16_model, vit_b_16_transforms
|
12 |
|
13 |
def get_effnet_b2():
|
14 |
+
eff_net_b2_model = torch.load("models/eff_netb2_custom_head_3_classes.pth", map_location = torch.device('cpu'))
|
15 |
eff_net_b2_transforms = EfficientNet_B2_Weights.DEFAULT.transforms()
|
16 |
|
17 |
return eff_net_b2_model, eff_net_b2_transforms
|