jiayicccc commited on
Commit
eaa16ad
·
verified ·
1 Parent(s): 73eef36

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +187 -0
README.md ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Image to GPS Project - ConvNext, MobileNet and EfficientNet Ensemble
2
+ ```bash
3
+ ## Training Data Statistics
4
+ lat_mean = 39.951537011424264
5
+ lat_std = 0.0006940325318781937
6
+ lon_mean = -75.19152009539549
7
+ lon_std = 0.0007607716964655242
8
+ ```
9
+
10
+ ## How to Load the Model and Perform Inference
11
+ ```bash
12
+ # install dependencies
13
+ pip install geopy datasets torch torchvision huggingface_hub
14
+ # import packages
15
+ import numpy as np
16
+ from geopy.distance import geodesic
17
+ import torch
18
+ from torch.utils.data import DataLoader, Dataset
19
+ from torchvision import transforms
20
+ import torch.nn as nn
21
+ from torchvision.models import mobilenet_v2, MobileNet_V2_Weights, convnext_tiny, ConvNeXt_Tiny_Weights, efficientnet_b0, EfficientNet_B0_Weights
22
+ from datasets import load_dataset
23
+ from huggingface_hub import hf_hub_download
24
+ # load the model
25
+ repo_id = "cis519projectA/Ensemble_ConvNeXt_MobileNet_EfficientNet"
26
+ filename = "ensemble_triple.pth"
27
+ model_path = hf_hub_download(repo_id=repo_id, filename=filename)
28
+ # define models
29
+ class CustomEfficientNetModel(nn.Module):
30
+ def __init__(self, weights=EfficientNet_B0_Weights.DEFAULT, num_classes=2):
31
+ super().__init__()
32
+ self.efficientnet = efficientnet_b0(weights=weights)
33
+ in_features = self.efficientnet.classifier[1].in_features
34
+ self.efficientnet.classifier = nn.Sequential(
35
+ nn.Linear(in_features, 512),
36
+ nn.ReLU(),
37
+ nn.Dropout(p=0.3),
38
+ nn.Linear(512, num_classes)
39
+ )
40
+ for param in self.efficientnet.features[:3].parameters():
41
+ param.requires_grad = False
42
+
43
+ def forward(self, x):
44
+ return self.efficientnet(x)
45
+
46
+ class CustomConvNeXtModel(nn.Module):
47
+ def __init__(self, weights=ConvNeXt_Tiny_Weights.DEFAULT, num_classes=2):
48
+ super().__init__()
49
+ self.convnext = convnext_tiny(weights=weights)
50
+ in_features = self.convnext.classifier[2].in_features
51
+ self.convnext.classifier = nn.Sequential(
52
+ nn.AdaptiveAvgPool2d(1),
53
+ nn.Flatten(),
54
+ nn.Linear(in_features, 512),
55
+ nn.BatchNorm1d(512),
56
+ nn.ReLU(),
57
+ nn.Dropout(p=0.3),
58
+ nn.Linear(512, num_classes)
59
+ )
60
+ for param in self.convnext.features[:4].parameters():
61
+ param.requires_grad = False
62
+ def forward(self, x):
63
+ return self.convnext(x)
64
+
65
+ class CustomMobileNetModel(nn.Module):
66
+ def __init__(self, weights=MobileNet_V2_Weights.DEFAULT, num_classes=2):
67
+ super().__init__()
68
+ self.mobilenet = mobilenet_v2(weights=weights)
69
+ in_features = self.mobilenet.classifier[1].in_features
70
+ self.mobilenet.classifier = nn.Sequential(
71
+ nn.Linear(in_features, 1024),
72
+ nn.ReLU(),
73
+ nn.Dropout(p=0.5),
74
+ nn.Linear(1024, 512),
75
+ nn.ReLU(),
76
+ nn.Dropout(p=0.5),
77
+ nn.Linear(512, num_classes)
78
+ )
79
+ for param in self.mobilenet.features[:5].parameters():
80
+ param.requires_grad = False
81
+
82
+ def forward(self, x):
83
+ return self.mobilenet(x)
84
+
85
+ class EnsembleModel(nn.Module):
86
+ def __init__(self, convnext_model, mobilenet_model, efficientnet_model, num_classes=2):
87
+ super().__init__()
88
+ self.convnext = convnext_model
89
+ self.mobilenet = mobilenet_model
90
+ self.efficientnet = efficientnet_model
91
+ self.weight_convnext = nn.Parameter(torch.tensor(1.0))
92
+ self.weight_mobilenet = nn.Parameter(torch.tensor(1.0))
93
+ self.weight_efficientnet = nn.Parameter(torch.tensor(1.0))
94
+ self.fc = nn.Sequential(
95
+ nn.Linear(num_classes * 3, 512),
96
+ nn.ReLU(),
97
+ nn.Dropout(p=0.3),
98
+ nn.Linear(512, num_classes)
99
+ )
100
+ def forward(self, x):
101
+ convnext_out = self.convnext(x)
102
+ mobilenet_out = self.mobilenet(x)
103
+ efficientnet_out = self.efficientnet(x)
104
+ weights = torch.softmax(torch.stack([self.weight_convnext, self.weight_mobilenet, self.weight_efficientnet]), dim=0)
105
+ combined = (weights[0] * convnext_out +
106
+ weights[1] * mobilenet_out +
107
+ weights[2] * efficientnet_out)
108
+ output = self.fc(torch.cat((convnext_out, mobilenet_out, efficientnet_out), dim=1))
109
+ return output
110
+
111
+ convnext_model = CustomConvNeXtModel(weights=ConvNeXt_Tiny_Weights.DEFAULT, num_classes=2)
112
+ mobilenet_model = CustomMobileNetModel(weights=MobileNet_V2_Weights.DEFAULT, num_classes=2)
113
+ efficientnet_model = CustomEfficientNetModel(weights=EfficientNet_B0_Weights.DEFAULT, num_classes=2)
114
+ ensemble_model = EnsembleModel(convnext_model, mobilenet_model, efficientnet_model, num_classes=2).to(device)
115
+ # load the model weights
116
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
117
+ state_dict = torch.load(model_path, map_location=device)
118
+ ensemble_model.load_state_dict(state_dict)
119
+ ensemble_model.to(device)
120
+ ensemble_model.eval()
121
+ # load the dataset
122
+ dataset_test = load_dataset("gydou/released_img", split="train")
123
+ # define transformers
124
+ inference_transform = transforms.Compose([
125
+ transforms.Resize((224, 224)),
126
+ transforms.ToTensor(),
127
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
128
+ ])
129
+ # Parameters for denormalization
130
+ lat_mean = 39.951537011424264
131
+ lat_std = 0.0006940325318781937
132
+ lon_mean = -75.19152009539549
133
+ lon_std = 0.0007607716964655242
134
+ class GPSImageDataset(Dataset):
135
+ def __init__(self, hf_dataset, transform=None, lat_mean=None, lat_std=None, lon_mean=None, lon_std=None):
136
+ self.hf_dataset = hf_dataset
137
+ self.transform = transform
138
+ self.latitude_mean = lat_mean
139
+ self.latitude_std = lat_std
140
+ self.longitude_mean = lon_mean
141
+ self.longitude_std = lon_std
142
+ def __len__(self):
143
+ return len(self.hf_dataset)
144
+ def __getitem__(self, idx):
145
+ example = self.hf_dataset[idx]
146
+ image = example['image']
147
+ latitude = example['Latitude']
148
+ longitude = example['Longitude']
149
+ if self.transform:
150
+ image = self.transform(image)
151
+ latitude = (latitude - self.latitude_mean) / self.latitude_std
152
+ longitude = (longitude - self.longitude_mean) / self.longitude_std
153
+ gps_coords = torch.tensor([latitude, longitude], dtype=torch.float32)
154
+ return image, gps_coords
155
+ # transform test data
156
+ test_dataset = GPSImageDataset(
157
+ hf_dataset=dataset_test,
158
+ transform=inference_transform,
159
+ lat_mean=lat_mean,
160
+ lat_std=lat_std,
161
+ lon_mean=lon_mean,
162
+ lon_std=lon_std
163
+ )
164
+ test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)
165
+ # evaluate
166
+ def evaluate_model_single_batch(model, dataloader, lat_mean, lat_std, lon_mean, lon_std):
167
+ all_distances = []
168
+ model.eval()
169
+ with torch.no_grad():
170
+ for batch_idx, (images, gps_coords) in enumerate(dataloader):
171
+ images, gps_coords = images.to(device), gps_coords.to(device)
172
+ outputs = model(images)
173
+ preds_denorm = outputs.cpu().numpy() * np.array([lat_std, lon_std]) + np.array([lat_mean, lon_mean])
174
+ actuals_denorm = gps_coords.cpu().numpy() * np.array([lat_std, lon_std]) + np.array([lat_mean, lon_mean])
175
+ for pred, actual in zip(preds_denorm, actuals_denorm):
176
+ distance = geodesic((actual[0], actual[1]), (pred[0], pred[1])).meters
177
+ all_distances.append(distance)
178
+ break
179
+ mean_error = np.mean(all_distances)
180
+ rmse_error = np.sqrt(np.mean(np.square(all_distances)))
181
+ return mean_error, rmse_error
182
+ # Evaluate using only one batch
183
+ mean_error, rmse_error = evaluate_model_single_batch(
184
+ ensemble_model, test_dataloader, lat_mean, lat_std, lon_mean, lon_std
185
+ )
186
+ print(f"Mean Error (meters): {mean_error:.2f}, RMSE (meters): {rmse_error:.2f}")
187
+ ```