ECON / lib /common /train_util.py
Yuliang's picture
Support TEXTure
487ee6d
raw
history blame contribute delete
No virus
4.31 kB
# -*- coding: utf-8 -*-
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# Contact: [email protected]
import pytorch_lightning as pl
import torch
from termcolor import colored
from ..dataset.mesh_util import *
from ..net.geometry import orthogonal
class Format:
end = '\033[0m'
start = '\033[4m'
def init_loss():
losses = {
# Cloth: chamfer distance
"cloth": {"weight": 1e3, "value": 0.0},
# Stiffness: [RT]_v1 - [RT]_v2 (v1-edge-v2)
"stiff": {"weight": 1e5, "value": 0.0},
# Cloth: det(R) = 1
"rigid": {"weight": 1e5, "value": 0.0},
# Cloth: edge length
"edge": {"weight": 0, "value": 0.0},
# Cloth: normal consistency
"nc": {"weight": 0, "value": 0.0},
# Cloth: laplacian smoonth
"lapla": {"weight": 1e2, "value": 0.0},
# Body: Normal_pred - Normal_smpl
"normal": {"weight": 1e0, "value": 0.0},
# Body: Silhouette_pred - Silhouette_smpl
"silhouette": {"weight": 1e0, "value": 0.0},
# Joint: reprojected joints difference
"joint": {"weight": 5e0, "value": 0.0},
}
return losses
class SubTrainer(pl.Trainer):
def save_checkpoint(self, filepath, weights_only=False):
"""Save model/training states as a checkpoint file through state-dump and file-write.
Args:
filepath: write-target file's path
weights_only: saving model weights only
"""
_checkpoint = self._checkpoint_connector.dump_checkpoint(weights_only)
del_keys = []
for key in _checkpoint["state_dict"].keys():
for ignore_key in ["normal_filter", "voxelization", "reconEngine"]:
if ignore_key in key:
del_keys.append(key)
for key in del_keys:
del _checkpoint["state_dict"][key]
pl.utilities.cloud_io.atomic_save(_checkpoint, filepath)
def query_func(opt, netG, features, points, proj_matrix=None):
"""
- points: size of (bz, N, 3)
- proj_matrix: size of (bz, 4, 4)
return: size of (bz, 1, N)
"""
assert len(points) == 1
samples = points.repeat(opt.num_views, 1, 1)
samples = samples.permute(0, 2, 1) # [bz, 3, N]
# view specific query
if proj_matrix is not None:
samples = orthogonal(samples, proj_matrix)
calib_tensor = torch.stack([torch.eye(4).float()], dim=0).type_as(samples)
preds = netG.query(
features=features,
points=samples,
calibs=calib_tensor,
regressor=netG.if_regressor,
)
if type(preds) is list:
preds = preds[0]
return preds
def query_func_IF(batch, netG, points):
"""
- points: size of (bz, N, 3)
return: size of (bz, 1, N)
"""
batch["samples_geo"] = points
batch["calib"] = torch.stack([torch.eye(4).float()], dim=0).type_as(points)
preds = netG(batch)
return preds.unsqueeze(1)
def batch_mean(res, key):
return torch.stack([
x[key] if torch.is_tensor(x[key]) else torch.as_tensor(x[key]) for x in res
]).mean()
def accumulate(outputs, rot_num, split):
hparam_log_dict = {}
metrics = outputs[0].keys()
datasets = split.keys()
for dataset in datasets:
for metric in metrics:
keyword = f"{dataset}/{metric}"
if keyword not in hparam_log_dict.keys():
hparam_log_dict[keyword] = 0
for idx in range(split[dataset][0] * rot_num, split[dataset][1] * rot_num):
hparam_log_dict[keyword] += outputs[idx][metric].item()
hparam_log_dict[keyword] /= (split[dataset][1] - split[dataset][0]) * rot_num
print(colored(hparam_log_dict, "green"))
return hparam_log_dict