File size: 1,106 Bytes
31b6c1e 50d1f36 31b6c1e 50d1f36 31b6c1e 50d1f36 31b6c1e 50d1f36 bfe4f60 50d1f36 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
<h1>Train Dataset Means and stds</h1>
```
lat_mean = 39.951648580775
lat_std = 0.0006491166433423773
lon_mean = -75.19144282374886
lon_std = 0.0006635364490202568
```
<h1>Custom Class Model</h1>
```
from transformers import ViTModel
class ViTGPSModel(nn.Module):
def __init__(self, output_size=2):
super().__init__()
self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
self.regression_head = nn.Linear(self.vit.config.hidden_size, output_size)
def forward(self, x):
cls_embedding = self.vit(x).last_hidden_state[:, 0, :]
return self.regression_head(cls_embedding)
```
<h1>Running Inference</h1>
```
model_path = hf_hub_download(repo_id="Latitude-Attitude/vit-gps-coordinates-predictor-with-filter", filename="vit-gps-coordinates-predictor-with-filter-6.pth")
model = torch.load(model_path)
model.eval()
with torch.no_grad():
for images in dataloader:
images = images.to(device)
outputs = model(images)
preds = outputs.cpu() * torch.tensor([lat_std, lon_std]) + torch.tensor([lat_mean, lon_mean])
``` |