File size: 1,856 Bytes
7033fbb
 
 
 
 
4be24cd
 
 
 
7033fbb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b23cba
 
7033fbb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d2669ff
7033fbb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# Image to GPS Model: DINO-ResNet Fusion

## Training Data Statistics
The following mean and standard deviation values were used to normalize the GPS coordinates:

- **Latitude Mean**: {39.95156391970743}
- **Latitude Std**: {0.0007633062105681285}
- **Longitude Mean**: {-75.19148737056214}
- **Longitude Std**: {0.0007871346840888362}

## How to use the model

Please include the definition of the model first before loading the checkpoint:

```python
# Import all the dependencies
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from transformers import AutoImageProcessor, AutoModelForImageClassification, AutoModel
from huggingface_hub import PyTorchModelHubMixin
from PIL import Image
import os
import numpy as np


class EfficientNetGPSModel(nn.Module, PyTorchModelHubMixin):
    def __init__(self, eff_name="efficientnet_b0", num_outputs=2):
        super(EfficientNetGPSModel, self).__init__()
        
        # Load the EfficientNet backbone
        self.efficientnet = getattr(models, eff_name)(pretrained=True)
        
        # Replace the classifier head while keeping the overall structure simple
        in_features = self.efficientnet.classifier[1].in_features
        self.efficientnet.classifier = nn.Sequential(
            nn.Linear(in_features, num_outputs)  # Directly map to GPS coordinates
        )

    def forward(self, x):
        return self.efficientnet(x)
    
    def save_model(self, save_path):
        self.save_pretrained(save_path)

    def push_model(self, repo_name):
        self.push_to_hub(repo_name)
```

Then you can download the model from HF by running, and this will also load the checkpoint automatically:

```python
model = EfficientNetGPSModel.from_pretrained("cis519/efficient-Net")
```