|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import logging |
|
import math |
|
import os |
|
from functools import partial |
|
|
|
from fvcore.common.checkpoint import PeriodicCheckpointer |
|
import torch |
|
|
|
from dinov2.data import SamplerType, make_data_loader, make_dataset |
|
from dinov2.data import collate_data_and_cast, DataAugmentationDINO, MaskingGenerator |
|
import dinov2.distributed as distributed |
|
from dinov2.fsdp import FSDPCheckpointer |
|
from dinov2.logging import MetricLogger |
|
from dinov2.utils.config import setup |
|
from dinov2.utils.utils import CosineScheduler |
|
|
|
from dinov2.train.ssl_meta_arch import SSLMetaArch |
|
|
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
logger = logging.getLogger("dinov2") |
|
|
|
|
|
def get_args_parser(add_help: bool = True): |
|
parser = argparse.ArgumentParser("DINOv2 training", add_help=add_help) |
|
parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file") |
|
parser.add_argument( |
|
"--no-resume", |
|
action="store_true", |
|
help="Whether to not attempt to resume from the checkpoint directory. ", |
|
) |
|
parser.add_argument("--eval-only", action="store_true", help="perform evaluation only") |
|
parser.add_argument("--eval", type=str, default="", help="Eval type to perform") |
|
parser.add_argument( |
|
"opts", |
|
help=""" |
|
Modify config options at the end of the command. For Yacs configs, use |
|
space-separated "PATH.KEY VALUE" pairs. |
|
For python-based LazyConfig, use "path.key=value". |
|
""".strip(), |
|
default=None, |
|
nargs=argparse.REMAINDER, |
|
) |
|
parser.add_argument( |
|
"--output-dir", |
|
"--output_dir", |
|
default="", |
|
type=str, |
|
help="Output directory to save logs and checkpoints", |
|
) |
|
|
|
return parser |
|
|
|
|
|
def build_optimizer(cfg, params_groups): |
|
return torch.optim.AdamW(params_groups, betas=(cfg.optim.adamw_beta1, cfg.optim.adamw_beta2)) |
|
|
|
|
|
def build_schedulers(cfg): |
|
OFFICIAL_EPOCH_LENGTH = cfg.train.OFFICIAL_EPOCH_LENGTH |
|
lr = dict( |
|
base_value=cfg.optim["lr"], |
|
final_value=cfg.optim["min_lr"], |
|
total_iters=cfg.optim["epochs"] * OFFICIAL_EPOCH_LENGTH, |
|
warmup_iters=cfg.optim["warmup_epochs"] * OFFICIAL_EPOCH_LENGTH, |
|
start_warmup_value=0, |
|
) |
|
wd = dict( |
|
base_value=cfg.optim["weight_decay"], |
|
final_value=cfg.optim["weight_decay_end"], |
|
total_iters=cfg.optim["epochs"] * OFFICIAL_EPOCH_LENGTH, |
|
) |
|
momentum = dict( |
|
base_value=cfg.teacher["momentum_teacher"], |
|
final_value=cfg.teacher["final_momentum_teacher"], |
|
total_iters=cfg.optim["epochs"] * OFFICIAL_EPOCH_LENGTH, |
|
) |
|
teacher_temp = dict( |
|
base_value=cfg.teacher["teacher_temp"], |
|
final_value=cfg.teacher["teacher_temp"], |
|
total_iters=cfg.teacher["warmup_teacher_temp_epochs"] * OFFICIAL_EPOCH_LENGTH, |
|
warmup_iters=cfg.teacher["warmup_teacher_temp_epochs"] * OFFICIAL_EPOCH_LENGTH, |
|
start_warmup_value=cfg.teacher["warmup_teacher_temp"], |
|
) |
|
|
|
lr_schedule = CosineScheduler(**lr) |
|
wd_schedule = CosineScheduler(**wd) |
|
momentum_schedule = CosineScheduler(**momentum) |
|
teacher_temp_schedule = CosineScheduler(**teacher_temp) |
|
last_layer_lr_schedule = CosineScheduler(**lr) |
|
|
|
last_layer_lr_schedule.schedule[ |
|
: cfg.optim["freeze_last_layer_epochs"] * OFFICIAL_EPOCH_LENGTH |
|
] = 0 |
|
|
|
logger.info("Schedulers ready.") |
|
|
|
return ( |
|
lr_schedule, |
|
wd_schedule, |
|
momentum_schedule, |
|
teacher_temp_schedule, |
|
last_layer_lr_schedule, |
|
) |
|
|
|
|
|
def apply_optim_scheduler(optimizer, lr, wd, last_layer_lr): |
|
for param_group in optimizer.param_groups: |
|
is_last_layer = param_group["is_last_layer"] |
|
lr_multiplier = param_group["lr_multiplier"] |
|
wd_multiplier = param_group["wd_multiplier"] |
|
param_group["weight_decay"] = wd * wd_multiplier |
|
param_group["lr"] = (last_layer_lr if is_last_layer else lr) * lr_multiplier |
|
|
|
|
|
def do_test(cfg, model, iteration): |
|
new_state_dict = model.teacher.state_dict() |
|
|
|
if distributed.is_main_process(): |
|
iterstring = str(iteration) |
|
eval_dir = os.path.join(cfg.train.output_dir, "eval", iterstring) |
|
os.makedirs(eval_dir, exist_ok=True) |
|
|
|
teacher_ckp_path = os.path.join(eval_dir, "teacher_checkpoint.pth") |
|
torch.save({"teacher": new_state_dict}, teacher_ckp_path) |
|
|
|
|
|
def do_train(cfg, model, resume=False): |
|
model.train() |
|
inputs_dtype = torch.half |
|
fp16_scaler = model.fp16_scaler |
|
|
|
|
|
|
|
optimizer = build_optimizer(cfg, model.get_params_groups()) |
|
( |
|
lr_schedule, |
|
wd_schedule, |
|
momentum_schedule, |
|
teacher_temp_schedule, |
|
last_layer_lr_schedule, |
|
) = build_schedulers(cfg) |
|
|
|
|
|
checkpointer = FSDPCheckpointer(model, cfg.train.output_dir, optimizer=optimizer, save_to_disk=True) |
|
|
|
start_iter = checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=resume).get("iteration", -1) + 1 |
|
|
|
OFFICIAL_EPOCH_LENGTH = cfg.train.OFFICIAL_EPOCH_LENGTH |
|
max_iter = cfg.optim.epochs * OFFICIAL_EPOCH_LENGTH |
|
|
|
periodic_checkpointer = PeriodicCheckpointer( |
|
checkpointer, |
|
period=3 * OFFICIAL_EPOCH_LENGTH, |
|
max_iter=max_iter, |
|
max_to_keep=3, |
|
) |
|
|
|
|
|
|
|
img_size = cfg.crops.global_crops_size |
|
patch_size = cfg.student.patch_size |
|
n_tokens = (img_size // patch_size) ** 2 |
|
mask_generator = MaskingGenerator( |
|
input_size=(img_size // patch_size, img_size // patch_size), |
|
max_num_patches=0.5 * img_size // patch_size * img_size // patch_size, |
|
) |
|
|
|
data_transform = DataAugmentationDINO( |
|
cfg.crops.global_crops_scale, |
|
cfg.crops.local_crops_scale, |
|
cfg.crops.local_crops_number, |
|
global_crops_size=cfg.crops.global_crops_size, |
|
local_crops_size=cfg.crops.local_crops_size, |
|
) |
|
|
|
collate_fn = partial( |
|
collate_data_and_cast, |
|
mask_ratio_tuple=cfg.ibot.mask_ratio_min_max, |
|
mask_probability=cfg.ibot.mask_sample_probability, |
|
n_tokens=n_tokens, |
|
mask_generator=mask_generator, |
|
dtype=inputs_dtype, |
|
) |
|
|
|
|
|
|
|
dataset = make_dataset( |
|
dataset_str=cfg.train.dataset_path, |
|
transform=data_transform, |
|
target_transform=lambda _: (), |
|
) |
|
|
|
sampler_type = SamplerType.SHARDED_INFINITE |
|
data_loader = make_data_loader( |
|
dataset=dataset, |
|
batch_size=cfg.train.batch_size_per_gpu, |
|
num_workers=cfg.train.num_workers, |
|
shuffle=True, |
|
seed=start_iter, |
|
sampler_type=sampler_type, |
|
sampler_advance=0, |
|
drop_last=True, |
|
collate_fn=collate_fn, |
|
) |
|
|
|
|
|
|
|
iteration = start_iter |
|
|
|
logger.info("Starting training from iteration {}".format(start_iter)) |
|
metrics_file = os.path.join(cfg.train.output_dir, "training_metrics.json") |
|
metric_logger = MetricLogger(delimiter=" ", output_file=metrics_file) |
|
header = "Training" |
|
|
|
for data in metric_logger.log_every( |
|
data_loader, |
|
10, |
|
header, |
|
max_iter, |
|
start_iter, |
|
): |
|
current_batch_size = data["collated_global_crops"].shape[0] / 2 |
|
if iteration > max_iter: |
|
return |
|
|
|
|
|
|
|
lr = lr_schedule[iteration] |
|
wd = wd_schedule[iteration] |
|
mom = momentum_schedule[iteration] |
|
teacher_temp = teacher_temp_schedule[iteration] |
|
last_layer_lr = last_layer_lr_schedule[iteration] |
|
apply_optim_scheduler(optimizer, lr, wd, last_layer_lr) |
|
|
|
|
|
|
|
optimizer.zero_grad(set_to_none=True) |
|
loss_dict = model.forward_backward(data, teacher_temp=teacher_temp) |
|
|
|
|
|
|
|
if fp16_scaler is not None: |
|
if cfg.optim.clip_grad: |
|
fp16_scaler.unscale_(optimizer) |
|
for v in model.student.values(): |
|
v.clip_grad_norm_(cfg.optim.clip_grad) |
|
fp16_scaler.step(optimizer) |
|
fp16_scaler.update() |
|
else: |
|
if cfg.optim.clip_grad: |
|
for v in model.student.values(): |
|
v.clip_grad_norm_(cfg.optim.clip_grad) |
|
optimizer.step() |
|
|
|
|
|
|
|
model.update_teacher(mom) |
|
|
|
|
|
|
|
if distributed.get_global_size() > 1: |
|
for v in loss_dict.values(): |
|
torch.distributed.all_reduce(v) |
|
loss_dict_reduced = {k: v.item() / distributed.get_global_size() for k, v in loss_dict.items()} |
|
|
|
if math.isnan(sum(loss_dict_reduced.values())): |
|
logger.info("NaN detected") |
|
raise AssertionError |
|
losses_reduced = sum(loss for loss in loss_dict_reduced.values()) |
|
|
|
metric_logger.update(lr=lr) |
|
metric_logger.update(wd=wd) |
|
metric_logger.update(mom=mom) |
|
metric_logger.update(last_layer_lr=last_layer_lr) |
|
metric_logger.update(current_batch_size=current_batch_size) |
|
metric_logger.update(total_loss=losses_reduced, **loss_dict_reduced) |
|
|
|
|
|
|
|
if cfg.evaluation.eval_period_iterations > 0 and (iteration + 1) % cfg.evaluation.eval_period_iterations == 0: |
|
do_test(cfg, model, f"training_{iteration}") |
|
torch.cuda.synchronize() |
|
periodic_checkpointer.step(iteration) |
|
|
|
iteration = iteration + 1 |
|
metric_logger.synchronize_between_processes() |
|
return {k: meter.global_avg for k, meter in metric_logger.meters.items()} |
|
|
|
|
|
def main(args): |
|
cfg = setup(args) |
|
|
|
model = SSLMetaArch(cfg).to(torch.device("cuda")) |
|
model.prepare_for_distributed_training() |
|
|
|
logger.info("Model:\n{}".format(model)) |
|
if args.eval_only: |
|
iteration = ( |
|
FSDPCheckpointer(model, save_dir=cfg.train.output_dir) |
|
.resume_or_load(cfg.MODEL.WEIGHTS, resume=not args.no_resume) |
|
.get("iteration", -1) |
|
+ 1 |
|
) |
|
return do_test(cfg, model, f"manual_{iteration}") |
|
|
|
do_train(cfg, model, resume=not args.no_resume) |
|
|
|
|
|
if __name__ == "__main__": |
|
args = get_args_parser(add_help=True).parse_args() |
|
main(args) |
|
|