# -*- coding: utf-8 -*- # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is # holder of all proprietary rights on this computer program. # You can only use this computer program if you have closed # a license agreement with MPG or you get the right to use the computer # program from someone who is authorized to grant you that right. # Any use of the computer program without a valid license is prohibited and # liable to prosecution. # # Copyright©2019 Max-Planck-Gesellschaft zur Förderung # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute # for Intelligent Systems. All rights reserved. # # Contact: ps-license@tuebingen.mpg.de import pytorch_lightning as pl import torch from termcolor import colored from ..dataset.mesh_util import * from ..net.geometry import orthogonal class Format: end = '\033[0m' start = '\033[4m' def init_loss(): losses = { # Cloth: chamfer distance "cloth": {"weight": 1e3, "value": 0.0}, # Stiffness: [RT]_v1 - [RT]_v2 (v1-edge-v2) "stiff": {"weight": 1e5, "value": 0.0}, # Cloth: det(R) = 1 "rigid": {"weight": 1e5, "value": 0.0}, # Cloth: edge length "edge": {"weight": 0, "value": 0.0}, # Cloth: normal consistency "nc": {"weight": 0, "value": 0.0}, # Cloth: laplacian smoonth "lapla": {"weight": 1e2, "value": 0.0}, # Body: Normal_pred - Normal_smpl "normal": {"weight": 1e0, "value": 0.0}, # Body: Silhouette_pred - Silhouette_smpl "silhouette": {"weight": 1e0, "value": 0.0}, # Joint: reprojected joints difference "joint": {"weight": 5e0, "value": 0.0}, } return losses class SubTrainer(pl.Trainer): def save_checkpoint(self, filepath, weights_only=False): """Save model/training states as a checkpoint file through state-dump and file-write. Args: filepath: write-target file's path weights_only: saving model weights only """ _checkpoint = self._checkpoint_connector.dump_checkpoint(weights_only) del_keys = [] for key in _checkpoint["state_dict"].keys(): for ignore_key in ["normal_filter", "voxelization", "reconEngine"]: if ignore_key in key: del_keys.append(key) for key in del_keys: del _checkpoint["state_dict"][key] pl.utilities.cloud_io.atomic_save(_checkpoint, filepath) def query_func(opt, netG, features, points, proj_matrix=None): """ - points: size of (bz, N, 3) - proj_matrix: size of (bz, 4, 4) return: size of (bz, 1, N) """ assert len(points) == 1 samples = points.repeat(opt.num_views, 1, 1) samples = samples.permute(0, 2, 1) # [bz, 3, N] # view specific query if proj_matrix is not None: samples = orthogonal(samples, proj_matrix) calib_tensor = torch.stack([torch.eye(4).float()], dim=0).type_as(samples) preds = netG.query( features=features, points=samples, calibs=calib_tensor, regressor=netG.if_regressor, ) if type(preds) is list: preds = preds[0] return preds def query_func_IF(batch, netG, points): """ - points: size of (bz, N, 3) return: size of (bz, 1, N) """ batch["samples_geo"] = points batch["calib"] = torch.stack([torch.eye(4).float()], dim=0).type_as(points) preds = netG(batch) return preds.unsqueeze(1) def batch_mean(res, key): return torch.stack([ x[key] if torch.is_tensor(x[key]) else torch.as_tensor(x[key]) for x in res ]).mean() def accumulate(outputs, rot_num, split): hparam_log_dict = {} metrics = outputs[0].keys() datasets = split.keys() for dataset in datasets: for metric in metrics: keyword = f"{dataset}/{metric}" if keyword not in hparam_log_dict.keys(): hparam_log_dict[keyword] = 0 for idx in range(split[dataset][0] * rot_num, split[dataset][1] * rot_num): hparam_log_dict[keyword] += outputs[idx][metric].item() hparam_log_dict[keyword] /= (split[dataset][1] - split[dataset][0]) * rot_num print(colored(hparam_log_dict, "green")) return hparam_log_dict