ezelpub commited on
Commit
cceb9a3
·
verified ·
1 Parent(s): 4ab7786

Updated the convNext-GPSPredictor model to the Huggingface

Browse files
Files changed (1) hide show
  1. README.md +51 -9
README.md CHANGED
@@ -1,9 +1,51 @@
1
- ---
2
- tags:
3
- - model_hub_mixin
4
- - pytorch_model_hub_mixin
5
- ---
6
-
7
- This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
8
- - Library: [More Information Needed]
9
- - Docs: [More Information Needed]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Information about the Dataset
2
+
3
+ **Mean Latitude**: 39.95156391970743
4
+ **Latitude Std**: 0.0007633062105681285
5
+ **Mean Longitude**: -75.19148737056214j
6
+ **Longitude Std**: 0.0007871346840888362
7
+
8
+ # Model definition
9
+
10
+ ```python
11
+ class ConvNeXtGPSPredictor(nn.Module, PyTorchModelHubMixin):
12
+ def __init__(self, model_name="facebook/convnext-tiny-224", num_outputs=2):
13
+ super(ConvNeXtGPSPredictor, self).__init__()
14
+
15
+ # Load the ConvNeXt backbone from Hugging Face
16
+ self.backbone = AutoModel.from_pretrained(model_name)
17
+
18
+ # Get feature dimension from the backbone's output
19
+ config = AutoConfig.from_pretrained(model_name)
20
+ feature_dim = config.hidden_sizes[-1] # Corrected attribute for ConvNeXt
21
+
22
+ # Define the GPS regression head
23
+ self.gps_head = nn.Sequential(
24
+ nn.AdaptiveAvgPool2d((1, 1)), # Pool to a single spatial dimension
25
+ nn.Flatten(), # Flatten the tensor
26
+ nn.LayerNorm(feature_dim),
27
+ nn.Linear(feature_dim, num_outputs) # Directly map to 2 GPS coordinates
28
+ )
29
+
30
+ def forward(self, x):
31
+ # Extract features from the backbone
32
+ features = self.backbone(x).last_hidden_state
33
+
34
+ # Pass through the GPS head
35
+ coords = self.gps_head(features)
36
+ return coords
37
+
38
+
39
+ def save_model(self, save_path):
40
+ self.save_pretrained(save_path)
41
+
42
+ def push_model(self, repo_name):
43
+ self.push_to_hub(repo_name)
44
+ ```
45
+
46
+ # How to load the model
47
+
48
+ You can simply load the model by
49
+ ```python
50
+ model = ConvNeXtGPSPredictor.from_pretrained("cis519/convNext-GPSPredictor")
51
+ ```