sqiud commited on
Commit
456a7a9
·
verified ·
1 Parent(s): ee72481

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +25 -29
README.md CHANGED
@@ -10,44 +10,40 @@ from huggingface_hub import hf_hub_download
10
  import torch
11
 
12
  # Specify the repository and the filename of the model you want to load
13
- repo_id = "FinalProj5190/ImageToGPSproject-resnet_vit-base" # Replace with your repo name
14
- filename = "resnet_vit_gps_regressor_complete.pth"
15
 
16
  model_path = hf_hub_download(repo_id=repo_id, filename=filename)
17
 
18
- # Load the model using torch
19
- model_test = torch.load(model_path)
20
- model_test.eval() # Set the model to evaluation mode
21
  ```
22
 
23
  The model implementation is here:
24
  ```
25
- from transformers import ViTModel
26
- class HybridGPSModel(nn.Module):
27
- def __init__(self, num_classes=2):
28
- super(HybridGPSModel, self).__init__()
29
- # Pre-trained ResNet for feature extraction
30
- self.resnet = resnet18(pretrained=True)
31
- self.resnet.fc = nn.Identity()
32
-
33
- # Pre-trained Vision Transformer
34
- self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
35
-
36
- # Combined regression head
37
- self.regression_head = nn.Sequential(
38
- nn.Linear(512 + self.vit.config.hidden_size, 128),
39
  nn.ReLU(),
40
- nn.Linear(128, num_classes)
 
41
  )
42
 
43
- def forward(self, x):
44
- resnet_features = self.resnet(x)
45
- vit_outputs = self.vit(pixel_values=x)
46
- vit_features = vit_outputs.last_hidden_state[:, 0, :] # CLS token
47
-
48
- combined_features = torch.cat((resnet_features, vit_features), dim=1)
49
 
50
- # Predict GPS coordinates
51
- gps_coordinates = self.regression_head(combined_features)
52
- return gps_coordinates
53
  ```
 
10
  import torch
11
 
12
  # Specify the repository and the filename of the model you want to load
13
+ repo_id = "FinalProj5190/vit_base_72" # Replace with your repo name
14
+ filename = "resnet_gps_regressor_complete.pth"
15
 
16
  model_path = hf_hub_download(repo_id=repo_id, filename=filename)
17
 
18
+ model_test = MultiModalModel()
19
+ model_test.load_state_dict(torch.load(model_path))
20
+ model_test.eval()
21
  ```
22
 
23
  The model implementation is here:
24
  ```
25
+ from transformers import AutoModel
26
+
27
+ class MultiModalModel(nn.Module):
28
+ def __init__(self, image_model_name='google/vit-base-patch16-224-in21k', output_dim=2):
29
+ super(MultiModalModel, self).__init__()
30
+
31
+ # Load Vision Transformer for feature extraction
32
+ self.image_model = AutoModel.from_pretrained(image_model_name, output_hidden_states=True)
33
+
34
+ # Combine image and GPS features for regression
35
+ self.regressor = nn.Sequential(
36
+ nn.Linear(self.image_model.config.hidden_size, 128),
 
 
37
  nn.ReLU(),
38
+ nn.Dropout(0.3),
39
+ nn.Linear(128, output_dim),
40
  )
41
 
42
+ def forward(self, image):
43
+ # Extract image features from the last hidden state
44
+ image_outputs = self.image_model(image)
45
+ image_features = image_outputs.last_hidden_state[:, 0, :] # CLS token features
 
 
46
 
47
+ # Final regression
48
+ return self.regressor(image_features)
 
49
  ```