|
|
|
|
|
|
|
|
|
|
|
|
|
import csv |
|
from enum import Enum |
|
import logging |
|
import os |
|
from typing import Callable, List, Optional, Tuple, Union |
|
|
|
import numpy as np |
|
|
|
from .extended import ExtendedVisionDataset |
|
|
|
|
|
logger = logging.getLogger("dinov2") |
|
_Target = int |
|
|
|
|
|
class _Split(Enum): |
|
TRAIN = "train" |
|
VAL = "val" |
|
TEST = "test" |
|
|
|
@property |
|
def length(self) -> int: |
|
split_lengths = { |
|
_Split.TRAIN: 1_281_167, |
|
_Split.VAL: 50_000, |
|
_Split.TEST: 100_000, |
|
} |
|
return split_lengths[self] |
|
|
|
def get_dirname(self, class_id: Optional[str] = None) -> str: |
|
return self.value if class_id is None else os.path.join(self.value, class_id) |
|
|
|
def get_image_relpath(self, actual_index: int, class_id: Optional[str] = None) -> str: |
|
dirname = self.get_dirname(class_id) |
|
if self == _Split.TRAIN: |
|
basename = f"{class_id}_{actual_index}" |
|
else: |
|
basename = f"ILSVRC2012_{self.value}_{actual_index:08d}" |
|
return os.path.join(dirname, basename + ".JPEG") |
|
|
|
def parse_image_relpath(self, image_relpath: str) -> Tuple[str, int]: |
|
assert self != _Split.TEST |
|
dirname, filename = os.path.split(image_relpath) |
|
class_id = os.path.split(dirname)[-1] |
|
basename, _ = os.path.splitext(filename) |
|
actual_index = int(basename.split("_")[-1]) |
|
return class_id, actual_index |
|
|
|
|
|
class ImageNet(ExtendedVisionDataset): |
|
Target = Union[_Target] |
|
Split = Union[_Split] |
|
|
|
def __init__( |
|
self, |
|
*, |
|
split: "ImageNet.Split", |
|
root: str, |
|
extra: str, |
|
transforms: Optional[Callable] = None, |
|
transform: Optional[Callable] = None, |
|
target_transform: Optional[Callable] = None, |
|
) -> None: |
|
super().__init__(root, transforms, transform, target_transform) |
|
self._extra_root = extra |
|
self._split = split |
|
|
|
self._entries = None |
|
self._class_ids = None |
|
self._class_names = None |
|
|
|
@property |
|
def split(self) -> "ImageNet.Split": |
|
return self._split |
|
|
|
def _get_extra_full_path(self, extra_path: str) -> str: |
|
return os.path.join(self._extra_root, extra_path) |
|
|
|
def _load_extra(self, extra_path: str) -> np.ndarray: |
|
extra_full_path = self._get_extra_full_path(extra_path) |
|
return np.load(extra_full_path, mmap_mode="r") |
|
|
|
def _save_extra(self, extra_array: np.ndarray, extra_path: str) -> None: |
|
extra_full_path = self._get_extra_full_path(extra_path) |
|
os.makedirs(self._extra_root, exist_ok=True) |
|
np.save(extra_full_path, extra_array) |
|
|
|
@property |
|
def _entries_path(self) -> str: |
|
return f"entries-{self._split.value.upper()}.npy" |
|
|
|
@property |
|
def _class_ids_path(self) -> str: |
|
return f"class-ids-{self._split.value.upper()}.npy" |
|
|
|
@property |
|
def _class_names_path(self) -> str: |
|
return f"class-names-{self._split.value.upper()}.npy" |
|
|
|
def _get_entries(self) -> np.ndarray: |
|
if self._entries is None: |
|
self._entries = self._load_extra(self._entries_path) |
|
assert self._entries is not None |
|
return self._entries |
|
|
|
def _get_class_ids(self) -> np.ndarray: |
|
if self._split == _Split.TEST: |
|
assert False, "Class IDs are not available in TEST split" |
|
if self._class_ids is None: |
|
self._class_ids = self._load_extra(self._class_ids_path) |
|
assert self._class_ids is not None |
|
return self._class_ids |
|
|
|
def _get_class_names(self) -> np.ndarray: |
|
if self._split == _Split.TEST: |
|
assert False, "Class names are not available in TEST split" |
|
if self._class_names is None: |
|
self._class_names = self._load_extra(self._class_names_path) |
|
assert self._class_names is not None |
|
return self._class_names |
|
|
|
def find_class_id(self, class_index: int) -> str: |
|
class_ids = self._get_class_ids() |
|
return str(class_ids[class_index]) |
|
|
|
def find_class_name(self, class_index: int) -> str: |
|
class_names = self._get_class_names() |
|
return str(class_names[class_index]) |
|
|
|
def get_image_data(self, index: int) -> bytes: |
|
entries = self._get_entries() |
|
actual_index = entries[index]["actual_index"] |
|
|
|
class_id = self.get_class_id(index) |
|
|
|
image_relpath = self.split.get_image_relpath(actual_index, class_id) |
|
image_full_path = os.path.join(self.root, image_relpath) |
|
with open(image_full_path, mode="rb") as f: |
|
image_data = f.read() |
|
return image_data |
|
|
|
def get_target(self, index: int) -> Optional[Target]: |
|
entries = self._get_entries() |
|
class_index = entries[index]["class_index"] |
|
return None if self.split == _Split.TEST else int(class_index) |
|
|
|
def get_targets(self) -> Optional[np.ndarray]: |
|
entries = self._get_entries() |
|
return None if self.split == _Split.TEST else entries["class_index"] |
|
|
|
def get_class_id(self, index: int) -> Optional[str]: |
|
entries = self._get_entries() |
|
class_id = entries[index]["class_id"] |
|
return None if self.split == _Split.TEST else str(class_id) |
|
|
|
def get_class_name(self, index: int) -> Optional[str]: |
|
entries = self._get_entries() |
|
class_name = entries[index]["class_name"] |
|
return None if self.split == _Split.TEST else str(class_name) |
|
|
|
def __len__(self) -> int: |
|
entries = self._get_entries() |
|
assert len(entries) == self.split.length |
|
return len(entries) |
|
|
|
def _load_labels(self, labels_path: str) -> List[Tuple[str, str]]: |
|
labels_full_path = os.path.join(self.root, labels_path) |
|
labels = [] |
|
|
|
try: |
|
with open(labels_full_path, "r") as f: |
|
reader = csv.reader(f) |
|
for row in reader: |
|
class_id, class_name = row |
|
labels.append((class_id, class_name)) |
|
except OSError as e: |
|
raise RuntimeError(f'can not read labels file "{labels_full_path}"') from e |
|
|
|
return labels |
|
|
|
def _dump_entries(self) -> None: |
|
split = self.split |
|
if split == ImageNet.Split.TEST: |
|
dataset = None |
|
sample_count = split.length |
|
max_class_id_length, max_class_name_length = 0, 0 |
|
else: |
|
labels_path = "labels.txt" |
|
logger.info(f'loading labels from "{labels_path}"') |
|
labels = self._load_labels(labels_path) |
|
|
|
|
|
from torchvision.datasets import ImageFolder |
|
|
|
dataset_root = os.path.join(self.root, split.get_dirname()) |
|
dataset = ImageFolder(dataset_root) |
|
sample_count = len(dataset) |
|
max_class_id_length, max_class_name_length = -1, -1 |
|
for sample in dataset.samples: |
|
_, class_index = sample |
|
class_id, class_name = labels[class_index] |
|
max_class_id_length = max(len(class_id), max_class_id_length) |
|
max_class_name_length = max(len(class_name), max_class_name_length) |
|
|
|
dtype = np.dtype( |
|
[ |
|
("actual_index", "<u4"), |
|
("class_index", "<u4"), |
|
("class_id", f"U{max_class_id_length}"), |
|
("class_name", f"U{max_class_name_length}"), |
|
] |
|
) |
|
entries_array = np.empty(sample_count, dtype=dtype) |
|
|
|
if split == ImageNet.Split.TEST: |
|
old_percent = -1 |
|
for index in range(sample_count): |
|
percent = 100 * (index + 1) // sample_count |
|
if percent > old_percent: |
|
logger.info(f"creating entries: {percent}%") |
|
old_percent = percent |
|
|
|
actual_index = index + 1 |
|
class_index = np.uint32(-1) |
|
class_id, class_name = "", "" |
|
entries_array[index] = (actual_index, class_index, class_id, class_name) |
|
else: |
|
class_names = {class_id: class_name for class_id, class_name in labels} |
|
|
|
assert dataset |
|
old_percent = -1 |
|
for index in range(sample_count): |
|
percent = 100 * (index + 1) // sample_count |
|
if percent > old_percent: |
|
logger.info(f"creating entries: {percent}%") |
|
old_percent = percent |
|
|
|
image_full_path, class_index = dataset.samples[index] |
|
image_relpath = os.path.relpath(image_full_path, self.root) |
|
class_id, actual_index = split.parse_image_relpath(image_relpath) |
|
class_name = class_names[class_id] |
|
entries_array[index] = (actual_index, class_index, class_id, class_name) |
|
|
|
logger.info(f'saving entries to "{self._entries_path}"') |
|
self._save_extra(entries_array, self._entries_path) |
|
|
|
def _dump_class_ids_and_names(self) -> None: |
|
split = self.split |
|
if split == ImageNet.Split.TEST: |
|
return |
|
|
|
entries_array = self._load_extra(self._entries_path) |
|
|
|
max_class_id_length, max_class_name_length, max_class_index = -1, -1, -1 |
|
for entry in entries_array: |
|
class_index, class_id, class_name = ( |
|
entry["class_index"], |
|
entry["class_id"], |
|
entry["class_name"], |
|
) |
|
max_class_index = max(int(class_index), max_class_index) |
|
max_class_id_length = max(len(str(class_id)), max_class_id_length) |
|
max_class_name_length = max(len(str(class_name)), max_class_name_length) |
|
|
|
class_count = max_class_index + 1 |
|
class_ids_array = np.empty(class_count, dtype=f"U{max_class_id_length}") |
|
class_names_array = np.empty(class_count, dtype=f"U{max_class_name_length}") |
|
for entry in entries_array: |
|
class_index, class_id, class_name = ( |
|
entry["class_index"], |
|
entry["class_id"], |
|
entry["class_name"], |
|
) |
|
class_ids_array[class_index] = class_id |
|
class_names_array[class_index] = class_name |
|
|
|
logger.info(f'saving class IDs to "{self._class_ids_path}"') |
|
self._save_extra(class_ids_array, self._class_ids_path) |
|
|
|
logger.info(f'saving class names to "{self._class_names_path}"') |
|
self._save_extra(class_names_array, self._class_names_path) |
|
|
|
def dump_extra(self) -> None: |
|
self._dump_entries() |
|
self._dump_class_ids_and_names() |
|
|