Spaces:
Runtime error
Runtime error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import argparse | |
from functools import partial | |
import json | |
import logging | |
import os | |
import sys | |
from typing import List, Optional | |
import torch | |
from torch.nn.functional import one_hot, softmax | |
import dinov2.distributed as distributed | |
from dinov2.data import SamplerType, make_data_loader, make_dataset | |
from dinov2.data.transforms import make_classification_eval_transform | |
from dinov2.eval.metrics import AccuracyAveraging, build_topk_accuracy_metric | |
from dinov2.eval.setup import get_args_parser as get_setup_args_parser | |
from dinov2.eval.setup import setup_and_build_model | |
from dinov2.eval.utils import ModelWithNormalize, evaluate, extract_features | |
logger = logging.getLogger("dinov2") | |
def get_args_parser( | |
description: Optional[str] = None, | |
parents: Optional[List[argparse.ArgumentParser]] = None, | |
add_help: bool = True, | |
): | |
parents = parents or [] | |
setup_args_parser = get_setup_args_parser(parents=parents, add_help=False) | |
parents = [setup_args_parser] | |
parser = argparse.ArgumentParser( | |
description=description, | |
parents=parents, | |
add_help=add_help, | |
) | |
parser.add_argument( | |
"--train-dataset", | |
dest="train_dataset_str", | |
type=str, | |
help="Training dataset", | |
) | |
parser.add_argument( | |
"--val-dataset", | |
dest="val_dataset_str", | |
type=str, | |
help="Validation dataset", | |
) | |
parser.add_argument( | |
"--nb_knn", | |
nargs="+", | |
type=int, | |
help="Number of NN to use. 20 is usually working the best.", | |
) | |
parser.add_argument( | |
"--temperature", | |
type=float, | |
help="Temperature used in the voting coefficient", | |
) | |
parser.add_argument( | |
"--gather-on-cpu", | |
action="store_true", | |
help="Whether to gather the train features on cpu, slower" | |
"but useful to avoid OOM for large datasets (e.g. ImageNet22k).", | |
) | |
parser.add_argument( | |
"--batch-size", | |
type=int, | |
help="Batch size.", | |
) | |
parser.add_argument( | |
"--n-per-class-list", | |
nargs="+", | |
type=int, | |
help="Number to take per class", | |
) | |
parser.add_argument( | |
"--n-tries", | |
type=int, | |
help="Number of tries", | |
) | |
parser.set_defaults( | |
train_dataset_str="ImageNet:split=TRAIN", | |
val_dataset_str="ImageNet:split=VAL", | |
nb_knn=[10, 20, 100, 200], | |
temperature=0.07, | |
batch_size=256, | |
n_per_class_list=[-1], | |
n_tries=1, | |
) | |
return parser | |
class KnnModule(torch.nn.Module): | |
""" | |
Gets knn of test features from all processes on a chunk of the train features | |
Each rank gets a chunk of the train features as well as a chunk of the test features. | |
In `compute_neighbors`, for each rank one after the other, its chunk of test features | |
is sent to all devices, partial knns are computed with each chunk of train features | |
then collated back on the original device. | |
""" | |
def __init__(self, train_features, train_labels, nb_knn, T, device, num_classes=1000): | |
super().__init__() | |
self.global_rank = distributed.get_global_rank() | |
self.global_size = distributed.get_global_size() | |
self.device = device | |
self.train_features_rank_T = train_features.chunk(self.global_size)[self.global_rank].T.to(self.device) | |
self.candidates = train_labels.chunk(self.global_size)[self.global_rank].view(1, -1).to(self.device) | |
self.nb_knn = nb_knn | |
self.max_k = max(self.nb_knn) | |
self.T = T | |
self.num_classes = num_classes | |
def _get_knn_sims_and_labels(self, similarity, train_labels): | |
topk_sims, indices = similarity.topk(self.max_k, largest=True, sorted=True) | |
neighbors_labels = torch.gather(train_labels, 1, indices) | |
return topk_sims, neighbors_labels | |
def _similarity_for_rank(self, features_rank, source_rank): | |
# Send the features from `source_rank` to all ranks | |
broadcast_shape = torch.tensor(features_rank.shape).to(self.device) | |
torch.distributed.broadcast(broadcast_shape, source_rank) | |
broadcasted = features_rank | |
if self.global_rank != source_rank: | |
broadcasted = torch.zeros(*broadcast_shape, dtype=features_rank.dtype, device=self.device) | |
torch.distributed.broadcast(broadcasted, source_rank) | |
# Compute the neighbors for `source_rank` among `train_features_rank_T` | |
similarity_rank = torch.mm(broadcasted, self.train_features_rank_T) | |
candidate_labels = self.candidates.expand(len(similarity_rank), -1) | |
return self._get_knn_sims_and_labels(similarity_rank, candidate_labels) | |
def _gather_all_knn_for_rank(self, topk_sims, neighbors_labels, target_rank): | |
# Gather all neighbors for `target_rank` | |
topk_sims_rank = retrieved_rank = None | |
if self.global_rank == target_rank: | |
topk_sims_rank = [torch.zeros_like(topk_sims) for _ in range(self.global_size)] | |
retrieved_rank = [torch.zeros_like(neighbors_labels) for _ in range(self.global_size)] | |
torch.distributed.gather(topk_sims, topk_sims_rank, dst=target_rank) | |
torch.distributed.gather(neighbors_labels, retrieved_rank, dst=target_rank) | |
if self.global_rank == target_rank: | |
# Perform a second top-k on the k * global_size retrieved neighbors | |
topk_sims_rank = torch.cat(topk_sims_rank, dim=1) | |
retrieved_rank = torch.cat(retrieved_rank, dim=1) | |
results = self._get_knn_sims_and_labels(topk_sims_rank, retrieved_rank) | |
return results | |
return None | |
def compute_neighbors(self, features_rank): | |
for rank in range(self.global_size): | |
topk_sims, neighbors_labels = self._similarity_for_rank(features_rank, rank) | |
results = self._gather_all_knn_for_rank(topk_sims, neighbors_labels, rank) | |
if results is not None: | |
topk_sims_rank, neighbors_labels_rank = results | |
return topk_sims_rank, neighbors_labels_rank | |
def forward(self, features_rank): | |
""" | |
Compute the results on all values of `self.nb_knn` neighbors from the full `self.max_k` | |
""" | |
assert all(k <= self.max_k for k in self.nb_knn) | |
topk_sims, neighbors_labels = self.compute_neighbors(features_rank) | |
batch_size = neighbors_labels.shape[0] | |
topk_sims_transform = softmax(topk_sims / self.T, 1) | |
matmul = torch.mul( | |
one_hot(neighbors_labels, num_classes=self.num_classes), | |
topk_sims_transform.view(batch_size, -1, 1), | |
) | |
probas_for_k = {k: torch.sum(matmul[:, :k, :], 1) for k in self.nb_knn} | |
return probas_for_k | |
class DictKeysModule(torch.nn.Module): | |
def __init__(self, keys): | |
super().__init__() | |
self.keys = keys | |
def forward(self, features_dict, targets): | |
for k in self.keys: | |
features_dict = features_dict[k] | |
return {"preds": features_dict, "target": targets} | |
def create_module_dict(*, module, n_per_class_list, n_tries, nb_knn, train_features, train_labels): | |
modules = {} | |
mapping = create_class_indices_mapping(train_labels) | |
for npc in n_per_class_list: | |
if npc < 0: # Only one try needed when using the full data | |
full_module = module( | |
train_features=train_features, | |
train_labels=train_labels, | |
nb_knn=nb_knn, | |
) | |
modules["full"] = ModuleDictWithForward({"1": full_module}) | |
continue | |
all_tries = {} | |
for t in range(n_tries): | |
final_indices = filter_train(mapping, npc, seed=t) | |
k_list = list(set(nb_knn + [npc])) | |
k_list = sorted([el for el in k_list if el <= npc]) | |
all_tries[str(t)] = module( | |
train_features=train_features[final_indices], | |
train_labels=train_labels[final_indices], | |
nb_knn=k_list, | |
) | |
modules[f"{npc} per class"] = ModuleDictWithForward(all_tries) | |
return ModuleDictWithForward(modules) | |
def filter_train(mapping, n_per_class, seed): | |
torch.manual_seed(seed) | |
final_indices = [] | |
for k in mapping.keys(): | |
index = torch.randperm(len(mapping[k]))[:n_per_class] | |
final_indices.append(mapping[k][index]) | |
return torch.cat(final_indices).squeeze() | |
def create_class_indices_mapping(labels): | |
unique_labels, inverse = torch.unique(labels, return_inverse=True) | |
mapping = {unique_labels[i]: (inverse == i).nonzero() for i in range(len(unique_labels))} | |
return mapping | |
class ModuleDictWithForward(torch.nn.ModuleDict): | |
def forward(self, *args, **kwargs): | |
return {k: module(*args, **kwargs) for k, module in self._modules.items()} | |
def eval_knn( | |
model, | |
train_dataset, | |
val_dataset, | |
accuracy_averaging, | |
nb_knn, | |
temperature, | |
batch_size, | |
num_workers, | |
gather_on_cpu, | |
n_per_class_list=[-1], | |
n_tries=1, | |
): | |
model = ModelWithNormalize(model) | |
logger.info("Extracting features for train set...") | |
train_features, train_labels = extract_features( | |
model, train_dataset, batch_size, num_workers, gather_on_cpu=gather_on_cpu | |
) | |
logger.info(f"Train features created, shape {train_features.shape}.") | |
val_dataloader = make_data_loader( | |
dataset=val_dataset, | |
batch_size=batch_size, | |
num_workers=num_workers, | |
sampler_type=SamplerType.DISTRIBUTED, | |
drop_last=False, | |
shuffle=False, | |
persistent_workers=True, | |
) | |
num_classes = train_labels.max() + 1 | |
metric_collection = build_topk_accuracy_metric(accuracy_averaging, num_classes=num_classes) | |
device = torch.cuda.current_device() | |
partial_module = partial(KnnModule, T=temperature, device=device, num_classes=num_classes) | |
knn_module_dict = create_module_dict( | |
module=partial_module, | |
n_per_class_list=n_per_class_list, | |
n_tries=n_tries, | |
nb_knn=nb_knn, | |
train_features=train_features, | |
train_labels=train_labels, | |
) | |
postprocessors, metrics = {}, {} | |
for n_per_class, knn_module in knn_module_dict.items(): | |
for t, knn_try in knn_module.items(): | |
postprocessors = { | |
**postprocessors, | |
**{(n_per_class, t, k): DictKeysModule([n_per_class, t, k]) for k in knn_try.nb_knn}, | |
} | |
metrics = {**metrics, **{(n_per_class, t, k): metric_collection.clone() for k in knn_try.nb_knn}} | |
model_with_knn = torch.nn.Sequential(model, knn_module_dict) | |
# ============ evaluation ... ============ | |
logger.info("Start the k-NN classification.") | |
_, results_dict = evaluate(model_with_knn, val_dataloader, postprocessors, metrics, device) | |
# Averaging the results over the n tries for each value of n_per_class | |
for n_per_class, knn_module in knn_module_dict.items(): | |
first_try = list(knn_module.keys())[0] | |
k_list = knn_module[first_try].nb_knn | |
for k in k_list: | |
keys = results_dict[(n_per_class, first_try, k)].keys() # keys are e.g. `top-1` and `top-5` | |
results_dict[(n_per_class, k)] = { | |
key: torch.mean(torch.stack([results_dict[(n_per_class, t, k)][key] for t in knn_module.keys()])) | |
for key in keys | |
} | |
for t in knn_module.keys(): | |
del results_dict[(n_per_class, t, k)] | |
return results_dict | |
def eval_knn_with_model( | |
model, | |
output_dir, | |
train_dataset_str="ImageNet:split=TRAIN", | |
val_dataset_str="ImageNet:split=VAL", | |
nb_knn=(10, 20, 100, 200), | |
temperature=0.07, | |
autocast_dtype=torch.float, | |
accuracy_averaging=AccuracyAveraging.MEAN_ACCURACY, | |
transform=None, | |
gather_on_cpu=False, | |
batch_size=256, | |
num_workers=5, | |
n_per_class_list=[-1], | |
n_tries=1, | |
): | |
transform = transform or make_classification_eval_transform() | |
train_dataset = make_dataset( | |
dataset_str=train_dataset_str, | |
transform=transform, | |
) | |
val_dataset = make_dataset( | |
dataset_str=val_dataset_str, | |
transform=transform, | |
) | |
with torch.cuda.amp.autocast(dtype=autocast_dtype): | |
results_dict_knn = eval_knn( | |
model=model, | |
train_dataset=train_dataset, | |
val_dataset=val_dataset, | |
accuracy_averaging=accuracy_averaging, | |
nb_knn=nb_knn, | |
temperature=temperature, | |
batch_size=batch_size, | |
num_workers=num_workers, | |
gather_on_cpu=gather_on_cpu, | |
n_per_class_list=n_per_class_list, | |
n_tries=n_tries, | |
) | |
results_dict = {} | |
if distributed.is_main_process(): | |
for knn_ in results_dict_knn.keys(): | |
top1 = results_dict_knn[knn_]["top-1"].item() * 100.0 | |
top5 = results_dict_knn[knn_]["top-5"].item() * 100.0 | |
results_dict[f"{knn_} Top 1"] = top1 | |
results_dict[f"{knn_} Top 5"] = top5 | |
logger.info(f"{knn_} classifier result: Top1: {top1:.2f} Top5: {top5:.2f}") | |
metrics_file_path = os.path.join(output_dir, "results_eval_knn.json") | |
with open(metrics_file_path, "a") as f: | |
for k, v in results_dict.items(): | |
f.write(json.dumps({k: v}) + "\n") | |
if distributed.is_enabled(): | |
torch.distributed.barrier() | |
return results_dict | |
def main(args): | |
model, autocast_dtype = setup_and_build_model(args) | |
eval_knn_with_model( | |
model=model, | |
output_dir=args.output_dir, | |
train_dataset_str=args.train_dataset_str, | |
val_dataset_str=args.val_dataset_str, | |
nb_knn=args.nb_knn, | |
temperature=args.temperature, | |
autocast_dtype=autocast_dtype, | |
accuracy_averaging=AccuracyAveraging.MEAN_ACCURACY, | |
transform=None, | |
gather_on_cpu=args.gather_on_cpu, | |
batch_size=args.batch_size, | |
num_workers=5, | |
n_per_class_list=args.n_per_class_list, | |
n_tries=args.n_tries, | |
) | |
return 0 | |
if __name__ == "__main__": | |
description = "DINOv2 k-NN evaluation" | |
args_parser = get_args_parser(description=description) | |
args = args_parser.parse_args() | |
sys.exit(main(args)) | |