sqiud commited on
Commit
754c234
·
verified ·
1 Parent(s): a040736

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +24 -76
README.md CHANGED
@@ -1,91 +1,16 @@
1
- ---
2
- title: README
3
- emoji: 🐢
4
- colorFrom: gray
5
- colorTo: pink
6
- sdk: static
7
- pinned: false
8
- ---
9
-
10
  Dataset stats: \
11
  lat_mean = 39.951564548022596 \
12
  lat_std = 0.0006361722351128644 \
13
  lon_mean = -75.19150880602636 \
14
  lon_std = 0.000611411894337979
15
 
16
- The model implementation is found here:
17
- ```
18
- import torch
19
- import torch.nn as nn
20
- import torchvision.models as models
21
- import torchvision.transforms as transforms
22
- from torch.utils.data import DataLoader, Dataset
23
- from transformers import AutoImageProcessor, AutoModelForImageClassification
24
- from huggingface_hub import PyTorchModelHubMixin
25
- from PIL import Image
26
- import os
27
- import numpy as np
28
- from transformers import AutoModel
29
-
30
- class MultiModalModel(nn.Module):
31
- def __init__(self, image_model_name="google/vit-base-patch16-224", num_gps_features=2, output_dim=2):
32
- super(MultiModalModel, self).__init__()
33
- # Load Vision Transformer for feature extraction
34
- self.image_model = AutoModel.from_pretrained(image_model_name, output_hidden_states=True)
35
-
36
- # Reduce image features to match GPS features
37
- self.image_fc = nn.Sequential(
38
- nn.Linear(self.image_model.config.hidden_size, 256),
39
- nn.ReLU(),
40
- )
41
-
42
- # Process GPS features
43
- self.gps_fc = nn.Sequential(
44
- nn.Linear(num_gps_features, 128),
45
- nn.ReLU(),
46
- nn.Dropout(0.3),
47
- nn.Linear(128, 256),
48
- )
49
-
50
- # Combine image and GPS features for regression
51
- self.regressor = nn.Sequential(
52
- nn.Linear(256 + 256, 512), # 256 from image + 256 from GPS
53
- nn.ReLU(),
54
- nn.Dropout(0.4),
55
- nn.Linear(512, output_dim),
56
- )
57
-
58
- def forward(self, image, gps):
59
- # Extract image features from the last hidden state
60
- image_outputs = self.image_model(image)
61
- image_features = image_outputs.last_hidden_state[:, 0, :] # CLS token features
62
- image_features = self.image_fc(image_features)
63
-
64
- # Process GPS features
65
- gps_features = self.gps_fc(gps)
66
-
67
- # Concatenate image and GPS features
68
- combined_features = torch.cat([image_features, gps_features], dim=1)
69
-
70
- # Final regression
71
- return self.regressor(combined_features)
72
-
73
- def save_model(self, save_path):
74
- """Save model locally using the Hugging Face format."""
75
- self.save_pretrained(save_path)
76
-
77
- def push_model(self, repo_name):
78
- """Push the model to the Hugging Face Hub."""
79
- self.push_to_hub(repo_name)
80
- ```
81
-
82
  The model can be loaded using:
83
  ```
84
  from huggingface_hub import hf_hub_download
85
  import torch
86
 
87
  # Specify the repository and the filename of the model you want to load
88
- repo_id = "FinalProj5190/ImageToGPSproject-vit-base" # Replace with your repo name
89
  filename = "resnet_gps_regressor_complete.pth"
90
 
91
  model_path = hf_hub_download(repo_id=repo_id, filename=filename)
@@ -94,3 +19,26 @@ model_path = hf_hub_download(repo_id=repo_id, filename=filename)
94
  model_test = torch.load(model_path)
95
  model_test.eval() # Set the model to evaluation mode
96
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  Dataset stats: \
2
  lat_mean = 39.951564548022596 \
3
  lat_std = 0.0006361722351128644 \
4
  lon_mean = -75.19150880602636 \
5
  lon_std = 0.000611411894337979
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  The model can be loaded using:
8
  ```
9
  from huggingface_hub import hf_hub_download
10
  import torch
11
 
12
  # Specify the repository and the filename of the model you want to load
13
+ repo_id = "FinalProj5190/ImageToGPSproject_new_vit" # Replace with your repo name
14
  filename = "resnet_gps_regressor_complete.pth"
15
 
16
  model_path = hf_hub_download(repo_id=repo_id, filename=filename)
 
19
  model_test = torch.load(model_path)
20
  model_test.eval() # Set the model to evaluation mode
21
  ```
22
+
23
+ The model implementation is here:
24
+ ```
25
+ class MultiModalModel(nn.Module):
26
+ def __init__(self, num_classes=2):
27
+ super(MultiModalModel, self).__init__()
28
+ self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
29
+
30
+ # Replace for regression instead of classification
31
+ self.regression_head = nn.Sequential(
32
+ nn.Linear(self.vit.config.hidden_size, 512),
33
+ nn.ReLU(),
34
+ nn.Linear(512, num_classes)
35
+ )
36
+
37
+ def forward(self, x):
38
+ outputs = self.vit(pixel_values=x)
39
+ # Take the last hidden state (CLS token embedding)
40
+ cls_output = outputs.last_hidden_state[:, 0, :]
41
+ # Pass through the regression head
42
+ gps_coordinates = self.regression_head(cls_output)
43
+ return gps_coordinates
44
+ ```