Spaces:
Runtime error
Runtime error
# some parts of the code adapted from https://github.com/benjiebob/WLDO and https://github.com/benjiebob/SMALify | |
import numpy as np | |
import torch | |
import pickle as pkl | |
class ShapePrior(torch.nn.Module): | |
def __init__(self, prior_path): | |
super(ShapePrior, self).__init__() | |
try: | |
with open(prior_path, 'r') as f: | |
res = pkl.load(f) | |
except (UnicodeDecodeError, TypeError) as e: | |
with open(prior_path, 'rb') as file: | |
u = pkl._Unpickler(file) | |
u.encoding = 'latin1' | |
res = u.load() | |
betas_mean = res['dog_cluster_mean'] | |
betas_cov = res['dog_cluster_cov'] | |
single_gaussian_inv_covs = np.linalg.inv(betas_cov + 1e-5 * np.eye(betas_cov.shape[0])) | |
single_gaussian_precs = torch.tensor(np.linalg.cholesky(single_gaussian_inv_covs)).float() | |
single_gaussian_means = torch.tensor(betas_mean).float() | |
self.register_buffer('single_gaussian_precs', single_gaussian_precs) # (20, 20) | |
self.register_buffer('single_gaussian_means', single_gaussian_means) # (20) | |
use_ind_tch = torch.from_numpy(np.ones(single_gaussian_means.shape[0], dtype=bool)).float() # .to(device) | |
self.register_buffer('use_ind_tch', use_ind_tch) | |
def forward(self, betas_smal_orig, use_singe_gaussian=False): | |
n_betas_smal = betas_smal_orig.shape[1] | |
device = betas_smal_orig.device | |
use_ind_tch_corrected = self.use_ind_tch * torch.cat((torch.ones_like(self.use_ind_tch[:n_betas_smal]), torch.zeros_like(self.use_ind_tch[n_betas_smal:]))) | |
samples = torch.cat((betas_smal_orig, torch.zeros((betas_smal_orig.shape[0], self.single_gaussian_means.shape[0]-n_betas_smal)).float().to(device)), dim=1) | |
mean_sub = samples - self.single_gaussian_means.unsqueeze(0) | |
single_gaussian_precs_corr = self.single_gaussian_precs * use_ind_tch_corrected[:, None] * use_ind_tch_corrected[None, :] | |
res = torch.tensordot(mean_sub, single_gaussian_precs_corr, dims = ([1], [0])) | |
res_final_mean_2 = torch.mean(res ** 2) | |
return res_final_mean_2 | |