rgny commited on
Commit
3a90c4a
1 Parent(s): 4f15341

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +14 -14
model.py CHANGED
@@ -1,14 +1,14 @@
1
- from torch import nn
2
- import torchvision
3
- import torch
4
- def create_model(num_classes:int=3)
5
- weights = torchvision.models.EfficientNet_B2_Weights.IMAGENET1K_V1
6
- model=torchvision.models.efficientnet_b2(weights=weights)
7
- transform=weights.transforms()
8
-
9
- for param in model.parameters():
10
- param.requires_grad=False
11
-
12
- model.classifier=nn.Sequential(nn.Dropout(p=0.3, inplace=True),
13
- nn.Linear(in_features=1408, out_features=num_classes, bias=True))
14
- return model,transform
 
1
+ from torch import nn
2
+ import torchvision
3
+ import torch
4
+ def create_model(num_classes:int=3):
5
+ weights = torchvision.models.EfficientNet_B2_Weights.IMAGENET1K_V1
6
+ model=torchvision.models.efficientnet_b2(weights=weights)
7
+ transform=weights.transforms()
8
+
9
+ for param in model.parameters():
10
+ param.requires_grad=False
11
+
12
+ model.classifier=nn.Sequential(nn.Dropout(p=0.3, inplace=True),
13
+ nn.Linear(in_features=1408, out_features=num_classes, bias=True))
14
+ return model,transform