FLARE / run_pose_pointcloud.py
聂如
Add design file
91126af
import argparse
import datetime
import json
import os
import sys
import time
import math
from pathlib import Path
from typing import Sized
import mast3r.utils.path_to_dust3r # noqa
from collections import defaultdict
import copy
import torch
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter
torch.backends.cuda.matmul.allow_tf32 = True
from mast3r.model import AsymmetricMASt3R
from dust3r.datasets import get_data_loader # noqa
from dust3r.inference import loss_of_one_batch # noqa
inf = float('inf')
from mast3r.losses import MeshOutput
import dust3r.utils.path_to_croco # noqa: F401
import croco.utils.misc as misc # noqa
import torch.nn.functional as F
def get_args_parser():
parser = argparse.ArgumentParser('DUST3R training', add_help=False)
# model and criterion
parser.add_argument('--model', default="AsymmetricCroCo3DStereo(patch_embed_cls='ManyAR_PatchEmbed')",
type=str, help="string containing the model to build")
parser.add_argument('--pretrained', default=None, help='path of a starting checkpoint')
parser.add_argument('--test_criterion', default=None, type=str, help="test criterion")
# dataset
parser.add_argument('--test_dataset', default='[None]', type=str, help="testing set")
# training
parser.add_argument('--seed', default=0, type=int, help="Random seed")
parser.add_argument('--batch_size', default=1, type=int,
help="Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus")
parser.add_argument('--accum_iter', default=1, type=int,
help="Accumulate gradient iterations (for increasing the effective batch size under memory constraints)")
parser.add_argument('--epochs', default=800, type=int, help="Maximum number of epochs for the scheduler")
parser.add_argument('--weight_decay', type=float, default=0.05, help="weight decay (default: 0.05)")
parser.add_argument('--lr', type=float, default=None, metavar='LR', help='learning rate (absolute lr)')
parser.add_argument('--blr', type=float, default=1.5e-4, metavar='LR',
help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
parser.add_argument('--min_lr', type=float, default=0., metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0')
parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N', help='epochs to warmup LR')
parser.add_argument('--amp', type=int, default=0,
choices=[0, 1], help="Use Automatic Mixed Precision for pretraining")
# others
parser.add_argument('--num_workers', default=8, type=int)
parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
parser.add_argument('--eval_freq', type=int, default=1, help='Test loss evaluation frequency')
parser.add_argument('--save_freq', default=1, type=int,
help='frequence (number of epochs) to save checkpoint in checkpoint-last.pth')
parser.add_argument('--keep_freq', default=20, type=int,
help='frequence (number of epochs) to save checkpoint in checkpoint-%d.pth')
parser.add_argument('--print_freq', default=20, type=int,
help='frequence (number of iterations) to print infos while training')
parser.add_argument('--noise_trans', default=0.05, type=float, help='translation noise')
parser.add_argument('--noise_rot', default=10, type=float, help='rotation noise')
parser.add_argument('--noise_prob', default=0.5, type=float, help='rotation noise')
parser.add_argument('--save_input_image', default=False, type=bool)
# output dir
parser.add_argument('--output_dir', default='./output/', type=str, help="path where to save the output")
return parser
def main(args):
device = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device)
data_loader_test = {dataset.split('(')[0]: build_dataset(dataset, args.batch_size, args.num_workers, test=True)
for dataset in args.test_dataset.split('+')}
print('Loading model: {:s}'.format(args.model))
model = eval(args.model)
test_criterion = eval(args.test_criterion or args.criterion)
model.to(device)
model_without_ddp = model
print("Model = %s" % str(model_without_ddp))
if args.pretrained:
print('Loading pretrained: ', args.pretrained)
ckpt = torch.load(args.pretrained, map_location=device)
print(model.load_state_dict(ckpt['model'], strict=False))
del ckpt # in case it occupies memory
global_rank = misc.get_rank()
if global_rank == 0 and args.output_dir is not None:
log_writer = SummaryWriter(log_dir=args.output_dir)
else:
log_writer = None
for test_name, testset in data_loader_test.items():
test_one_epoch(model, test_criterion, testset,
device, 0, log_writer=log_writer, args=args, prefix=test_name)
def build_dataset(dataset, batch_size, num_workers, test=False):
split = ['Train', 'Test'][test]
print(f'Building {split} Data loader for dataset: ', dataset)
loader = get_data_loader(dataset,
batch_size=batch_size,
num_workers=num_workers,
pin_mem=True,
shuffle=not (test),
drop_last=not (test))
print(f"{split} dataset length: ", len(loader))
return loader
def pad_to_square(reshaped_image):
B, C, H, W = reshaped_image.shape
max_dim = max(H, W)
pad_height = max_dim - H
pad_width = max_dim - W
padding = (pad_width // 2, pad_width - pad_width // 2,
pad_height // 2, pad_height - pad_height // 2)
padded_image = F.pad(reshaped_image, padding, mode='constant', value=0)
return padded_image
def generate_rank_by_dino(
reshaped_image, backbone, query_frame_num, image_size=336
):
# Downsample image to image_size x image_size
# because we found it is unnecessary to use high resolution
rgbs = pad_to_square(reshaped_image)
rgbs = F.interpolate(
reshaped_image,
(image_size, image_size),
mode="bilinear",
align_corners=True,
)
rgbs = _resnet_normalize_image(rgbs.cuda())
# Get the image features (patch level)
frame_feat = backbone(rgbs, is_training=True)
frame_feat = frame_feat["x_norm_patchtokens"]
frame_feat_norm = F.normalize(frame_feat, p=2, dim=1)
# Compute the similiarty matrix
frame_feat_norm = frame_feat_norm.permute(1, 0, 2)
similarity_matrix = torch.bmm(
frame_feat_norm, frame_feat_norm.transpose(-1, -2)
)
similarity_matrix = similarity_matrix.mean(dim=0)
distance_matrix = 100 - similarity_matrix.clone()
# Ignore self-pairing
similarity_matrix.fill_diagonal_(-100)
similarity_sum = similarity_matrix.sum(dim=1)
# Find the most common frame
most_common_frame_index = torch.argmax(similarity_sum).item()
return most_common_frame_index
_RESNET_MEAN = [0.485, 0.456, 0.406]
_RESNET_STD = [0.229, 0.224, 0.225]
_resnet_mean = torch.tensor(_RESNET_MEAN).view(1, 3, 1, 1).cuda()
_resnet_std = torch.tensor(_RESNET_STD).view(1, 3, 1, 1).cuda()
def _resnet_normalize_image(img: torch.Tensor) -> torch.Tensor:
return (img - _resnet_mean) / _resnet_std
def calculate_index_mappings(query_index, S, device=None):
"""
Construct an order that we can switch [query_index] and [0]
so that the content of query_index would be placed at [0]
"""
new_order = torch.arange(S)
new_order[0] = query_index
new_order[query_index] = 0
if device is not None:
new_order = new_order.to(device)
return new_order
@torch.no_grad()
def test_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
data_loader: Sized, device: torch.device, epoch: int,
args, log_writer=None, prefix='test'):
model.eval()
metric_logger = misc.MetricLogger(delimiter=" ")
metric_logger.meters = defaultdict(lambda: misc.SmoothedValue(window_size=9**9))
header = 'Test Epoch: [{}]'.format(epoch)
if log_writer is not None:
print('log_dir: {}'.format(log_writer.log_dir))
if hasattr(data_loader, 'dataset') and hasattr(data_loader.dataset, 'set_epoch'):
data_loader.dataset.set_epoch(epoch)
if hasattr(data_loader, 'sampler') and hasattr(data_loader.sampler, 'set_epoch'):
data_loader.sampler.set_epoch(epoch)
try:
gt_num_image = data_loader.dataset.dataset.gt_num_image
except:
gt_num_image = data_loader.dataset.gt_num_image
backbone = torch.hub.load(
"facebookresearch/dinov2", "dinov2_vitb14_reg"
)
backbone = backbone.eval().cuda()
for i, batch in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):
images = [gt['img_org'] for gt in batch]
images = torch.cat(images, dim=0)
images = images / 2 + 0.5
index = generate_rank_by_dino(images, backbone, query_frame_num=1)
sorted_order = calculate_index_mappings(index, len(images), device=device)
sorted_batch = []
for i in range(len(batch)):
sorted_batch.append(batch[sorted_order[i]])
batch = sorted_batch
loss_tuple = loss_of_one_batch(gt_num_image, batch, model, criterion, device,
symmetrize_batch=True,
use_amp=bool(args.amp))
if __name__ == '__main__':
args = get_args_parser()
args = args.parse_args()
main(args)