areebalam commited on
Commit
715ab45
·
verified ·
1 Parent(s): 27672e7

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +25 -0
README.md CHANGED
@@ -29,3 +29,28 @@ state_dict = torch.load("resnet_gps_model.pth")
29
  resnet.load_state_dict(state_dict)
30
  resnet.eval()
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  resnet.load_state_dict(state_dict)
30
  resnet.eval()
31
 
32
+ This is our customresnetmodel
33
+
34
+ class CustomResNetModel(nn.Module):
35
+ def __init__(self, model_name="microsoft/resnet-18", num_classes=2):
36
+ super(CustomResNetModel, self).__init__()
37
+ # Load pre-trained ResNet from Hugging Face
38
+ self.resnet = AutoModelForImageClassification.from_pretrained(model_name)
39
+
40
+ # Adjust the classifier layer to output the desired number of classes
41
+ in_features = self.resnet.classifier[0].in_features # Assuming the last layer is a Linear layer
42
+ self.resnet.classifier = nn.Sequential(
43
+ nn.Flatten(),
44
+ nn.Linear(in_features, num_classes)
45
+ )
46
+
47
+ def forward(self, x):
48
+ return self.resnet(x)
49
+
50
+ def save_model(self, save_path):
51
+ """Save model locally using the Hugging Face format."""
52
+ self.save_pretrained(save_path)
53
+
54
+ def push_model(self, repo_name):
55
+ """Push the model to the Hugging Face Hub."""
56
+ self.push_to_hub(repo_name)