Spaces:
Runtime error
Runtime error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import logging | |
from enum import Enum | |
from typing import Any, Callable, List, Optional, TypeVar | |
import torch | |
from torch.utils.data import Sampler | |
from .datasets import ImageNet, ImageNet22k | |
from .samplers import EpochSampler, InfiniteSampler, ShardedInfiniteSampler | |
logger = logging.getLogger("dinov2") | |
class SamplerType(Enum): | |
DISTRIBUTED = 0 | |
EPOCH = 1 | |
INFINITE = 2 | |
SHARDED_INFINITE = 3 | |
SHARDED_INFINITE_NEW = 4 | |
def _make_bool_str(b: bool) -> str: | |
return "yes" if b else "no" | |
def _make_sample_transform(image_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None): | |
def transform(sample): | |
image, target = sample | |
if image_transform is not None: | |
image = image_transform(image) | |
if target_transform is not None: | |
target = target_transform(target) | |
return image, target | |
return transform | |
def _parse_dataset_str(dataset_str: str): | |
tokens = dataset_str.split(":") | |
name = tokens[0] | |
kwargs = {} | |
for token in tokens[1:]: | |
key, value = token.split("=") | |
assert key in ("root", "extra", "split") | |
kwargs[key] = value | |
if name == "ImageNet": | |
class_ = ImageNet | |
if "split" in kwargs: | |
kwargs["split"] = ImageNet.Split[kwargs["split"]] | |
elif name == "ImageNet22k": | |
class_ = ImageNet22k | |
else: | |
raise ValueError(f'Unsupported dataset "{name}"') | |
return class_, kwargs | |
def make_dataset( | |
*, | |
dataset_str: str, | |
transform: Optional[Callable] = None, | |
target_transform: Optional[Callable] = None, | |
): | |
""" | |
Creates a dataset with the specified parameters. | |
Args: | |
dataset_str: A dataset string description (e.g. ImageNet:split=TRAIN). | |
transform: A transform to apply to images. | |
target_transform: A transform to apply to targets. | |
Returns: | |
The created dataset. | |
""" | |
logger.info(f'using dataset: "{dataset_str}"') | |
class_, kwargs = _parse_dataset_str(dataset_str) | |
dataset = class_(transform=transform, target_transform=target_transform, **kwargs) | |
logger.info(f"# of dataset samples: {len(dataset):,d}") | |
# Aggregated datasets do not expose (yet) these attributes, so add them. | |
if not hasattr(dataset, "transform"): | |
setattr(dataset, "transform", transform) | |
if not hasattr(dataset, "target_transform"): | |
setattr(dataset, "target_transform", target_transform) | |
return dataset | |
def _make_sampler( | |
*, | |
dataset, | |
type: Optional[SamplerType] = None, | |
shuffle: bool = False, | |
seed: int = 0, | |
size: int = -1, | |
advance: int = 0, | |
) -> Optional[Sampler]: | |
sample_count = len(dataset) | |
if type == SamplerType.INFINITE: | |
logger.info("sampler: infinite") | |
if size > 0: | |
raise ValueError("sampler size > 0 is invalid") | |
return InfiniteSampler( | |
sample_count=sample_count, | |
shuffle=shuffle, | |
seed=seed, | |
advance=advance, | |
) | |
elif type in (SamplerType.SHARDED_INFINITE, SamplerType.SHARDED_INFINITE_NEW): | |
logger.info("sampler: sharded infinite") | |
if size > 0: | |
raise ValueError("sampler size > 0 is invalid") | |
# TODO: Remove support for old shuffling | |
use_new_shuffle_tensor_slice = type == SamplerType.SHARDED_INFINITE_NEW | |
return ShardedInfiniteSampler( | |
sample_count=sample_count, | |
shuffle=shuffle, | |
seed=seed, | |
advance=advance, | |
use_new_shuffle_tensor_slice=use_new_shuffle_tensor_slice, | |
) | |
elif type == SamplerType.EPOCH: | |
logger.info("sampler: epoch") | |
if advance > 0: | |
raise NotImplementedError("sampler advance > 0 is not supported") | |
size = size if size > 0 else sample_count | |
logger.info(f"# of samples / epoch: {size:,d}") | |
return EpochSampler( | |
size=size, | |
sample_count=sample_count, | |
shuffle=shuffle, | |
seed=seed, | |
) | |
elif type == SamplerType.DISTRIBUTED: | |
logger.info("sampler: distributed") | |
if size > 0: | |
raise ValueError("sampler size > 0 is invalid") | |
if advance > 0: | |
raise ValueError("sampler advance > 0 is invalid") | |
return torch.utils.data.DistributedSampler( | |
dataset=dataset, | |
shuffle=shuffle, | |
seed=seed, | |
drop_last=False, | |
) | |
logger.info("sampler: none") | |
return None | |
T = TypeVar("T") | |
def make_data_loader( | |
*, | |
dataset, | |
batch_size: int, | |
num_workers: int, | |
shuffle: bool = True, | |
seed: int = 0, | |
sampler_type: Optional[SamplerType] = SamplerType.INFINITE, | |
sampler_size: int = -1, | |
sampler_advance: int = 0, | |
drop_last: bool = True, | |
persistent_workers: bool = False, | |
collate_fn: Optional[Callable[[List[T]], Any]] = None, | |
): | |
""" | |
Creates a data loader with the specified parameters. | |
Args: | |
dataset: A dataset (third party, LaViDa or WebDataset). | |
batch_size: The size of batches to generate. | |
num_workers: The number of workers to use. | |
shuffle: Whether to shuffle samples. | |
seed: The random seed to use. | |
sampler_type: Which sampler to use: EPOCH, INFINITE, SHARDED_INFINITE, SHARDED_INFINITE_NEW, DISTRIBUTED or None. | |
sampler_size: The number of images per epoch (when applicable) or -1 for the entire dataset. | |
sampler_advance: How many samples to skip (when applicable). | |
drop_last: Whether the last non-full batch of data should be dropped. | |
persistent_workers: maintain the workers Dataset instances alive after a dataset has been consumed once. | |
collate_fn: Function that performs batch collation | |
""" | |
sampler = _make_sampler( | |
dataset=dataset, | |
type=sampler_type, | |
shuffle=shuffle, | |
seed=seed, | |
size=sampler_size, | |
advance=sampler_advance, | |
) | |
logger.info("using PyTorch data loader") | |
data_loader = torch.utils.data.DataLoader( | |
dataset, | |
sampler=sampler, | |
batch_size=batch_size, | |
num_workers=num_workers, | |
pin_memory=True, | |
drop_last=drop_last, | |
persistent_workers=persistent_workers, | |
collate_fn=collate_fn, | |
) | |
try: | |
logger.info(f"# of batches: {len(data_loader):,d}") | |
except TypeError: # data loader has no length | |
logger.info("infinite data loader") | |
return data_loader | |