|
### Train Dataset Means and stds |
|
``` |
|
lat_mean = 39.951572994535354 |
|
lat_std = 0.0006556104083785816 |
|
lon_mean = -75.19137012508818 |
|
lon_std = 0.0006895844560639971 |
|
``` |
|
### Custom Model Class |
|
``` |
|
from transformers import ViTModel |
|
class ViTGPSModel(nn.Module): |
|
def __init__(self, output_size=2): |
|
super().__init__() |
|
self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k") |
|
self.regression_head = nn.Linear(self.vit.config.hidden_size, output_size) |
|
|
|
def forward(self, x): |
|
cls_embedding = self.vit(x).last_hidden_state[:, 0, :] |
|
return self.regression_head(cls_embedding) |
|
``` |
|
### Running Inference |
|
``` |
|
model_path = hf_hub_download(repo_id="Latitude-Attitude/vit-gps-coordinates-predictor", filename="vit-gps-coordinates-predictor.pth") |
|
model = torch.load(model_path) |
|
model.eval() |
|
|
|
with torch.no_grad(): |
|
for images in dataloader: |
|
images = images.to(device) |
|
outputs = model(images) |
|
preds = outputs.cpu() * torch.tensor([lat_std, lon_std]) + torch.tensor([lat_mean, lon_mean]) |
|
``` |