|
import argparse |
|
import datetime |
|
import json |
|
import os |
|
import sys |
|
import time |
|
import math |
|
from pathlib import Path |
|
from typing import Sized |
|
import mast3r.utils.path_to_dust3r |
|
from collections import defaultdict |
|
import copy |
|
import torch |
|
import torch.backends.cudnn as cudnn |
|
from torch.utils.tensorboard import SummaryWriter |
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
from mast3r.model import AsymmetricMASt3R |
|
from dust3r.datasets import get_data_loader |
|
from dust3r.inference import loss_of_one_batch |
|
inf = float('inf') |
|
from mast3r.losses import MeshOutput |
|
import dust3r.utils.path_to_croco |
|
import croco.utils.misc as misc |
|
import torch.nn.functional as F |
|
|
|
def get_args_parser(): |
|
parser = argparse.ArgumentParser('DUST3R training', add_help=False) |
|
|
|
parser.add_argument('--model', default="AsymmetricCroCo3DStereo(patch_embed_cls='ManyAR_PatchEmbed')", |
|
type=str, help="string containing the model to build") |
|
parser.add_argument('--pretrained', default=None, help='path of a starting checkpoint') |
|
parser.add_argument('--test_criterion', default=None, type=str, help="test criterion") |
|
|
|
parser.add_argument('--test_dataset', default='[None]', type=str, help="testing set") |
|
|
|
parser.add_argument('--seed', default=0, type=int, help="Random seed") |
|
parser.add_argument('--batch_size', default=1, type=int, |
|
help="Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus") |
|
parser.add_argument('--accum_iter', default=1, type=int, |
|
help="Accumulate gradient iterations (for increasing the effective batch size under memory constraints)") |
|
parser.add_argument('--epochs', default=800, type=int, help="Maximum number of epochs for the scheduler") |
|
|
|
parser.add_argument('--weight_decay', type=float, default=0.05, help="weight decay (default: 0.05)") |
|
parser.add_argument('--lr', type=float, default=None, metavar='LR', help='learning rate (absolute lr)') |
|
parser.add_argument('--blr', type=float, default=1.5e-4, metavar='LR', |
|
help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') |
|
parser.add_argument('--min_lr', type=float, default=0., metavar='LR', |
|
help='lower lr bound for cyclic schedulers that hit 0') |
|
parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N', help='epochs to warmup LR') |
|
parser.add_argument('--amp', type=int, default=0, |
|
choices=[0, 1], help="Use Automatic Mixed Precision for pretraining") |
|
|
|
parser.add_argument('--num_workers', default=8, type=int) |
|
parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') |
|
parser.add_argument('--local_rank', default=-1, type=int) |
|
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') |
|
|
|
parser.add_argument('--eval_freq', type=int, default=1, help='Test loss evaluation frequency') |
|
parser.add_argument('--save_freq', default=1, type=int, |
|
help='frequence (number of epochs) to save checkpoint in checkpoint-last.pth') |
|
parser.add_argument('--keep_freq', default=20, type=int, |
|
help='frequence (number of epochs) to save checkpoint in checkpoint-%d.pth') |
|
parser.add_argument('--print_freq', default=20, type=int, |
|
help='frequence (number of iterations) to print infos while training') |
|
parser.add_argument('--noise_trans', default=0.05, type=float, help='translation noise') |
|
parser.add_argument('--noise_rot', default=10, type=float, help='rotation noise') |
|
parser.add_argument('--noise_prob', default=0.5, type=float, help='rotation noise') |
|
parser.add_argument('--save_input_image', default=False, type=bool) |
|
|
|
parser.add_argument('--output_dir', default='./output/', type=str, help="path where to save the output") |
|
return parser |
|
|
|
def main(args): |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
device = torch.device(device) |
|
data_loader_test = {dataset.split('(')[0]: build_dataset(dataset, args.batch_size, args.num_workers, test=True) |
|
for dataset in args.test_dataset.split('+')} |
|
print('Loading model: {:s}'.format(args.model)) |
|
model = eval(args.model) |
|
test_criterion = eval(args.test_criterion or args.criterion) |
|
model.to(device) |
|
model_without_ddp = model |
|
print("Model = %s" % str(model_without_ddp)) |
|
if args.pretrained: |
|
print('Loading pretrained: ', args.pretrained) |
|
ckpt = torch.load(args.pretrained, map_location=device) |
|
print(model.load_state_dict(ckpt['model'], strict=False)) |
|
del ckpt |
|
global_rank = misc.get_rank() |
|
if global_rank == 0 and args.output_dir is not None: |
|
log_writer = SummaryWriter(log_dir=args.output_dir) |
|
else: |
|
log_writer = None |
|
for test_name, testset in data_loader_test.items(): |
|
test_one_epoch(model, test_criterion, testset, |
|
device, 0, log_writer=log_writer, args=args, prefix=test_name) |
|
|
|
|
|
def build_dataset(dataset, batch_size, num_workers, test=False): |
|
split = ['Train', 'Test'][test] |
|
print(f'Building {split} Data loader for dataset: ', dataset) |
|
loader = get_data_loader(dataset, |
|
batch_size=batch_size, |
|
num_workers=num_workers, |
|
pin_mem=True, |
|
shuffle=not (test), |
|
drop_last=not (test)) |
|
|
|
print(f"{split} dataset length: ", len(loader)) |
|
return loader |
|
|
|
def pad_to_square(reshaped_image): |
|
B, C, H, W = reshaped_image.shape |
|
max_dim = max(H, W) |
|
pad_height = max_dim - H |
|
pad_width = max_dim - W |
|
padding = (pad_width // 2, pad_width - pad_width // 2, |
|
pad_height // 2, pad_height - pad_height // 2) |
|
padded_image = F.pad(reshaped_image, padding, mode='constant', value=0) |
|
return padded_image |
|
|
|
def generate_rank_by_dino( |
|
reshaped_image, backbone, query_frame_num, image_size=336 |
|
): |
|
|
|
|
|
rgbs = pad_to_square(reshaped_image) |
|
rgbs = F.interpolate( |
|
reshaped_image, |
|
(image_size, image_size), |
|
mode="bilinear", |
|
align_corners=True, |
|
) |
|
rgbs = _resnet_normalize_image(rgbs.cuda()) |
|
|
|
|
|
frame_feat = backbone(rgbs, is_training=True) |
|
frame_feat = frame_feat["x_norm_patchtokens"] |
|
frame_feat_norm = F.normalize(frame_feat, p=2, dim=1) |
|
|
|
|
|
frame_feat_norm = frame_feat_norm.permute(1, 0, 2) |
|
similarity_matrix = torch.bmm( |
|
frame_feat_norm, frame_feat_norm.transpose(-1, -2) |
|
) |
|
similarity_matrix = similarity_matrix.mean(dim=0) |
|
distance_matrix = 100 - similarity_matrix.clone() |
|
|
|
|
|
similarity_matrix.fill_diagonal_(-100) |
|
|
|
similarity_sum = similarity_matrix.sum(dim=1) |
|
|
|
|
|
most_common_frame_index = torch.argmax(similarity_sum).item() |
|
return most_common_frame_index |
|
|
|
_RESNET_MEAN = [0.485, 0.456, 0.406] |
|
_RESNET_STD = [0.229, 0.224, 0.225] |
|
_resnet_mean = torch.tensor(_RESNET_MEAN).view(1, 3, 1, 1).cuda() |
|
_resnet_std = torch.tensor(_RESNET_STD).view(1, 3, 1, 1).cuda() |
|
def _resnet_normalize_image(img: torch.Tensor) -> torch.Tensor: |
|
return (img - _resnet_mean) / _resnet_std |
|
|
|
def calculate_index_mappings(query_index, S, device=None): |
|
""" |
|
Construct an order that we can switch [query_index] and [0] |
|
so that the content of query_index would be placed at [0] |
|
""" |
|
new_order = torch.arange(S) |
|
new_order[0] = query_index |
|
new_order[query_index] = 0 |
|
if device is not None: |
|
new_order = new_order.to(device) |
|
return new_order |
|
|
|
@torch.no_grad() |
|
def test_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, |
|
data_loader: Sized, device: torch.device, epoch: int, |
|
args, log_writer=None, prefix='test'): |
|
model.eval() |
|
metric_logger = misc.MetricLogger(delimiter=" ") |
|
metric_logger.meters = defaultdict(lambda: misc.SmoothedValue(window_size=9**9)) |
|
header = 'Test Epoch: [{}]'.format(epoch) |
|
if log_writer is not None: |
|
print('log_dir: {}'.format(log_writer.log_dir)) |
|
if hasattr(data_loader, 'dataset') and hasattr(data_loader.dataset, 'set_epoch'): |
|
data_loader.dataset.set_epoch(epoch) |
|
if hasattr(data_loader, 'sampler') and hasattr(data_loader.sampler, 'set_epoch'): |
|
data_loader.sampler.set_epoch(epoch) |
|
try: |
|
gt_num_image = data_loader.dataset.dataset.gt_num_image |
|
except: |
|
gt_num_image = data_loader.dataset.gt_num_image |
|
backbone = torch.hub.load( |
|
"facebookresearch/dinov2", "dinov2_vitb14_reg" |
|
) |
|
backbone = backbone.eval().cuda() |
|
for i, batch in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)): |
|
images = [gt['img_org'] for gt in batch] |
|
images = torch.cat(images, dim=0) |
|
images = images / 2 + 0.5 |
|
index = generate_rank_by_dino(images, backbone, query_frame_num=1) |
|
sorted_order = calculate_index_mappings(index, len(images), device=device) |
|
sorted_batch = [] |
|
for i in range(len(batch)): |
|
sorted_batch.append(batch[sorted_order[i]]) |
|
batch = sorted_batch |
|
loss_tuple = loss_of_one_batch(gt_num_image, batch, model, criterion, device, |
|
symmetrize_batch=True, |
|
use_amp=bool(args.amp)) |
|
|
|
if __name__ == '__main__': |
|
args = get_args_parser() |
|
args = args.parse_args() |
|
main(args) |
|
|