vit / README.md
NilPtrExceptionn's picture
Update README.md
cde40b6 verified

Image2GPS Model Overview

Datasets

  • Training Dataset: image2gpsLLH/image_data
  • Evaluation Metrics:
    • Accuracy

Model Statistics

  • Latitude Mean: 39.95150678400655
  • Latitude Standard Deviation: 0.0007344790486223371
  • Longitude Mean: -75.19146715269915
  • Longitude Standard Deviation: 0.0007342464795497821

Model Description

  • Model Type: Vision Transformer (ViT)

Training Data

  • Dataset Size: 1325 Images
  • Location: Penn Engineering walkways
  • Data Collection Method:
    • Images captured from different directions at various points:
      • North, Northeast, East, Southeast, South, Southwest, West, Northwest

Testing Data

  • Dataset Size: 441 Images
  • Location: Penn Engineering walkways

Factors Affecting Model Performance

  • Environmental Conditions: Lighting, weather, time of day
  • Image Variability: Different camera angles and perspectives

Training Result

Image Example Caption: Example of an image used during training/testing.


Example Execution

https://colab.research.google.com/drive/12mQAu1m65EV5kJlVkigkEOxH8NaLULTS?usp=sharing

!pip install datasets

# Imports
from huggingface_hub import login
from huggingface_hub import hf_hub_download
from torchvision import models
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset
import numpy as np
import os
from sklearn.metrics import mean_absolute_error, mean_squared_error
import timm
from torch import nn

class ViTGeoLocator(nn.Module):
    def __init__(self, freeze_backbone=True):
        super(ViTGeoLocator, self).__init__()
        # Load pretrained ViT
        self.backbone = timm.create_model('vit_base_patch16_224', pretrained=True)

        if freeze_backbone:
            for param in self.backbone.parameters():
                param.requires_grad = False

        # Get the dimension of the ViT's output
        embed_dim = self.backbone.num_features

        # Remove the original classification head
        self.backbone.head = nn.Identity()

        # New regression head
        self.regressor = nn.Sequential(
            nn.Linear(embed_dim, 512),
            nn.GELU(),
            nn.Dropout(0.3),
            nn.Linear(512, 128),
            nn.GELU(),
            nn.Dropout(0.2),
            nn.Linear(128, 2)  # Output: [latitude, longitude]
        )

    def forward(self, x):
        x = self.backbone(x)
        return self.regressor(x)

# Log in to Hugging Face
login("replace with huggingface token")

# Specify the repository and model file
repo_id = "image2gpsLLH/vit"
filename = "vit.pth"

# Download the model from Hugging Face
model_path = hf_hub_download(repo_id=repo_id, filename=filename)

# Initialize the model
model_test = ViTGeoLocator(freeze_backbone=True)

# Load the checkpoint
checkpoint = torch.load(model_path)

# Load state dict
model_test.load_state_dict(checkpoint['model_state_dict'])

# Set the model to evaluation mode
model_test.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_test = model_test.to(device)

# Create the DataLoader and run inference
with torch.no_grad():
    for images, gps_coords in sample_dataloader:
        images, gps_coords = images.to(device), gps_coords.to(device)
        outputs = model_test(images)


class GPSImageDataset(Dataset):
    def __init__(self, hf_dataset, transform=None, lat_mean=None, lat_std=None, lon_mean=None, lon_std=None):
        self.hf_dataset = hf_dataset
        self.transform = transform

        # Compute mean and std from the dataframe if not provided
        self.latitude_mean = lat_mean if lat_mean is not None else np.mean(np.array(self.hf_dataset['Latitude']))
        self.latitude_std = lat_std if lat_std is not None else np.std(np.array(self.hf_dataset['Latitude']))
        self.longitude_mean = lon_mean if lon_mean is not None else np.mean(np.array(self.hf_dataset['Longitude']))
        self.longitude_std = lon_std if lon_std is not None else np.std(np.array(self.hf_dataset['Longitude']))

    def __len__(self):
        return len(self.hf_dataset)

    def __getitem__(self, idx):
        # Extract data
        example = self.hf_dataset[idx]

        # Load and process the image
        image = example['image']
        latitude = example['Latitude']
        longitude = example['Longitude']
        if self.transform:
            image = self.transform(image)

        # Normalize GPS coordinates
        latitude = (latitude - self.latitude_mean) / self.latitude_std
        longitude = (longitude - self.longitude_mean) / self.longitude_std
        gps_coords = torch.tensor([latitude, longitude], dtype=torch.float32)

        return image, gps_coords

# Load sample data (replace with path to sample data)
data_sample = load_dataset("gydou/released_img", split="train")

# Specify mean and std for latitude and longitude (replace with the stated mean and std above)
lat_mean: 39.95150678400655
lat_std: 0.0007344790486223371
lon_mean: -75.19146715269915
lon_std: 0.0007342464795497821

# Specify transform
inference_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# Create dataloader
sample_dataset = GPSImageDataset(
    hf_dataset=data_sample,
    transform=inference_transform,
    lat_mean=lat_mean,
    lat_std=lat_std,
    lon_mean=lon_mean,
    lon_std=lon_std
)
sample_dataloader = DataLoader(sample_dataset, batch_size=32, shuffle=False)

# Run model
all_preds = []
all_actuals = []
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
with torch.no_grad():
    for images, gps_coords in sample_dataloader:
        images, gps_coords = images.to(device), gps_coords.to(device)

        outputs = model_test(images)

        # Denormalize predictions and actual values
        preds = outputs.cpu() * torch.tensor([lat_std, lon_std]) + torch.tensor([lat_mean, lon_mean])
        actuals = gps_coords.cpu() * torch.tensor([lat_std, lon_std]) + torch.tensor([lat_mean, lon_mean])

        all_preds.append(preds)
        all_actuals.append(actuals)

# Concatenate all batches
all_preds = torch.cat(all_preds).numpy()
all_actuals = torch.cat(all_actuals).numpy()

# Compute error metrics
mae = mean_absolute_error(all_actuals, all_preds)
rmse = mean_squared_error(all_actuals, all_preds, squared=False)

print(f'Mean Absolute Error: {mae}')
print(f'Root Mean Squared Error: {rmse}')