FrancoMango commited on
Commit
50d1f36
·
verified ·
1 Parent(s): 7d2f136

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +32 -6
README.md CHANGED
@@ -1,13 +1,39 @@
1
  <h1>Train Dataset Means and stds</h1>
2
- lat_mean = 39.95157130295544
3
 
4
- lat_std = 0.0006593704228342234
 
 
 
 
 
5
 
6
- lon_mean = -75.19136178838008
7
 
8
- lon_std = 0.0006865423903444358
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  <h1>Running Inference</h1>
11
- model_path = hf_hub_download(repo_id="Latitude-Attitude/vit-gps-coordinates-predictor-with-filter", filename="vit-gps-coordinates-predictor-with-filter-3.pth") model = torch.load(model_path) model.eval()
12
 
13
- with torch.no_grad(): for images in dataloader: images = images.to(device) outputs = model(images) preds = outputs.cpu() * torch.tensor([lat_std, lon_std]) + torch.tensor([lat_mean, lon_mean])
 
 
 
 
 
 
 
 
 
 
 
 
1
  <h1>Train Dataset Means and stds</h1>
 
2
 
3
+ ```
4
+ lat_mean = 39.951648580775
5
+ lat_std = 0.0006491166433423773
6
+ lon_mean = -75.19144282374886
7
+ lon_std = 0.0006635364490202568
8
+ ```
9
 
10
+ <h1>Custom Class Model</h1>
11
 
12
+ ```
13
+ from transformers import ViTModel
14
+ class ViTGPSModel(nn.Module):
15
+ def __init__(self, output_size=2):
16
+ super().__init__()
17
+ self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
18
+ self.regression_head = nn.Linear(self.vit.config.hidden_size, output_size)
19
+
20
+ def forward(self, x):
21
+ cls_embedding = self.vit(x).last_hidden_state[:, 0, :]
22
+ return self.regression_head(cls_embedding)
23
+
24
+ ```
25
 
26
  <h1>Running Inference</h1>
 
27
 
28
+ ```
29
+ model_path = hf_hub_download(repo_id="Latitude-Attitude/vit-gps-coordinates-predictor", filename="vit-gps-coordinates-predictor.pth")
30
+ model = torch.load(model_path)
31
+ model.eval()
32
+
33
+ with torch.no_grad():
34
+ for images in dataloader:
35
+ images = images.to(device)
36
+ outputs = model(images)
37
+ preds = outputs.cpu() * torch.tensor([lat_std, lon_std]) + torch.tensor([lat_mean, lon_mean])
38
+
39
+ ```