sethsuk's picture
Updated Model Inference Information (#2)
73599b7 verified

Train Dataset Means and stds

lat_mean = 39.951572994535354 lat_std = 0.0006556104083785816 lon_mean = -75.19137012508818 lon_std = 0.0006895844560639971

Custom Model Class

from transformers import ViTModel class ViTGPSModel(nn.Module): def init(self, output_size=2): super().init() self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k") self.regression_head = nn.Linear(self.vit.config.hidden_size, output_size)

def forward(self, x):
    cls_embedding = self.vit(x).last_hidden_state[:, 0, :]
    return self.regression_head(cls_embedding)

Running Inference

model_path = hf_hub_download(repo_id="Latitude-Attitude/vit-gps-coordinates-predictor", filename="vit-gps-coordinates-predictor.pth") model = torch.load(model_path) model.eval()

with torch.no_grad(): for images in dataloader: images = images.to(device) outputs = model(images) preds = outputs.cpu() * torch.tensor([lat_std, lon_std]) + torch.tensor([lat_mean, lon_mean])