Update README.md
Browse files
README.md
CHANGED
@@ -29,3 +29,28 @@ state_dict = torch.load("resnet_gps_model.pth")
|
|
29 |
resnet.load_state_dict(state_dict)
|
30 |
resnet.eval()
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
resnet.load_state_dict(state_dict)
|
30 |
resnet.eval()
|
31 |
|
32 |
+
This is our customresnetmodel
|
33 |
+
|
34 |
+
class CustomResNetModel(nn.Module):
|
35 |
+
def __init__(self, model_name="microsoft/resnet-18", num_classes=2):
|
36 |
+
super(CustomResNetModel, self).__init__()
|
37 |
+
# Load pre-trained ResNet from Hugging Face
|
38 |
+
self.resnet = AutoModelForImageClassification.from_pretrained(model_name)
|
39 |
+
|
40 |
+
# Adjust the classifier layer to output the desired number of classes
|
41 |
+
in_features = self.resnet.classifier[0].in_features # Assuming the last layer is a Linear layer
|
42 |
+
self.resnet.classifier = nn.Sequential(
|
43 |
+
nn.Flatten(),
|
44 |
+
nn.Linear(in_features, num_classes)
|
45 |
+
)
|
46 |
+
|
47 |
+
def forward(self, x):
|
48 |
+
return self.resnet(x)
|
49 |
+
|
50 |
+
def save_model(self, save_path):
|
51 |
+
"""Save model locally using the Hugging Face format."""
|
52 |
+
self.save_pretrained(save_path)
|
53 |
+
|
54 |
+
def push_model(self, repo_name):
|
55 |
+
"""Push the model to the Hugging Face Hub."""
|
56 |
+
self.push_to_hub(repo_name)
|