lling0212 commited on
Commit
c4a5ad0
·
1 Parent(s): 42f6af3

Modified readme

Browse files
Files changed (1) hide show
  1. README.md +28 -4
README.md CHANGED
@@ -422,6 +422,34 @@ inference_transform = transforms.Compose([
422
  ])
423
  ```
424
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
425
  Here's an exmaple of us testing the ensemble on the release test set. You can just change the load release_data line below and run the rest of the code to obtain rMSE.
426
 
427
  ```python
@@ -442,10 +470,6 @@ rel_dataset = GPSImageDataset(
442
  rel_dataloader = DataLoader(rel_dataset, batch_size=32, shuffle=False)
443
  ```
444
 
445
- ```python
446
- models = [convnext, resnet, vit, efficientnet]
447
- weights = [0.28, 0.26, 0.20, 0.27] # based on val 1/RMSE
448
- ```
449
 
450
  ```python
451
  # Release
 
422
  ])
423
  ```
424
 
425
+ ### Ensemble
426
+ Define Ensemble (weighted average) and prepare model
427
+ ```python
428
+ models = [convnext, resnet, vit, efficientnet]
429
+ weights = [0.28, 0.26, 0.20, 0.27] # based on val 1/RMSE
430
+ ```
431
+
432
+ ```python
433
+ # Weighted ensemble prediction function
434
+ def weighted_ensemble_predict(models, weights, images):
435
+ """
436
+ Generate weighted ensemble predictions by averaging logits using model weights.
437
+ """
438
+ ensemble_logits = torch.zeros((images.size(0), 2)).to(images.device) # Initialize logits for ensemble
439
+ for model, weight in zip(models, weights):
440
+ outputs = model(images)
441
+ logits = outputs.logits if hasattr(outputs, "logits") else outputs # Extract logits
442
+ ensemble_logits += weight * logits # Weighted sum of logits
443
+ return ensemble_logits # Return the weighted logits sum (no division since weights sum to 1)
444
+
445
+
446
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
447
+ for model in models:
448
+ model.to(device)
449
+ model.eval()
450
+
451
+ ```
452
+
453
  Here's an exmaple of us testing the ensemble on the release test set. You can just change the load release_data line below and run the rest of the code to obtain rMSE.
454
 
455
  ```python
 
470
  rel_dataloader = DataLoader(rel_dataset, batch_size=32, shuffle=False)
471
  ```
472
 
 
 
 
 
473
 
474
  ```python
475
  # Release