Modified readme
Browse files
README.md
CHANGED
@@ -422,6 +422,34 @@ inference_transform = transforms.Compose([
|
|
422 |
])
|
423 |
```
|
424 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
425 |
Here's an exmaple of us testing the ensemble on the release test set. You can just change the load release_data line below and run the rest of the code to obtain rMSE.
|
426 |
|
427 |
```python
|
@@ -442,10 +470,6 @@ rel_dataset = GPSImageDataset(
|
|
442 |
rel_dataloader = DataLoader(rel_dataset, batch_size=32, shuffle=False)
|
443 |
```
|
444 |
|
445 |
-
```python
|
446 |
-
models = [convnext, resnet, vit, efficientnet]
|
447 |
-
weights = [0.28, 0.26, 0.20, 0.27] # based on val 1/RMSE
|
448 |
-
```
|
449 |
|
450 |
```python
|
451 |
# Release
|
|
|
422 |
])
|
423 |
```
|
424 |
|
425 |
+
### Ensemble
|
426 |
+
Define Ensemble (weighted average) and prepare model
|
427 |
+
```python
|
428 |
+
models = [convnext, resnet, vit, efficientnet]
|
429 |
+
weights = [0.28, 0.26, 0.20, 0.27] # based on val 1/RMSE
|
430 |
+
```
|
431 |
+
|
432 |
+
```python
|
433 |
+
# Weighted ensemble prediction function
|
434 |
+
def weighted_ensemble_predict(models, weights, images):
|
435 |
+
"""
|
436 |
+
Generate weighted ensemble predictions by averaging logits using model weights.
|
437 |
+
"""
|
438 |
+
ensemble_logits = torch.zeros((images.size(0), 2)).to(images.device) # Initialize logits for ensemble
|
439 |
+
for model, weight in zip(models, weights):
|
440 |
+
outputs = model(images)
|
441 |
+
logits = outputs.logits if hasattr(outputs, "logits") else outputs # Extract logits
|
442 |
+
ensemble_logits += weight * logits # Weighted sum of logits
|
443 |
+
return ensemble_logits # Return the weighted logits sum (no division since weights sum to 1)
|
444 |
+
|
445 |
+
|
446 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
447 |
+
for model in models:
|
448 |
+
model.to(device)
|
449 |
+
model.eval()
|
450 |
+
|
451 |
+
```
|
452 |
+
|
453 |
Here's an exmaple of us testing the ensemble on the release test set. You can just change the load release_data line below and run the rest of the code to obtain rMSE.
|
454 |
|
455 |
```python
|
|
|
470 |
rel_dataloader = DataLoader(rel_dataset, batch_size=32, shuffle=False)
|
471 |
```
|
472 |
|
|
|
|
|
|
|
|
|
473 |
|
474 |
```python
|
475 |
# Release
|