File size: 1,068 Bytes
73599b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
### Train Dataset Means and stds
lat_mean = 39.951572994535354

lat_std = 0.0006556104083785816
lon_mean = -75.19137012508818

lon_std = 0.0006895844560639971

### Custom Model Class
from transformers import ViTModel
class ViTGPSModel(nn.Module):
    def __init__(self, output_size=2):

        super().__init__()

        self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")

        self.regression_head = nn.Linear(self.vit.config.hidden_size, output_size)


    def forward(self, x):

        cls_embedding = self.vit(x).last_hidden_state[:, 0, :]

        return self.regression_head(cls_embedding)


### Running Inference

model_path = hf_hub_download(repo_id="Latitude-Attitude/vit-gps-coordinates-predictor", filename="vit-gps-coordinates-predictor.pth")
model = torch.load(model_path)

model.eval() 



with torch.no_grad():
    for images in dataloader:

        images = images.to(device)

        outputs = model(images)

        preds = outputs.cpu() * torch.tensor([lat_std, lon_std]) + torch.tensor([lat_mean, lon_mean])