import pickle as pkl import numpy as np import torchvision.models as models from torchvision import transforms import torch from torch import nn from torch.nn.parameter import Parameter from kornia.geometry.subpix import dsnt # kornia 0.4.0 import os import sys sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) from stacked_hourglass.utils.evaluation import get_preds_soft from stacked_hourglass import hg1, hg2, hg8 from lifting_to_3d.linear_model import LinearModelComplete, LinearModel from lifting_to_3d.inn_model_for_shape import INNForShape from lifting_to_3d.utils.geometry_utils import rot6d_to_rotmat, rotmat_to_rot6d from smal_pytorch.smal_model.smal_torch_new import SMAL from smal_pytorch.renderer.differentiable_renderer import SilhRenderer from bps_2d.bps_for_segmentation import SegBPS from configs.SMAL_configs import UNITY_SMAL_SHAPE_PRIOR_DOGS as SHAPE_PRIOR from configs.SMAL_configs import MEAN_DOG_BONE_LENGTHS_NO_RED, VERTEX_IDS_TAIL class SmallLinear(nn.Module): def __init__(self, input_size=64, output_size=30, linear_size=128): super(SmallLinear, self).__init__() self.relu = nn.ReLU(inplace=True) self.w1 = nn.Linear(input_size, linear_size) self.w2 = nn.Linear(linear_size, linear_size) self.w3 = nn.Linear(linear_size, output_size) def forward(self, x): # pre-processing y = self.w1(x) y = self.relu(y) y = self.w2(y) y = self.relu(y) y = self.w3(y) return y class MyConv1d(nn.Module): def __init__(self, input_size=37, output_size=30, start=True): super(MyConv1d, self).__init__() self.input_size = input_size self.output_size = output_size self.start = start self.weight = Parameter(torch.ones((self.output_size))) self.bias = Parameter(torch.zeros((self.output_size))) def forward(self, x): # pre-processing if self.start: y = x[:, :self.output_size] else: y = x[:, -self.output_size:] y = y * self.weight[None, :] + self.bias[None, :] return y class ModelShapeAndBreed(nn.Module): def __init__(self, n_betas=10, n_betas_limbs=13, n_breeds=121, n_z=512, structure_z_to_betas='default'): super(ModelShapeAndBreed, self).__init__() self.n_betas = n_betas self.n_betas_limbs = n_betas_limbs # n_betas_logscale self.n_breeds = n_breeds self.structure_z_to_betas = structure_z_to_betas if self.structure_z_to_betas == '1dconv': if not (n_z == self.n_betas+self.n_betas_limbs): raise ValueError # shape branch self.resnet = models.resnet34(pretrained=False) # replace the first layer n_in = 3 + 1 self.resnet.conv1 = nn.Conv2d(n_in, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) # replace the last layer self.resnet.fc = nn.Linear(512, n_z) # softmax self.soft_max = torch.nn.Softmax(dim=1) # fc network (and other versions) to connect z with betas p_dropout = 0.2 if self.structure_z_to_betas == 'default': self.linear_betas = LinearModel(linear_size=1024, num_stage=1, p_dropout=p_dropout, input_size=n_z, output_size=self.n_betas) self.linear_betas_limbs = LinearModel(linear_size=1024, num_stage=1, p_dropout=p_dropout, input_size=n_z, output_size=self.n_betas_limbs) elif self.structure_z_to_betas == 'lin': self.linear_betas = nn.Linear(n_z, self.n_betas) self.linear_betas_limbs = nn.Linear(n_z, self.n_betas_limbs) elif self.structure_z_to_betas == 'fc_0': self.linear_betas = SmallLinear(linear_size=128, # 1024, input_size=n_z, output_size=self.n_betas) self.linear_betas_limbs = SmallLinear(linear_size=128, # 1024, input_size=n_z, output_size=self.n_betas_limbs) elif structure_z_to_betas == 'fc_1': self.linear_betas = LinearModel(linear_size=64, # 1024, num_stage=1, p_dropout=0, input_size=n_z, output_size=self.n_betas) self.linear_betas_limbs = LinearModel(linear_size=64, # 1024, num_stage=1, p_dropout=0, input_size=n_z, output_size=self.n_betas_limbs) elif self.structure_z_to_betas == '1dconv': self.linear_betas = MyConv1d(n_z, self.n_betas, start=True) self.linear_betas_limbs = MyConv1d(n_z, self.n_betas_limbs, start=False) elif self.structure_z_to_betas == 'inn': self.linear_betas_and_betas_limbs = INNForShape(self.n_betas, self.n_betas_limbs, betas_scale=1.0, betas_limbs_scale=1.0) else: raise ValueError # network to connect latent shape vector z with dog breed classification self.linear_breeds = LinearModel(linear_size=1024, # 1024, num_stage=1, p_dropout=p_dropout, input_size=n_z, output_size=self.n_breeds) # shape multiplicator self.shape_multiplicator_np = np.ones(self.n_betas) with open(SHAPE_PRIOR, 'rb') as file: u = pkl._Unpickler(file) u.encoding = 'latin1' res = u.load() # shape predictions are centered around the mean dog of our dog model self.betas_mean_np = res['dog_cluster_mean'] def forward(self, img, seg_raw=None, seg_prep=None): # img is the network input image # seg_raw is before softmax and subtracting 0.5 # seg_prep would be the prepared_segmentation if seg_prep is None: seg_prep = self.soft_max(seg_raw)[:, 1:2, :, :] - 0.5 input_img_and_seg = torch.cat((img, seg_prep), axis=1) res_output = self.resnet(input_img_and_seg) dog_breed_output = self.linear_breeds(res_output) if self.structure_z_to_betas == 'inn': shape_output_orig, shape_limbs_output_orig = self.linear_betas_and_betas_limbs(res_output) else: shape_output_orig = self.linear_betas(res_output) * 0.1 betas_mean = torch.tensor(self.betas_mean_np).float().to(img.device) shape_output = shape_output_orig + betas_mean[None, 0:self.n_betas] shape_limbs_output_orig = self.linear_betas_limbs(res_output) shape_limbs_output = shape_limbs_output_orig * 0.1 output_dict = {'z': res_output, 'breeds': dog_breed_output, 'betas': shape_output_orig, 'betas_limbs': shape_limbs_output_orig} return output_dict class LearnableShapedirs(nn.Module): def __init__(self, sym_ids_dict, shapedirs_init, n_betas, n_betas_fixed=10): super(LearnableShapedirs, self).__init__() # shapedirs_init = self.smal.shapedirs.detach() self.n_betas = n_betas self.n_betas_fixed = n_betas_fixed self.sym_ids_dict = sym_ids_dict sym_left_ids = self.sym_ids_dict['left'] sym_right_ids = self.sym_ids_dict['right'] sym_center_ids = self.sym_ids_dict['center'] self.n_center = sym_center_ids.shape[0] self.n_left = sym_left_ids.shape[0] self.n_sd = self.n_betas - self.n_betas_fixed # number of learnable shapedirs # get indices to go from half_shapedirs to shapedirs inds_back = np.zeros((3889)) for ind in range(0, sym_center_ids.shape[0]): ind_in_forward = sym_center_ids[ind] inds_back[ind_in_forward] = ind for ind in range(0, sym_left_ids.shape[0]): ind_in_forward = sym_left_ids[ind] inds_back[ind_in_forward] = sym_center_ids.shape[0] + ind for ind in range(0, sym_right_ids.shape[0]): ind_in_forward = sym_right_ids[ind] inds_back[ind_in_forward] = sym_center_ids.shape[0] + sym_left_ids.shape[0] + ind self.register_buffer('inds_back_torch', torch.Tensor(inds_back).long()) # self.smal.shapedirs: (51, 11667) # shapedirs: (3889, 3, n_sd) # shapedirs_half: (2012, 3, n_sd) sd = shapedirs_init[:self.n_betas, :].permute((1, 0)).reshape((-1, 3, self.n_betas)) self.register_buffer('sd', sd) sd_center = sd[sym_center_ids, :, self.n_betas_fixed:] sd_left = sd[sym_left_ids, :, self.n_betas_fixed:] self.register_parameter('learnable_half_shapedirs_c0', torch.nn.Parameter(sd_center[:, 0, :].detach())) self.register_parameter('learnable_half_shapedirs_c2', torch.nn.Parameter(sd_center[:, 2, :].detach())) self.register_parameter('learnable_half_shapedirs_l0', torch.nn.Parameter(sd_left[:, 0, :].detach())) self.register_parameter('learnable_half_shapedirs_l1', torch.nn.Parameter(sd_left[:, 1, :].detach())) self.register_parameter('learnable_half_shapedirs_l2', torch.nn.Parameter(sd_left[:, 2, :].detach())) def forward(self): device = self.learnable_half_shapedirs_c0.device half_shapedirs_center = torch.stack((self.learnable_half_shapedirs_c0, \ torch.zeros((self.n_center, self.n_sd)).to(device), \ self.learnable_half_shapedirs_c2), axis=1) half_shapedirs_left = torch.stack((self.learnable_half_shapedirs_l0, \ self.learnable_half_shapedirs_l1, \ self.learnable_half_shapedirs_l2), axis=1) half_shapedirs_right = torch.stack((self.learnable_half_shapedirs_l0, \ - self.learnable_half_shapedirs_l1, \ self.learnable_half_shapedirs_l2), axis=1) half_shapedirs_tot = torch.cat((half_shapedirs_center, half_shapedirs_left, half_shapedirs_right)) shapedirs = torch.index_select(half_shapedirs_tot, dim=0, index=self.inds_back_torch) shapedirs_complete = torch.cat((self.sd[:, :, :self.n_betas_fixed], shapedirs), axis=2) # (3889, 3, n_sd) shapedirs_complete_prepared = torch.cat((self.sd[:, :, :10], shapedirs), axis=2).reshape((-1, 30)).permute((1, 0)) # (n_sd, 11667) return shapedirs_complete, shapedirs_complete_prepared class ModelImageToBreed(nn.Module): def __init__(self, arch='hg8', n_joints=35, n_classes=20, n_partseg=15, n_keyp=20, n_bones=24, n_betas=10, n_betas_limbs=7, n_breeds=121, image_size=256, n_z=512, thr_keyp_sc=None, add_partseg=True): super(ModelImageToBreed, self).__init__() self.n_classes = n_classes self.n_partseg = n_partseg self.n_betas = n_betas self.n_betas_limbs = n_betas_limbs self.n_keyp = n_keyp self.n_bones = n_bones self.n_breeds = n_breeds self.image_size = image_size self.upsample_seg = True self.threshold_scores = thr_keyp_sc self.n_z = n_z self.add_partseg = add_partseg # ------------------------------ STACKED HOUR GLASS ------------------------------ if arch == 'hg8': self.stacked_hourglass = hg8(pretrained=False, num_classes=self.n_classes, num_partseg=self.n_partseg, upsample_seg=self.upsample_seg, add_partseg=self.add_partseg) else: raise Exception('unrecognised model architecture: ' + arch) # ------------------------------ SHAPE AND BREED MODEL ------------------------------ self.breed_model = ModelShapeAndBreed(n_betas=self.n_betas, n_betas_limbs=self.n_betas_limbs, n_breeds=self.n_breeds, n_z=self.n_z) def forward(self, input_img, norm_dict=None, bone_lengths_prepared=None, betas=None): batch_size = input_img.shape[0] device = input_img.device # ------------------------------ STACKED HOUR GLASS ------------------------------ hourglass_out_dict = self.stacked_hourglass(input_img) last_seg = hourglass_out_dict['seg_final'] last_heatmap = hourglass_out_dict['out_list_kp'][-1] # - prepare keypoints (from heatmap) # normalize predictions -> from logits to probability distribution # last_heatmap_norm = dsnt.spatial_softmax2d(last_heatmap, temperature=torch.tensor(1)) # keypoints = dsnt.spatial_expectation2d(last_heatmap_norm, normalized_coordinates=False) + 1 # (bs, 20, 2) # keypoints_norm = dsnt.spatial_expectation2d(last_heatmap_norm, normalized_coordinates=True) # (bs, 20, 2) keypoints_norm, scores = get_preds_soft(last_heatmap, return_maxval=True, norm_coords=True) if self.threshold_scores is not None: scores[scores>self.threshold_scores] = 1.0 scores[scores<=self.threshold_scores] = 0.0 # ------------------------------ SHAPE AND BREED MODEL ------------------------------ # breed_model takes as input the image as well as the predicted segmentation map # -> we need to split up ModelImageTo3d, such that we can use the silhouette resnet_output = self.breed_model(img=input_img, seg_raw=last_seg) pred_breed = resnet_output['breeds'] # (bs, n_breeds) pred_betas = resnet_output['betas'] pred_betas_limbs = resnet_output['betas_limbs'] small_output = {'keypoints_norm': keypoints_norm, 'keypoints_scores': scores} small_output_reproj = {'betas': pred_betas, 'betas_limbs': pred_betas_limbs, 'dog_breed': pred_breed} return small_output, None, small_output_reproj class ModelImageTo3d_withshape_withproj(nn.Module): def __init__(self, arch='hg8', num_stage_comb=2, num_stage_heads=1, num_stage_heads_pose=1, trans_sep=False, n_joints=35, n_classes=20, n_partseg=15, n_keyp=20, n_bones=24, n_betas=10, n_betas_limbs=6, n_breeds=121, image_size=256, n_z=512, n_segbps=64*2, thr_keyp_sc=None, add_z_to_3d_input=True, add_segbps_to_3d_input=False, add_partseg=True, silh_no_tail=True, fix_flength=False, render_partseg=False, structure_z_to_betas='default', structure_pose_net='default', nf_version=None): super(ModelImageTo3d_withshape_withproj, self).__init__() self.n_classes = n_classes self.n_partseg = n_partseg self.n_betas = n_betas self.n_betas_limbs = n_betas_limbs self.n_keyp = n_keyp self.n_bones = n_bones self.n_breeds = n_breeds self.image_size = image_size self.threshold_scores = thr_keyp_sc self.upsample_seg = True self.silh_no_tail = silh_no_tail self.add_z_to_3d_input = add_z_to_3d_input self.add_segbps_to_3d_input = add_segbps_to_3d_input self.add_partseg = add_partseg assert (not self.add_segbps_to_3d_input) or (not self.add_z_to_3d_input) self.n_z = n_z if add_segbps_to_3d_input: self.n_segbps = n_segbps # 64 self.segbps_model = SegBPS() else: self.n_segbps = 0 self.fix_flength = fix_flength self.render_partseg = render_partseg self.structure_z_to_betas = structure_z_to_betas self.structure_pose_net = structure_pose_net assert self.structure_pose_net in ['default', 'vae', 'normflow'] self.nf_version = nf_version self.register_buffer('betas_zeros', torch.zeros((1, self.n_betas))) self.register_buffer('mean_dog_bone_lengths', torch.tensor(MEAN_DOG_BONE_LENGTHS_NO_RED, dtype=torch.float32)) p_dropout = 0.2 # 0.5 # ------------------------------ SMAL MODEL ------------------------------ self.smal = SMAL(template_name='neutral') # New for rendering without tail f_np = self.smal.faces.detach().cpu().numpy() self.f_no_tail_np = f_np[np.isin(f_np[:,:], VERTEX_IDS_TAIL).sum(axis=1)==0, :] # in theory we could optimize for improved shapedirs, but we do not do that # -> would need to implement regularizations # -> there are better ways than changing the shapedirs self.model_learnable_shapedirs = LearnableShapedirs(self.smal.sym_ids_dict, self.smal.shapedirs.detach(), self.n_betas, 10) # ------------------------------ STACKED HOUR GLASS ------------------------------ if arch == 'hg8': self.stacked_hourglass = hg8(pretrained=False, num_classes=self.n_classes, num_partseg=self.n_partseg, upsample_seg=self.upsample_seg, add_partseg=self.add_partseg) else: raise Exception('unrecognised model architecture: ' + arch) # ------------------------------ SHAPE AND BREED MODEL ------------------------------ self.breed_model = ModelShapeAndBreed(n_betas=self.n_betas, n_betas_limbs=self.n_betas_limbs, n_breeds=self.n_breeds, n_z=self.n_z, structure_z_to_betas=self.structure_z_to_betas) # ------------------------------ LINEAR 3D MODEL ------------------------------ # 3d model -> from image to 3d parameters {2d keypoints from heatmap, pose, trans, flength} self.soft_max = torch.nn.Softmax(dim=1) input_size = self.n_keyp*3 + self.n_bones self.model_3d = LinearModelComplete(linear_size=1024, num_stage_comb=num_stage_comb, num_stage_heads=num_stage_heads, num_stage_heads_pose=num_stage_heads_pose, trans_sep=trans_sep, p_dropout=p_dropout, # 0.5, input_size=input_size, intermediate_size=1024, output_info=None, n_joints=n_joints, n_z=self.n_z, add_z_to_3d_input=self.add_z_to_3d_input, n_segbps=self.n_segbps, add_segbps_to_3d_input=self.add_segbps_to_3d_input, structure_pose_net=self.structure_pose_net, nf_version = self.nf_version) # ------------------------------ RENDERING ------------------------------ self.silh_renderer = SilhRenderer(image_size) def forward(self, input_img, norm_dict=None, bone_lengths_prepared=None, betas=None): batch_size = input_img.shape[0] device = input_img.device # ------------------------------ STACKED HOUR GLASS ------------------------------ hourglass_out_dict = self.stacked_hourglass(input_img) last_seg = hourglass_out_dict['seg_final'] last_heatmap = hourglass_out_dict['out_list_kp'][-1] # - prepare keypoints (from heatmap) # normalize predictions -> from logits to probability distribution # last_heatmap_norm = dsnt.spatial_softmax2d(last_heatmap, temperature=torch.tensor(1)) # keypoints = dsnt.spatial_expectation2d(last_heatmap_norm, normalized_coordinates=False) + 1 # (bs, 20, 2) # keypoints_norm = dsnt.spatial_expectation2d(last_heatmap_norm, normalized_coordinates=True) # (bs, 20, 2) keypoints_norm, scores = get_preds_soft(last_heatmap, return_maxval=True, norm_coords=True) if self.threshold_scores is not None: scores[scores>self.threshold_scores] = 1.0 scores[scores<=self.threshold_scores] = 0.0 # ------------------------------ LEARNABLE SHAPE MODEL ------------------------------ # in our cvpr 2022 paper we do not change the shapedirs # learnable_sd_complete has shape (3889, 3, n_sd) # learnable_sd_complete_prepared has shape (n_sd, 11667) learnable_sd_complete, learnable_sd_complete_prepared = self.model_learnable_shapedirs() shapedirs_sel = learnable_sd_complete_prepared # None # ------------------------------ SHAPE AND BREED MODEL ------------------------------ # breed_model takes as input the image as well as the predicted segmentation map # -> we need to split up ModelImageTo3d, such that we can use the silhouette resnet_output = self.breed_model(img=input_img, seg_raw=last_seg) pred_breed = resnet_output['breeds'] # (bs, n_breeds) pred_z = resnet_output['z'] # - prepare shape pred_betas = resnet_output['betas'] pred_betas_limbs = resnet_output['betas_limbs'] # - calculate bone lengths with torch.no_grad(): use_mean_bone_lengths = False if use_mean_bone_lengths: bone_lengths_prepared = torch.cat(batch_size*[self.mean_dog_bone_lengths.reshape((1, -1))]) else: assert (bone_lengths_prepared is None) bone_lengths_prepared = self.smal.caclulate_bone_lengths(pred_betas, pred_betas_limbs, shapedirs_sel=shapedirs_sel, short=True) # ------------------------------ LINEAR 3D MODEL ------------------------------ # 3d model -> from image to 3d parameters {2d keypoints from heatmap, pose, trans, flength} # prepare input for 2d-to-3d network keypoints_prepared = torch.cat((keypoints_norm, scores), axis=2) if bone_lengths_prepared is None: bone_lengths_prepared = torch.cat(batch_size*[self.mean_dog_bone_lengths.reshape((1, -1))]) # should we add silhouette to 3d input? should we add z? if self.add_segbps_to_3d_input: seg_raw = last_seg seg_prep_bps = self.soft_max(seg_raw)[:, 1, :, :] # class 1 is the dog with torch.no_grad(): seg_prep_np = seg_prep_bps.detach().cpu().numpy() bps_output_np = self.segbps_model.calculate_bps_points_batch(seg_prep_np) # (bs, 64, 2) bps_output = torch.tensor(bps_output_np, dtype=torch.float32).to(device).reshape((batch_size, -1)) bps_output_prep = bps_output * 2. - 1 input_vec_keyp_bones = torch.cat((keypoints_prepared.reshape((batch_size, -1)), bone_lengths_prepared), axis=1) input_vec = torch.cat((input_vec_keyp_bones, bps_output_prep), dim=1) elif self.add_z_to_3d_input: # we do not use this in our cvpr 2022 version input_vec_keyp_bones = torch.cat((keypoints_prepared.reshape((batch_size, -1)), bone_lengths_prepared), axis=1) input_vec_additional = pred_z input_vec = torch.cat((input_vec_keyp_bones, input_vec_additional), dim=1) else: input_vec = torch.cat((keypoints_prepared.reshape((batch_size, -1)), bone_lengths_prepared), axis=1) # predict 3d parameters (those are normalized, we need to correct mean and std in a next step) output = self.model_3d(input_vec) # add predicted keypoints to the output dict output['keypoints_norm'] = keypoints_norm output['keypoints_scores'] = scores # - denormalize 3d parameters -> so far predictions were normalized, now we denormalize them again pred_trans = output['trans'] * norm_dict['trans_std'][None, :] + norm_dict['trans_mean'][None, :] # (bs, 3) if self.structure_pose_net == 'default': pred_pose_rot6d = output['pose'] + norm_dict['pose_rot6d_mean'][None, :] elif self.structure_pose_net == 'normflow': pose_rot6d_mean_zeros = torch.zeros_like(norm_dict['pose_rot6d_mean'][None, :]) pose_rot6d_mean_zeros[:, 0, :] = norm_dict['pose_rot6d_mean'][None, 0, :] pred_pose_rot6d = output['pose'] + pose_rot6d_mean_zeros else: pose_rot6d_mean_zeros = torch.zeros_like(norm_dict['pose_rot6d_mean'][None, :]) pose_rot6d_mean_zeros[:, 0, :] = norm_dict['pose_rot6d_mean'][None, 0, :] pred_pose_rot6d = output['pose'] + pose_rot6d_mean_zeros pred_pose_reshx33 = rot6d_to_rotmat(pred_pose_rot6d.reshape((-1, 6))) pred_pose = pred_pose_reshx33.reshape((batch_size, -1, 3, 3)) pred_pose_rot6d = rotmat_to_rot6d(pred_pose_reshx33).reshape((batch_size, -1, 6)) if self.fix_flength: output['flength'] = torch.zeros_like(output['flength']) pred_flength = torch.ones_like(output['flength'])*2100 # norm_dict['flength_mean'][None, :] else: pred_flength_orig = output['flength'] * norm_dict['flength_std'][None, :] + norm_dict['flength_mean'][None, :] # (bs, 1) pred_flength = pred_flength_orig.clone() # torch.abs(pred_flength_orig) pred_flength[pred_flength_orig<=0] = norm_dict['flength_mean'][None, :] # ------------------------------ RENDERING ------------------------------ # get 3d model (SMAL) V, keyp_green_3d, _ = self.smal(beta=pred_betas, betas_limbs=pred_betas_limbs, pose=pred_pose, trans=pred_trans, get_skin=True, keyp_conf='green', shapedirs_sel=shapedirs_sel) keyp_3d = keyp_green_3d[:, :self.n_keyp, :] # (bs, 20, 3) # render silhouette faces_prep = self.smal.faces.unsqueeze(0).expand((batch_size, -1, -1)) if not self.silh_no_tail: pred_silh_images, pred_keyp = self.silh_renderer(vertices=V, points=keyp_3d, faces=faces_prep, focal_lengths=pred_flength) else: faces_no_tail_prep = torch.tensor(self.f_no_tail_np).to(device).expand((batch_size, -1, -1)) pred_silh_images, pred_keyp = self.silh_renderer(vertices=V, points=keyp_3d, faces=faces_no_tail_prep, focal_lengths=pred_flength) # get torch 'Meshes' torch_meshes = self.silh_renderer.get_torch_meshes(vertices=V, faces=faces_prep) # render body parts (not part of cvpr 2022 version) if self.render_partseg: raise NotImplementedError else: partseg_images = None partseg_images_hg = None # ------------------------------ PREPARE OUTPUT ------------------------------ # create output dictionarys # output: contains all output from model_image_to_3d # output_unnorm: same as output, but normalizations are undone # output_reproj: smal output and reprojected keypoints as well as silhouette keypoints_heatmap_256 = (output['keypoints_norm'] / 2. + 0.5) * (self.image_size - 1) output_unnorm = {'pose_rotmat': pred_pose, 'flength': pred_flength, 'trans': pred_trans, 'keypoints':keypoints_heatmap_256} output_reproj = {'vertices_smal': V, 'torch_meshes': torch_meshes, 'keyp_3d': keyp_3d, 'keyp_2d': pred_keyp, 'silh': pred_silh_images, 'betas': pred_betas, 'betas_limbs': pred_betas_limbs, 'pose_rot6d': pred_pose_rot6d, # used for pose prior... 'dog_breed': pred_breed, 'shapedirs': shapedirs_sel, 'z': pred_z, 'flength_unnorm': pred_flength, 'flength': output['flength'], 'partseg_images_rend': partseg_images, 'partseg_images_hg_nograd': partseg_images_hg, 'normflow_z': output['normflow_z']} return output, output_unnorm, output_reproj def render_vis_nograd(self, vertices, focal_lengths, color=0): # this function is for visualization only # vertices: (bs, n_verts, 3) # focal_lengths: (bs, 1) # color: integer, either 0 or 1 # returns a torch tensor of shape (bs, image_size, image_size, 3) with torch.no_grad(): batch_size = vertices.shape[0] faces_prep = self.smal.faces.unsqueeze(0).expand((batch_size, -1, -1)) visualizations = self.silh_renderer.get_visualization_nograd(vertices, faces_prep, focal_lengths, color=color) return visualizations