Spaces:
Runtime error
Runtime error
import pytorch_lightning as pl | |
import torch | |
import pandas as pd | |
from oml.datasets.base import DatasetQueryGallery, DatasetWithLabels | |
from oml.lightning.modules.extractor import ExtractorModule | |
from oml.lightning.callbacks.metric import MetricValCallback | |
from oml.losses.triplet import TripletLossWithMiner | |
from oml.metrics.embeddings import EmbeddingMetrics | |
from oml.miners.inbatch_all_tri import AllTripletsMiner | |
from oml.models.vit.vit import ViTExtractor | |
from oml.samplers.balance import BalanceSampler | |
from pytorch_lightning.loggers import TensorBoardLogger | |
import argparse | |
parser = argparse.ArgumentParser("Train model with OML", | |
formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |
parser.add_argument("--root-dir", help="root directory for train data", default="data/train_val/") | |
parser.add_argument("--train-dataframe-name", help="name of dataframe in root directory", default="df_stamps.csv") | |
parser.add_argument("--train-images", help="name of directory with images", default="images/") | |
parser.add_argument("--num-epochs", help="number of epochs to train model", default=100) | |
parser.add_argument("--model-arch", help="which model architecture to use, check model zoo", default="vits8") | |
parser.add_argument("--weights", | |
help=""" | |
pretrained weights for model, choose from model zoo | |
https://open-metric-learning.readthedocs.io/en/latest/feature_extraction/zoo.html | |
""", | |
default="vits8_dino") | |
parser.add_argument("--checkpoint", help="resume training from checkpoint, provide path", default=None) | |
parser.add_argument("--num-labels", help="number of labels in dataset, set less if cuda out of memory", default=6) | |
parser.add_argument("--num-instances", help="number of instances for each label in batch, set less if cuda out of memory", default=2) | |
parser.add_argument("--val-batch-size", help="batch size for validation", default=4) | |
parser.add_argument("--log-data", action="store_true", help="Whether to log data") | |
args = parser.parse_args() | |
config = vars(args) | |
dataset_root = config['root_dir'] | |
df = pd.read_csv(f"{dataset_root}{config['train_dataframe_name']}") | |
df_train = df[df["split"] == "train"].reset_index(drop=True) | |
df_val = df[df["split"] == "validation"].reset_index(drop=True) | |
df_val["is_query"] = df_val["is_query"].astype(bool) | |
df_val["is_gallery"] = df_val["is_gallery"].astype(bool) | |
extractor = ViTExtractor(config['weights'], arch=config['model_arch'], normalise_features=False) | |
optimizer = torch.optim.SGD(extractor.parameters(), lr=1e-6) | |
train_dataset = DatasetWithLabels(df_train, dataset_root=dataset_root) | |
criterion = TripletLossWithMiner(margin=0.1, miner=AllTripletsMiner()) | |
batch_sampler = BalanceSampler(train_dataset.get_labels(), n_labels=config['num_labels'], n_instances=config['num_instances']) | |
train_loader = torch.utils.data.DataLoader(train_dataset, batch_sampler=batch_sampler) | |
val_dataset = DatasetQueryGallery(df_val, dataset_root=dataset_root) | |
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=config['val_batch_size']) | |
metric_callback = MetricValCallback(metric=EmbeddingMetrics(extra_keys=[train_dataset.paths_key,], cmc_top_k=(5, 3, 1)), log_images=True) | |
if config['log_data']: | |
logger = TensorBoardLogger(".") | |
pl_model = ExtractorModule(extractor, criterion, optimizer) | |
trainer = pl.Trainer(max_epochs=config['num_epochs'], callbacks=[metric_callback], num_sanity_val_steps=0, accelerator='gpu', devices=1, resume_from_checkpoint=config['checkpoint']) | |
trainer.fit(pl_model, train_dataloaders=train_loader, val_dataloaders=val_loader) |