sqiud commited on
Commit
a040736
·
verified ·
1 Parent(s): 6524ea5

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +66 -0
README.md CHANGED
@@ -13,6 +13,72 @@ lat_std = 0.0006361722351128644 \
13
  lon_mean = -75.19150880602636 \
14
  lon_std = 0.000611411894337979
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  The model can be loaded using:
17
  ```
18
  from huggingface_hub import hf_hub_download
 
13
  lon_mean = -75.19150880602636 \
14
  lon_std = 0.000611411894337979
15
 
16
+ The model implementation is found here:
17
+ ```
18
+ import torch
19
+ import torch.nn as nn
20
+ import torchvision.models as models
21
+ import torchvision.transforms as transforms
22
+ from torch.utils.data import DataLoader, Dataset
23
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
24
+ from huggingface_hub import PyTorchModelHubMixin
25
+ from PIL import Image
26
+ import os
27
+ import numpy as np
28
+ from transformers import AutoModel
29
+
30
+ class MultiModalModel(nn.Module):
31
+ def __init__(self, image_model_name="google/vit-base-patch16-224", num_gps_features=2, output_dim=2):
32
+ super(MultiModalModel, self).__init__()
33
+ # Load Vision Transformer for feature extraction
34
+ self.image_model = AutoModel.from_pretrained(image_model_name, output_hidden_states=True)
35
+
36
+ # Reduce image features to match GPS features
37
+ self.image_fc = nn.Sequential(
38
+ nn.Linear(self.image_model.config.hidden_size, 256),
39
+ nn.ReLU(),
40
+ )
41
+
42
+ # Process GPS features
43
+ self.gps_fc = nn.Sequential(
44
+ nn.Linear(num_gps_features, 128),
45
+ nn.ReLU(),
46
+ nn.Dropout(0.3),
47
+ nn.Linear(128, 256),
48
+ )
49
+
50
+ # Combine image and GPS features for regression
51
+ self.regressor = nn.Sequential(
52
+ nn.Linear(256 + 256, 512), # 256 from image + 256 from GPS
53
+ nn.ReLU(),
54
+ nn.Dropout(0.4),
55
+ nn.Linear(512, output_dim),
56
+ )
57
+
58
+ def forward(self, image, gps):
59
+ # Extract image features from the last hidden state
60
+ image_outputs = self.image_model(image)
61
+ image_features = image_outputs.last_hidden_state[:, 0, :] # CLS token features
62
+ image_features = self.image_fc(image_features)
63
+
64
+ # Process GPS features
65
+ gps_features = self.gps_fc(gps)
66
+
67
+ # Concatenate image and GPS features
68
+ combined_features = torch.cat([image_features, gps_features], dim=1)
69
+
70
+ # Final regression
71
+ return self.regressor(combined_features)
72
+
73
+ def save_model(self, save_path):
74
+ """Save model locally using the Hugging Face format."""
75
+ self.save_pretrained(save_path)
76
+
77
+ def push_model(self, repo_name):
78
+ """Push the model to the Hugging Face Hub."""
79
+ self.push_to_hub(repo_name)
80
+ ```
81
+
82
  The model can be loaded using:
83
  ```
84
  from huggingface_hub import hf_hub_download