import torch import copy import torch.nn as nn import torch.nn.functional as F import os from dust3r.utils.geometry import inv, geotrf, normalize_pointcloud, closed_form_inverse from mast3r.catmlp_dpt_head import mast3r_head_factory from mast3r.vgg_pose_head import CameraPredictor, CameraPredictor_clean, Mlp from mast3r.shallow_cnn import FeatureNet import mast3r.utils.path_to_dust3r # noqa from dust3r.model import AsymmetricCroCo3DStereo # noqa from dust3r.utils.misc import transpose_to_landscape, freeze_all_params # noqa inf = float('inf') from dust3r.patch_embed import get_patch_embed from torch.utils.checkpoint import checkpoint from pytorch3d.transforms.rotation_conversions import matrix_to_quaternion def load_model(model_path, device, verbose=True): if verbose: print('... loading model from', model_path) ckpt = torch.load(model_path, map_location='cpu') args = ckpt['args'].model.replace("ManyAR_PatchEmbed", "PatchEmbedDust3R") if 'landscape_only' not in args: args = args[:-1] + ', landscape_only=False)' else: args = args.replace(" ", "").replace('landscape_only=True', 'landscape_only=False') assert "landscape_only=False" in args if verbose: print(f"instantiating : {args}") net = eval(args) s = net.load_state_dict(ckpt['model'], strict=False) if verbose: print(s) return net.to(device) import torch def modulate(x, shift, scale): return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) class AsymmetricMASt3R(AsymmetricCroCo3DStereo): def __init__(self, wpose=False, wogs=True, desc_mode=('norm'), two_confs=False, desc_conf_mode=None, **kwargs): self.desc_mode = desc_mode self.two_confs = two_confs self.desc_conf_mode = desc_conf_mode self.wogs = wogs self.wpose = wpose super().__init__(**kwargs) # Global Geometry Projector self.dec_blocks_point = copy.deepcopy(self.dec_blocks_fine) self.cam_cond_encoder_point = copy.deepcopy(self.cam_cond_encoder) self.decoder_embed_point = copy.deepcopy(self.decoder_embed) self.dec_norm_point = copy.deepcopy(self.dec_norm) self.pose_token_ref_point = copy.deepcopy(self.pose_token_ref) self.pose_token_source_point = copy.deepcopy(self.pose_token_source) self.cam_cond_embed_point = copy.deepcopy(self.cam_cond_embed) self.cam_cond_embed_point_pre = copy.deepcopy(self.cam_cond_embed) self.inject_stage3 = nn.ModuleList([nn.Linear(self.enc_embed_dim, self.dec_embed_dim, bias=False) for i in range(3)]) self.enc_inject_stage3 = nn.ModuleList([copy.deepcopy(self.enc_norm) for i in range(3)]) # Camera-centric Geometry Estimator self.cam_cond_encoder_fine = copy.deepcopy(self.cam_cond_encoder) self.adaLN_modulation = nn.ModuleList([nn.Sequential( nn.SiLU(inplace=False), nn.Linear(self.dec_embed_dim, 3 * self.dec_embed_dim, bias=True) ) for _ in range(len(self.dec_blocks_fine))]) for block in self.adaLN_modulation: nn.init.constant_(block[-1].weight, 0) nn.init.constant_(block[-1].bias, 0) self.decoder_embed_fine = copy.deepcopy(self.decoder_embed) self.dec_cam_norm_fine = copy.deepcopy(self.dec_cam_norm) self.dec_norm_fine = copy.deepcopy(self.dec_norm) self.pose_token_ref_fine = copy.deepcopy(self.pose_token_ref) self.pose_token_source_fine = copy.deepcopy(self.pose_token_source) self.cam_cond_embed_fine = copy.deepcopy(self.cam_cond_embed) self.inject_stage2 = nn.ModuleList([nn.Linear(self.enc_embed_dim, self.dec_embed_dim, bias=False) for i in range(3)]) self.enc_inject_stage2 = nn.ModuleList([copy.deepcopy(self.enc_norm) for i in range(3)]) # Encoder self.enc_norm_coarse = copy.deepcopy(self.enc_norm) self.embed_pose = Mlp(7, hidden_features=self.dec_embed_dim, out_features=self.dec_embed_dim) # Shallow CNNs self.cnn_wobn = FeatureNet() self.cnn_proj = nn.Conv2d(64, 16, 3, 1, 1) self.cnn_fusion = nn.Conv2d(32*3, 64, 3, 1, 1) for i in range(3): nn.init.constant_(self.inject_stage2[i].weight, 0.) nn.init.constant_(self.inject_stage3[i].weight, 0.) self.idx_hook = [2, 5, 8] self.encode_feature_landscape = transpose_to_landscape(self.encode_feature, activate=True) if self.wogs == False: self.decoder_embed_stage2 = copy.deepcopy(self.decoder_embed) nn.init.constant_(self.decoder_embed_stage2.weight, 0.) self.decoder_embed_fxfycxcy = Mlp(4, hidden_features=self.dec_embed_dim, out_features=self.dec_embed_dim) nn.init.constant_(self.decoder_embed_fxfycxcy.fc2.weight, 0.) nn.init.constant_(self.decoder_embed_fxfycxcy.fc2.bias, 0.) def load_state_dict_stage1(self, ckpt, **kw): # duplicate all weights for the second decoder if not present new_ckpt = dict(ckpt) return super().load_state_dict(new_ckpt, **kw) def load_state_dict(self, ckpt, **kw): # duplicate all weights for the second decoder if not present new_ckpt = dict(ckpt) if self.head_type == 'dpt_gs': for key, value in ckpt.items(): if 'dpt.head.4' in key: state_dict = self.state_dict() state_dict[key][:value.shape[0]] = value new_ckpt[key] = state_dict[key] return super().load_state_dict(new_ckpt, **kw) def encode_feature(self, imgs_vgg, image_size): H, W = image_size imgs_vgg = imgs_vgg[0].permute(0,3,1,2) feat_vgg3, feat_vgg2, feat_vgg1 = self.cnn_wobn(imgs_vgg) feat_vgg2 = F.interpolate(feat_vgg2.float(), (H, W), mode='bilinear', align_corners=True) feat_vgg3 = F.interpolate(feat_vgg3.float(), (H, W), mode='bilinear', align_corners=True) feat_vgg = self.cnn_fusion(torch.cat((feat_vgg1.float(), feat_vgg2, feat_vgg3), 1)) feat_vgg_detail = self.cnn_proj(feat_vgg) N, C, h, w = feat_vgg.shape imgs_vgg = feat_vgg.reshape(N, C, -1).permute(0,2,1) N, P, C = imgs_vgg.shape imgs_vgg = imgs_vgg.reshape(N, P, -1, 64) imgs_vgg = imgs_vgg.permute(0, 2, 1, 3) x = torch.arange(w).to(imgs_vgg) y = torch.arange(h).to(imgs_vgg) xy = torch.meshgrid(x, y, indexing='xy') pos_full = torch.cat((xy[0].unsqueeze(-1), xy[1].unsqueeze(-1)), -1).unsqueeze(0) imgs_vgg = imgs_vgg + self.rope(torch.ones_like(imgs_vgg).to(imgs_vgg), pos_full.reshape(1,-1,2).repeat(N, 1, 1).long()).to(imgs_vgg) imgs_vgg = imgs_vgg.permute(0, 2, 1, 3) imgs_vgg = imgs_vgg.reshape(N, -1, C).permute(0, 2, 1) imgs_vgg = imgs_vgg.reshape(N, C, h, w) return {'imgs_vgg': imgs_vgg.permute(0, 2, 3, 1), 'feat_vgg_detail': feat_vgg_detail.permute(0, 2, 3, 1)} @classmethod def from_pretrained(cls, pretrained_model_name_or_path, **kw): if os.path.isfile(pretrained_model_name_or_path): return load_model(pretrained_model_name_or_path, device='cpu') else: return super(AsymmetricMASt3R, cls).from_pretrained(pretrained_model_name_or_path, **kw) def _encode_image(self, image, true_shape): # embed the image into patches (x has size B x Npatches x C) interm_features = [] x, pos = self.patch_embed(image, true_shape=true_shape) # add positional embedding without cls token assert self.enc_pos_embed is None # now apply the transformer encoder and normalization for blk in self.enc_blocks: interm_features.append(x) x = blk(x, pos) x = self.enc_norm(x) return x, pos, interm_features def _encode_symmetrized(self, views): imgs = [view['img'] for view in views] shapes = [view['true_shape'] for view in views] imgs = torch.stack((imgs), dim=1) B, views, _, H, W = imgs.shape dtype = imgs.dtype imgs = imgs.view(-1, *imgs.shape[2:]) shapes = torch.stack((shapes), dim=1) shapes = shapes.view(-1, *shapes.shape[2:]) out, pos, interm_features = self._encode_image(imgs, shapes) out = out.to(dtype) for i in range(len(interm_features)): interm_features[i] = interm_features[i].to(dtype) interm_features[i] = interm_features[i].reshape(B, views, *out.shape[1:]) true_shape = shapes W //= 64 H //= 64 n_tokens = H * W x_coarse = out.new_zeros((B*views, n_tokens, self.patch_embed_coarse2.embed_dim)).to(dtype) pos_coarse = out.new_zeros((B*views, n_tokens, 2), dtype=torch.int64) height, width = true_shape.T is_landscape = (width >= height) is_portrait = ~is_landscape fine_token = out.view(B*views, H * 4, W * 4, -1).permute(0, 3, 1, 2) x_coarse[is_landscape] = self.patch_embed_coarse2.proj(fine_token[is_landscape]).permute(0, 2, 3, 1).flatten(1, 2) x_coarse[is_portrait] = self.patch_embed_coarse2.proj(fine_token[is_portrait].swapaxes(-1, -2)).permute(0, 2, 3, 1).flatten(1, 2) pos_coarse[is_landscape] = self.patch_embed_test_.position_getter(1, H, W, pos.device) pos_coarse[is_portrait] = self.patch_embed_test_.position_getter(1, W, H, pos.device) x_coarse = self.enc_norm_coarse(x_coarse) out_coarse = x_coarse.reshape(B, views, *x_coarse.shape[1:]).to(dtype) pos_coarse = pos_coarse.reshape(B, views, *pos_coarse.shape[1:]) shapes_coarse = shapes.reshape(B, views, *shapes.shape[1:]) // 4 out = out.reshape(B, views, *out.shape[1:]) pos = pos.reshape(B, views, *pos.shape[1:]) shapes = shapes.reshape(B, views, *shapes.shape[1:]) return shapes_coarse, out_coarse, pos_coarse, shapes, out, pos, interm_features def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768): self.patch_embed = get_patch_embed(self.patch_embed_cls, img_size, patch_size, enc_embed_dim) self.patch_embed_coarse = get_patch_embed(self.patch_embed_cls, img_size, 2, enc_embed_dim, input_dim=enc_embed_dim) self.patch_embed_coarse2 = get_patch_embed(self.patch_embed_cls, img_size, 4, enc_embed_dim, input_dim=enc_embed_dim) # self.patch_embed_test = get_patch_embed(self.patch_embed_cls, img_size, 2 * patch_size, enc_embed_dim) self.patch_embed_test_ = get_patch_embed(self.patch_embed_cls, img_size, 4 * patch_size, enc_embed_dim) # self.patch_embed_fine = get_patch_embed(self.patch_embed_cls, img_size, patch_size, enc_embed_dim, input_dim=64) def set_downstream_head(self, output_mode, head_type, landscape_only, depth_mode, conf_mode, patch_size, img_size, **kw): assert img_size[0] % patch_size == 0 and img_size[ 1] % patch_size == 0, f'{img_size=} must be multiple of {patch_size=}' self.output_mode = output_mode self.head_type = head_type self.depth_mode = depth_mode self.conf_mode = conf_mode if self.desc_conf_mode is None: self.desc_conf_mode = conf_mode # allocate heads self.downstream_head1 = mast3r_head_factory(head_type, output_mode, self, has_conf=bool(conf_mode)) self.downstream_head2 = mast3r_head_factory(head_type, output_mode, self, has_conf=bool(conf_mode)) # magic wrapper self.head1 = transpose_to_landscape(self.downstream_head1, activate=landscape_only) self.head2 = transpose_to_landscape(self.downstream_head2, activate=landscape_only) self.pose_head = CameraPredictor_clean(hood_idx=self.downstream_head2.dpt.hooks, trunk_depth=4, rope=self.rope) self.pose_head_stage2 = CameraPredictor_clean(hood_idx=self.downstream_head2.dpt.hooks, trunk_depth=4, rope=self.rope) self.downstream_head4 = mast3r_head_factory('sh', output_mode, self, has_conf=bool(conf_mode)) self.head4 = transpose_to_landscape(self.downstream_head4, activate=landscape_only) def _decoder_stage2(self, f1, pos1, f2, pos2, pose1, pose2, low_token=None): f = torch.cat((f1, f2), 1) pos = torch.cat((pos1, pos2), 1) final_output = [f] # before projection f = self.decoder_embed_fine(f) B, views, P, C = f.shape f = f.view(B, -1 ,C) pos = pos.view(B, -1, pos.shape[-1]) cam_tokens = [] final_output.append(f) pose1_embed = self.embed_pose(pose1) pose2_embed = self.embed_pose(pose2) pose_embed = torch.cat((pose1_embed, pose2_embed), 1) views = views - 1 pose_token_ref, pose_token_source = self.pose_token_ref_fine.to(f1.dtype).repeat(B,1,1).view(B, -1, C), self.pose_token_source_fine.to(f1.dtype).repeat(B*views,1,1).view(B*views, -1, C) dtype = f.dtype hook_idx = 0 for i, (blk1, cam_cond, cam_cond_embed_fine, adaLN_modulation) in enumerate(zip(self.dec_blocks_fine, self.cam_cond_encoder_fine, self.cam_cond_embed_fine, self.adaLN_modulation)): shift_msa, scale_msa, gate_msa = adaLN_modulation(pose_embed).chunk(3, dim=-1) pose_token_ref = modulate(pose_token_ref.reshape(B, -1, C), shift_msa[:,:1].reshape(B,-1), scale_msa[:,:1].reshape(B,-1)) pose_token_source = modulate(pose_token_source.reshape(B*views, -1, C), shift_msa[:,1:].reshape(B*views,-1), scale_msa[:,1:].reshape(B*views,-1)) feat = checkpoint(blk1, f, pos) feat = feat.view(B, views+1, -1, C) f1 = feat[:,:1].view(B, -1, C) f2 = feat[:,1:].reshape(B*views, -1, C) f1_cam = torch.cat((pose_token_ref, f1.view(B, -1, C)), 1) f2_cam = torch.cat((pose_token_source, f2.view(B*views, -1, C)), 1) f_cam = torch.cat((f1_cam, f2_cam), 0) f_cam = checkpoint(cam_cond, f_cam) # torch.Size([64, 769, 768]) f_delta = f_cam[:,1:] f_cam = f_cam[:,:1] f_delta1 = f_delta[:B].view(B, -1, C) f_delta2 = f_delta[B:].view(B*views, -1, C) pose_token_ref = pose_token_ref.view(B, -1, C) + f_cam[:B].view(B, -1, C) pose_token_source = pose_token_source.view(B*views, -1, C) + f_cam[B:].view(B*views, -1, C) cam_tokens.append((pose_token_ref, pose_token_source)) f1 = f1.view(B, -1, C) + cam_cond_embed_fine(f_delta1) f2 = f2.view(B*views, -1, C) + cam_cond_embed_fine(f_delta2) if i in self.idx_hook: f1 = f1.view(B, -1, C) + self.inject_stage2[hook_idx](self.enc_inject_stage2[hook_idx](low_token[i * 2][:,:1].view(B, -1, 1024))) f2 = f2.view(B*views, -1, C) + self.inject_stage2[hook_idx](self.enc_inject_stage2[hook_idx](low_token[i * 2][:,1:].reshape(B*views, -1, 1024))) hook_idx += 1 f1 = f1.view(B, 1, -1 ,C) f2 = f2.view(B, views, -1 ,C) f = torch.cat((f1, f2), 1) final_output.append(f) f = f.view(B, -1 ,C) # normalize last output del final_output[1] # duplicate with final_output[0] final_output[-1] = self.dec_norm_fine(final_output[-1]) cam_tokens[-1] = tuple(map(self.dec_cam_norm_fine, cam_tokens[-1])) return final_output, zip(*cam_tokens) def _decoder_stage3(self, feat_ref, pos1, pos2, pose1, pose2, low_token=None, feat_stage2=None, fxfycxcy1=None, fxfycxcy2=None): final_output = [feat_ref[0]] # before projection # project to decoder dim final_output.append(feat_ref[1]) with torch.cuda.amp.autocast(enabled=False,dtype=torch.float32): pose1_embed = self.embed_pose(pose1) pose2_embed = self.embed_pose(pose2) pose_embed = torch.cat((pose1_embed, pose2_embed), 1) B, views, P, C = feat_ref[-1].shape if feat_stage2 is None: f = self.decoder_embed_point(feat_ref[0]) else: f = self.decoder_embed_point(feat_ref[0]) + self.decoder_embed_stage2(feat_stage2) views = views - 1 dtype = f.dtype pose_token_ref, pose_token_source = self.pose_token_ref_point.to(dtype).repeat(B,1,1).view(B, -1, C), self.pose_token_source_point.to(dtype).repeat(B*views,1,1).view(B*views, -1, C) pos = torch.cat((pos1, pos2), 1) if fxfycxcy1 is not None: with torch.cuda.amp.autocast(enabled=False,dtype=torch.float32): fxfycxcy1 = self.decoder_embed_fxfycxcy(fxfycxcy1) fxfycxcy2 = self.decoder_embed_fxfycxcy(fxfycxcy2) pose1_embed = pose1_embed + fxfycxcy1 pose2_embed = pose2_embed + fxfycxcy2 pose1_embed = pose1_embed.to(dtype) pose2_embed = pose2_embed.to(dtype) pose_token_ref = pose_token_ref + pose1_embed pose_token_source = pose_token_source + pose2_embed.view(B*views, -1, C) hook_idx = 0 for i, (blk, blk_cross, cam_cond, cam_cond_embed_point, cam_cond_embed_point_pre) in enumerate(zip(self.dec_blocks_point, self.dec_blocks_point_cross, self.cam_cond_encoder_point, self.cam_cond_embed_point, self.cam_cond_embed_point_pre)): f1_pre = feat_ref[i+1].reshape(B, (views+1), -1, C)[:,:1].view(B, -1, C) f2_pre = feat_ref[i+1].reshape(B, (views+1), -1, C)[:,1:].reshape(B*views, -1, C) f1_pre = f1_pre + cam_cond_embed_point_pre(pose_token_ref) f2_pre = f2_pre + cam_cond_embed_point_pre(pose_token_source) f_pre = torch.cat((f1_pre.view(B, 1, -1, C), f2_pre.view(B, views, -1, C)), 1) feat, _ = checkpoint(blk_cross, f.reshape(B*(views+1), -1, C), f_pre.reshape(B*(views+1), -1, C), pos.reshape(B*(views+1), -1, 2), pos.reshape(B*(views+1), -1, 2)) feat = feat.view(B, views+1, -1, C).reshape(B, -1, C) feat = checkpoint(blk, feat, pos.reshape(B, -1, 2)) feat = feat.view(B, views+1, -1, C) f1 = feat[:,:1].view(B, -1, C) f2 = feat[:,1:].reshape(B*views, -1, C) f1_cam = torch.cat((pose_token_ref, f1.view(B, -1, C)), 1) f2_cam = torch.cat((pose_token_source, f2.view(B*views, -1, C)), 1) f_cam = torch.cat((f1_cam, f2_cam), 0) f_cam = checkpoint(cam_cond, f_cam) # torch.Size([64, 769, 768]) f_delta = f_cam[:,1:] f_cam = f_cam[:,:1] f_delta1 = f_delta[:B].view(B, -1, C) f_delta2 = f_delta[B:].view(B*views, -1, C) pose_token_ref = pose_token_ref.view(B, -1, C) + f_cam[:B].view(B, -1, C) pose_token_source = pose_token_source.view(B*views, -1, C) + f_cam[B:].view(B*views, -1, C) f1 = f1.view(B, -1, C) + cam_cond_embed_point(f_delta1) f2 = f2.view(B*views, -1, C) + cam_cond_embed_point(f_delta2) if i in self.idx_hook: f1 = f1.view(B, -1, C) + self.inject_stage3[hook_idx](self.enc_inject_stage3[hook_idx](low_token[i * 2][:,:1].view(B, -1, 1024))) f2 = f2.view(B*views, -1, C) + self.inject_stage3[hook_idx](self.enc_inject_stage3[hook_idx](low_token[i * 2][:,1:].reshape(B*views, -1, 1024))) hook_idx += 1 f1 = f1.view(B, 1, -1 ,C) f2 = f2.view(B, views, -1 ,C) f = torch.cat((f1, f2), 1) final_output.append(f) f = f.view(B, -1 ,C) # normalize last output del final_output[1] # duplicate with final_output[0] final_output[-1] = self.dec_norm_point(final_output[-1]) return final_output def _decoder(self, f1, pos1, f2, pos2): final_output = [(f1, f2)] # before projection # project to decoder dim f1 = self.decoder_embed(f1) f2 = self.decoder_embed(f2) B, views, P, C = f2.shape f1 = f1.view(B, -1 ,C) f2 = f2.view(B, -1 ,C) pos1 = pos1.view(B, -1, pos1.shape[-1]) pos2 = pos2.view(B, -1, pos2.shape[-1]) cam_tokens = [] final_output.append((f1, f2)) pose_token_ref, pose_token_source = self.pose_token_ref.to(f1.dtype).repeat(B,1,1).view(B, -1, C), self.pose_token_source.to(f1.dtype).repeat(B*views,1,1).view(B*views, -1, C) for i, (blk1, blk2, cam_cond, cam_cond_embed) in enumerate(zip(self.dec_blocks, self.dec_blocks2, self.cam_cond_encoder, self.cam_cond_embed)): f1, _ = checkpoint(blk1, *final_output[-1][::+1], pos1, pos2) f2, _ = checkpoint(blk2, *final_output[-1][::-1], pos2, pos1) f1_cam = torch.cat((pose_token_ref, f1.view(B, -1, C)), 1) f2_cam = torch.cat((pose_token_source, f2.view(B*views, -1, C)), 1) f_cam = torch.cat((f1_cam, f2_cam), 0) f_cam = checkpoint(cam_cond, f_cam) f_delta = f_cam[:,1:] f_cam = f_cam[:,:1] f_delta1 = f_delta[:B].view(B, -1, C) f_delta2 = f_delta[B:].view(B*views, -1, C) pose_token_ref = pose_token_ref.view(B, -1, C) + f_cam[:B].view(B, -1, C) pose_token_source = pose_token_source.view(B*views, -1, C) + f_cam[B:].view(B*views, -1, C) cam_tokens.append((pose_token_ref, pose_token_source)) f1 = f1.view(B, -1, C) + cam_cond_embed(f_delta1) f2 = f2.view(B*views, -1, C) + cam_cond_embed(f_delta2) f1 = f1.view(B, -1 ,C) f2 = f2.view(B, -1 ,C) # store the result final_output.append((f1, f2)) # normalize last output del final_output[1] # duplicate with final_output[0] cam_tokens[-1] = tuple(map(self.dec_cam_norm, cam_tokens[-1])) return zip(*cam_tokens) def forward_coarse_pose(self, view1, view2, enabled=True, dtype=torch.bfloat16): # encode the two images --> B,S,D batch_size, _, _, _ = view1[0]['img'].shape view_num = len(view2) with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16): shapes, feat, pos, shape_stage2, feat_stage2, pos_stage2, interm_features = self._encode_symmetrized(view1+view2) # shapes_coarse, out_coarse, pos_coarse, shapes, out, pos, interm_features feat1 = feat[:, :1].to(dtype) feat2 = feat[:, 1:].to(dtype) pos1 = pos[:, :1] pos2 = pos[:, 1:] shape1 = shapes[:, :1] shape2 = shapes[:, 1:] shape1_stage2 = shape_stage2[:, :1] shape2_stage2 = shape_stage2[:, 1:] feat1_stage2 = feat_stage2[:, :1] feat2_stage2 = feat_stage2[:, 1:] pos1_stage2 = pos_stage2[:, :1] pos2_stage2 = pos_stage2[:, 1:] (pose_token1, pose_token2) = self._decoder(feat1, pos1, feat2, pos2) pred_cameras, _ = self.pose_head(batch_size, interm_feature1=pose_token1, interm_feature2=pose_token2, enabled=True, dtype=dtype) return feat1_stage2, pos1_stage2, feat2_stage2, pos2_stage2, pred_cameras, shape1_stage2, shape2_stage2, None, None, pose_token1, pose_token2, interm_features def forward(self, view1, view2, enabled=True, dtype=torch.bfloat16): if self.wogs: res1, res2, pred_cameras = self.forward_pointmap(view1, view2, enabled=enabled, dtype=dtype) else: res1, res2, pred_cameras = self.forward_gs(view1, view2, enabled=enabled, dtype=dtype) return res1, res2, pred_cameras def forward_gs(self, view1, view2, enabled=True, dtype=torch.bfloat16): raise NotImplementedError("This feature (novel view synthesis) has not been released yet.") def load_state_dict_posehead(self, ckpt, strict=True): # duplicate all weights for the second decoder if not present new_ckpt = {} for key, value in ckpt.items(): new_key = '.'.join([key.split('.')[0] + '_pose_head']+ key.split('.')[1:]) if new_key in dict(self.named_parameters()).keys(): print(f'Loading {new_key} from checkpoint') new_ckpt[new_key] = ckpt[key] return self.load_state_dict(new_ckpt, strict=strict) def forward_pointmap(self, view1, view2, enabled=True, dtype=torch.bfloat16): # encode the two images --> B,S,D batch_size, _, _, _ = view1[0]['img'].shape view_num = len(view2) # coarse camera pose estimation feat1, pos1, feat2, pos2, pred_cameras_coarse, shape1, shape2, res1_stage1, res2_stage1, pose_token1, pose_token2, interm_features = self.forward_coarse_pose(view1, view2, enabled=enabled, dtype=dtype) if self.wpose == False: trans = pred_cameras_coarse[-1]['T'].float().detach().clone() trans = trans.reshape(batch_size, -1, 3) quaternion_R_pred = pred_cameras_coarse[-1]['quaternion_R'].reshape(batch_size, -1, 4).float().detach().clone() else: ref_camera_pose = torch.cat([view['camera_pose'] for view in view1], 0).double() trajectory = torch.cat([view['camera_pose'] for view in view1 + view2], 0).double() in_camera1 = closed_form_inverse(ref_camera_pose) trajectory = torch.bmm(in_camera1.repeat(trajectory.shape[0],1,1), trajectory) quaternion_R_pred = matrix_to_quaternion(trajectory[:, :3, :3]).float().reshape(batch_size, -1, 4) trans = trajectory[:, :3, 3].float().reshape(batch_size, -1, 3) gt_quaternion_R = quaternion_R_pred gt_trans = trans size = (trans.norm(dim=-1, keepdim=True).mean(dim=-2, keepdim=True) + 1e-8) trans_pred = trans / size camera_embed = torch.cat((quaternion_R_pred, trans_pred), -1) camera_embed1 = camera_embed[:, :1].to(dtype) camera_embed2 = camera_embed[:, 1:].to(dtype) # fine camera pose estimation + camera-centric geometry estimation dec_fine, (pose_token1_fine, pose_token2_fine) = self._decoder_stage2(feat1, pos1, feat2, pos2, camera_embed1, camera_embed2, interm_features) shape = torch.cat((shape1, shape2), 1) res1 = self._downstream_head(1, [tok.to(dtype).reshape(-1, tok.shape[-2], tok.shape[-1]) for tok in dec_fine], shape.reshape(-1, 2)) res1.pop('desc') for key in res1.keys(): res1[key] = res1[key].unflatten(0, (batch_size, view_num+1)).float() with torch.cuda.amp.autocast(enabled=False, dtype=torch.float32): pred_cameras, _ = self.pose_head_stage2(batch_size, interm_feature1=pose_token1_fine, interm_feature2=pose_token2_fine, enabled=True, dtype=torch.float32) if self.wpose == False: trans = pred_cameras[-1]['T'].float().detach().clone() quaternion_R_pred = pred_cameras[-1]['quaternion_R'].reshape(batch_size, -1, 4).float().detach().clone() else: quaternion_R_pred = gt_quaternion_R trans = gt_trans size = (trans.norm(dim=-1, keepdim=True).mean(dim=-2, keepdim=True) + 1e-8) trans_pred = trans / size quaternion_R_noise = quaternion_R_pred trans_noise = trans_pred camera_embed = torch.cat((quaternion_R_noise, trans_noise), -1) camera_embed1 = camera_embed[:, :1] camera_embed2 = camera_embed[:, 1:] pred_cameras = pred_cameras_coarse + pred_cameras # global geometry estimation dec_fine_stage2 = self._decoder_stage3(dec_fine, pos1, pos2, camera_embed1, camera_embed2, interm_features) with torch.cuda.amp.autocast(enabled=False, dtype=torch.float32): res2 = self._downstream_head(2, [tok.float().reshape(-1, tok.shape[-2], tok.shape[-1]) for tok in dec_fine_stage2], shape.reshape(-1, 2)) res2.pop('desc') for key in res2.keys(): res2[key] = res2[key].unflatten(0, (batch_size, view_num+1)).float() return res1, res2, pred_cameras