|
import torch |
|
import copy |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import os |
|
from dust3r.utils.geometry import inv, geotrf, normalize_pointcloud, closed_form_inverse |
|
from mast3r.catmlp_dpt_head import mast3r_head_factory |
|
from mast3r.vgg_pose_head import CameraPredictor, CameraPredictor_clean, Mlp |
|
from mast3r.shallow_cnn import FeatureNet |
|
import mast3r.utils.path_to_dust3r |
|
from dust3r.model import AsymmetricCroCo3DStereo |
|
from dust3r.utils.misc import transpose_to_landscape, freeze_all_params |
|
inf = float('inf') |
|
from dust3r.patch_embed import get_patch_embed |
|
from torch.utils.checkpoint import checkpoint |
|
from pytorch3d.transforms.rotation_conversions import matrix_to_quaternion |
|
|
|
def load_model(model_path, device, verbose=True): |
|
if verbose: |
|
print('... loading model from', model_path) |
|
ckpt = torch.load(model_path, map_location='cpu') |
|
args = ckpt['args'].model.replace("ManyAR_PatchEmbed", "PatchEmbedDust3R") |
|
if 'landscape_only' not in args: |
|
args = args[:-1] + ', landscape_only=False)' |
|
else: |
|
args = args.replace(" ", "").replace('landscape_only=True', 'landscape_only=False') |
|
assert "landscape_only=False" in args |
|
if verbose: |
|
print(f"instantiating : {args}") |
|
net = eval(args) |
|
s = net.load_state_dict(ckpt['model'], strict=False) |
|
if verbose: |
|
print(s) |
|
return net.to(device) |
|
import torch |
|
|
|
def modulate(x, shift, scale): |
|
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) |
|
|
|
class AsymmetricMASt3R(AsymmetricCroCo3DStereo): |
|
def __init__(self, wpose=False, wogs=True, desc_mode=('norm'), two_confs=False, desc_conf_mode=None, **kwargs): |
|
self.desc_mode = desc_mode |
|
self.two_confs = two_confs |
|
self.desc_conf_mode = desc_conf_mode |
|
self.wogs = wogs |
|
self.wpose = wpose |
|
super().__init__(**kwargs) |
|
|
|
|
|
self.dec_blocks_point = copy.deepcopy(self.dec_blocks_fine) |
|
self.cam_cond_encoder_point = copy.deepcopy(self.cam_cond_encoder) |
|
self.decoder_embed_point = copy.deepcopy(self.decoder_embed) |
|
self.dec_norm_point = copy.deepcopy(self.dec_norm) |
|
self.pose_token_ref_point = copy.deepcopy(self.pose_token_ref) |
|
self.pose_token_source_point = copy.deepcopy(self.pose_token_source) |
|
self.cam_cond_embed_point = copy.deepcopy(self.cam_cond_embed) |
|
self.cam_cond_embed_point_pre = copy.deepcopy(self.cam_cond_embed) |
|
self.inject_stage3 = nn.ModuleList([nn.Linear(self.enc_embed_dim, self.dec_embed_dim, bias=False) for i in range(3)]) |
|
self.enc_inject_stage3 = nn.ModuleList([copy.deepcopy(self.enc_norm) for i in range(3)]) |
|
|
|
|
|
self.cam_cond_encoder_fine = copy.deepcopy(self.cam_cond_encoder) |
|
self.adaLN_modulation = nn.ModuleList([nn.Sequential( |
|
nn.SiLU(inplace=False), |
|
nn.Linear(self.dec_embed_dim, 3 * self.dec_embed_dim, bias=True) |
|
) for _ in range(len(self.dec_blocks_fine))]) |
|
for block in self.adaLN_modulation: |
|
nn.init.constant_(block[-1].weight, 0) |
|
nn.init.constant_(block[-1].bias, 0) |
|
self.decoder_embed_fine = copy.deepcopy(self.decoder_embed) |
|
self.dec_cam_norm_fine = copy.deepcopy(self.dec_cam_norm) |
|
self.dec_norm_fine = copy.deepcopy(self.dec_norm) |
|
self.pose_token_ref_fine = copy.deepcopy(self.pose_token_ref) |
|
self.pose_token_source_fine = copy.deepcopy(self.pose_token_source) |
|
self.cam_cond_embed_fine = copy.deepcopy(self.cam_cond_embed) |
|
self.inject_stage2 = nn.ModuleList([nn.Linear(self.enc_embed_dim, self.dec_embed_dim, bias=False) for i in range(3)]) |
|
self.enc_inject_stage2 = nn.ModuleList([copy.deepcopy(self.enc_norm) for i in range(3)]) |
|
|
|
|
|
self.enc_norm_coarse = copy.deepcopy(self.enc_norm) |
|
self.embed_pose = Mlp(7, hidden_features=self.dec_embed_dim, out_features=self.dec_embed_dim) |
|
|
|
|
|
self.cnn_wobn = FeatureNet() |
|
self.cnn_proj = nn.Conv2d(64, 16, 3, 1, 1) |
|
self.cnn_fusion = nn.Conv2d(32*3, 64, 3, 1, 1) |
|
|
|
for i in range(3): |
|
nn.init.constant_(self.inject_stage2[i].weight, 0.) |
|
nn.init.constant_(self.inject_stage3[i].weight, 0.) |
|
self.idx_hook = [2, 5, 8] |
|
|
|
self.encode_feature_landscape = transpose_to_landscape(self.encode_feature, activate=True) |
|
if self.wogs == False: |
|
self.decoder_embed_stage2 = copy.deepcopy(self.decoder_embed) |
|
nn.init.constant_(self.decoder_embed_stage2.weight, 0.) |
|
self.decoder_embed_fxfycxcy = Mlp(4, hidden_features=self.dec_embed_dim, out_features=self.dec_embed_dim) |
|
nn.init.constant_(self.decoder_embed_fxfycxcy.fc2.weight, 0.) |
|
nn.init.constant_(self.decoder_embed_fxfycxcy.fc2.bias, 0.) |
|
|
|
def load_state_dict_stage1(self, ckpt, **kw): |
|
|
|
new_ckpt = dict(ckpt) |
|
return super().load_state_dict(new_ckpt, **kw) |
|
|
|
def load_state_dict(self, ckpt, **kw): |
|
|
|
new_ckpt = dict(ckpt) |
|
if self.head_type == 'dpt_gs': |
|
for key, value in ckpt.items(): |
|
if 'dpt.head.4' in key: |
|
state_dict = self.state_dict() |
|
state_dict[key][:value.shape[0]] = value |
|
new_ckpt[key] = state_dict[key] |
|
return super().load_state_dict(new_ckpt, **kw) |
|
|
|
def encode_feature(self, imgs_vgg, image_size): |
|
H, W = image_size |
|
imgs_vgg = imgs_vgg[0].permute(0,3,1,2) |
|
feat_vgg3, feat_vgg2, feat_vgg1 = self.cnn_wobn(imgs_vgg) |
|
feat_vgg2 = F.interpolate(feat_vgg2.float(), (H, W), mode='bilinear', align_corners=True) |
|
feat_vgg3 = F.interpolate(feat_vgg3.float(), (H, W), mode='bilinear', align_corners=True) |
|
feat_vgg = self.cnn_fusion(torch.cat((feat_vgg1.float(), feat_vgg2, feat_vgg3), 1)) |
|
feat_vgg_detail = self.cnn_proj(feat_vgg) |
|
N, C, h, w = feat_vgg.shape |
|
imgs_vgg = feat_vgg.reshape(N, C, -1).permute(0,2,1) |
|
N, P, C = imgs_vgg.shape |
|
imgs_vgg = imgs_vgg.reshape(N, P, -1, 64) |
|
imgs_vgg = imgs_vgg.permute(0, 2, 1, 3) |
|
x = torch.arange(w).to(imgs_vgg) |
|
y = torch.arange(h).to(imgs_vgg) |
|
xy = torch.meshgrid(x, y, indexing='xy') |
|
pos_full = torch.cat((xy[0].unsqueeze(-1), xy[1].unsqueeze(-1)), -1).unsqueeze(0) |
|
imgs_vgg = imgs_vgg + self.rope(torch.ones_like(imgs_vgg).to(imgs_vgg), pos_full.reshape(1,-1,2).repeat(N, 1, 1).long()).to(imgs_vgg) |
|
imgs_vgg = imgs_vgg.permute(0, 2, 1, 3) |
|
imgs_vgg = imgs_vgg.reshape(N, -1, C).permute(0, 2, 1) |
|
imgs_vgg = imgs_vgg.reshape(N, C, h, w) |
|
return {'imgs_vgg': imgs_vgg.permute(0, 2, 3, 1), 'feat_vgg_detail': feat_vgg_detail.permute(0, 2, 3, 1)} |
|
|
|
@classmethod |
|
def from_pretrained(cls, pretrained_model_name_or_path, **kw): |
|
if os.path.isfile(pretrained_model_name_or_path): |
|
return load_model(pretrained_model_name_or_path, device='cpu') |
|
else: |
|
return super(AsymmetricMASt3R, cls).from_pretrained(pretrained_model_name_or_path, **kw) |
|
|
|
def _encode_image(self, image, true_shape): |
|
|
|
interm_features = [] |
|
x, pos = self.patch_embed(image, true_shape=true_shape) |
|
|
|
|
|
assert self.enc_pos_embed is None |
|
|
|
|
|
for blk in self.enc_blocks: |
|
interm_features.append(x) |
|
x = blk(x, pos) |
|
|
|
x = self.enc_norm(x) |
|
return x, pos, interm_features |
|
|
|
def _encode_symmetrized(self, views): |
|
imgs = [view['img'] for view in views] |
|
shapes = [view['true_shape'] for view in views] |
|
imgs = torch.stack((imgs), dim=1) |
|
B, views, _, H, W = imgs.shape |
|
dtype = imgs.dtype |
|
imgs = imgs.view(-1, *imgs.shape[2:]) |
|
shapes = torch.stack((shapes), dim=1) |
|
shapes = shapes.view(-1, *shapes.shape[2:]) |
|
out, pos, interm_features = self._encode_image(imgs, shapes) |
|
out = out.to(dtype) |
|
for i in range(len(interm_features)): |
|
interm_features[i] = interm_features[i].to(dtype) |
|
interm_features[i] = interm_features[i].reshape(B, views, *out.shape[1:]) |
|
true_shape = shapes |
|
W //= 64 |
|
H //= 64 |
|
n_tokens = H * W |
|
x_coarse = out.new_zeros((B*views, n_tokens, self.patch_embed_coarse2.embed_dim)).to(dtype) |
|
pos_coarse = out.new_zeros((B*views, n_tokens, 2), dtype=torch.int64) |
|
height, width = true_shape.T |
|
is_landscape = (width >= height) |
|
is_portrait = ~is_landscape |
|
fine_token = out.view(B*views, H * 4, W * 4, -1).permute(0, 3, 1, 2) |
|
x_coarse[is_landscape] = self.patch_embed_coarse2.proj(fine_token[is_landscape]).permute(0, 2, 3, 1).flatten(1, 2) |
|
x_coarse[is_portrait] = self.patch_embed_coarse2.proj(fine_token[is_portrait].swapaxes(-1, -2)).permute(0, 2, 3, 1).flatten(1, 2) |
|
pos_coarse[is_landscape] = self.patch_embed_test_.position_getter(1, H, W, pos.device) |
|
pos_coarse[is_portrait] = self.patch_embed_test_.position_getter(1, W, H, pos.device) |
|
x_coarse = self.enc_norm_coarse(x_coarse) |
|
out_coarse = x_coarse.reshape(B, views, *x_coarse.shape[1:]).to(dtype) |
|
pos_coarse = pos_coarse.reshape(B, views, *pos_coarse.shape[1:]) |
|
shapes_coarse = shapes.reshape(B, views, *shapes.shape[1:]) // 4 |
|
out = out.reshape(B, views, *out.shape[1:]) |
|
pos = pos.reshape(B, views, *pos.shape[1:]) |
|
shapes = shapes.reshape(B, views, *shapes.shape[1:]) |
|
return shapes_coarse, out_coarse, pos_coarse, shapes, out, pos, interm_features |
|
|
|
def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768): |
|
self.patch_embed = get_patch_embed(self.patch_embed_cls, img_size, patch_size, enc_embed_dim) |
|
self.patch_embed_coarse = get_patch_embed(self.patch_embed_cls, img_size, 2, enc_embed_dim, input_dim=enc_embed_dim) |
|
self.patch_embed_coarse2 = get_patch_embed(self.patch_embed_cls, img_size, 4, enc_embed_dim, input_dim=enc_embed_dim) |
|
|
|
self.patch_embed_test_ = get_patch_embed(self.patch_embed_cls, img_size, 4 * patch_size, enc_embed_dim) |
|
|
|
|
|
def set_downstream_head(self, output_mode, head_type, landscape_only, depth_mode, conf_mode, patch_size, img_size, **kw): |
|
assert img_size[0] % patch_size == 0 and img_size[ |
|
1] % patch_size == 0, f'{img_size=} must be multiple of {patch_size=}' |
|
self.output_mode = output_mode |
|
self.head_type = head_type |
|
self.depth_mode = depth_mode |
|
self.conf_mode = conf_mode |
|
if self.desc_conf_mode is None: |
|
self.desc_conf_mode = conf_mode |
|
|
|
self.downstream_head1 = mast3r_head_factory(head_type, output_mode, self, has_conf=bool(conf_mode)) |
|
self.downstream_head2 = mast3r_head_factory(head_type, output_mode, self, has_conf=bool(conf_mode)) |
|
|
|
self.head1 = transpose_to_landscape(self.downstream_head1, activate=landscape_only) |
|
self.head2 = transpose_to_landscape(self.downstream_head2, activate=landscape_only) |
|
self.pose_head = CameraPredictor_clean(hood_idx=self.downstream_head2.dpt.hooks, trunk_depth=4, rope=self.rope) |
|
self.pose_head_stage2 = CameraPredictor_clean(hood_idx=self.downstream_head2.dpt.hooks, trunk_depth=4, rope=self.rope) |
|
self.downstream_head4 = mast3r_head_factory('sh', output_mode, self, has_conf=bool(conf_mode)) |
|
self.head4 = transpose_to_landscape(self.downstream_head4, activate=landscape_only) |
|
|
|
def _decoder_stage2(self, f1, pos1, f2, pos2, pose1, pose2, low_token=None): |
|
f = torch.cat((f1, f2), 1) |
|
pos = torch.cat((pos1, pos2), 1) |
|
final_output = [f] |
|
f = self.decoder_embed_fine(f) |
|
B, views, P, C = f.shape |
|
f = f.view(B, -1 ,C) |
|
pos = pos.view(B, -1, pos.shape[-1]) |
|
cam_tokens = [] |
|
final_output.append(f) |
|
pose1_embed = self.embed_pose(pose1) |
|
pose2_embed = self.embed_pose(pose2) |
|
pose_embed = torch.cat((pose1_embed, pose2_embed), 1) |
|
|
|
views = views - 1 |
|
pose_token_ref, pose_token_source = self.pose_token_ref_fine.to(f1.dtype).repeat(B,1,1).view(B, -1, C), self.pose_token_source_fine.to(f1.dtype).repeat(B*views,1,1).view(B*views, -1, C) |
|
dtype = f.dtype |
|
hook_idx = 0 |
|
for i, (blk1, cam_cond, cam_cond_embed_fine, adaLN_modulation) in enumerate(zip(self.dec_blocks_fine, self.cam_cond_encoder_fine, self.cam_cond_embed_fine, self.adaLN_modulation)): |
|
shift_msa, scale_msa, gate_msa = adaLN_modulation(pose_embed).chunk(3, dim=-1) |
|
pose_token_ref = modulate(pose_token_ref.reshape(B, -1, C), shift_msa[:,:1].reshape(B,-1), scale_msa[:,:1].reshape(B,-1)) |
|
pose_token_source = modulate(pose_token_source.reshape(B*views, -1, C), shift_msa[:,1:].reshape(B*views,-1), scale_msa[:,1:].reshape(B*views,-1)) |
|
feat = checkpoint(blk1, f, pos) |
|
feat = feat.view(B, views+1, -1, C) |
|
f1 = feat[:,:1].view(B, -1, C) |
|
f2 = feat[:,1:].reshape(B*views, -1, C) |
|
f1_cam = torch.cat((pose_token_ref, f1.view(B, -1, C)), 1) |
|
f2_cam = torch.cat((pose_token_source, f2.view(B*views, -1, C)), 1) |
|
f_cam = torch.cat((f1_cam, f2_cam), 0) |
|
f_cam = checkpoint(cam_cond, f_cam) |
|
f_delta = f_cam[:,1:] |
|
f_cam = f_cam[:,:1] |
|
f_delta1 = f_delta[:B].view(B, -1, C) |
|
f_delta2 = f_delta[B:].view(B*views, -1, C) |
|
pose_token_ref = pose_token_ref.view(B, -1, C) + f_cam[:B].view(B, -1, C) |
|
pose_token_source = pose_token_source.view(B*views, -1, C) + f_cam[B:].view(B*views, -1, C) |
|
cam_tokens.append((pose_token_ref, pose_token_source)) |
|
f1 = f1.view(B, -1, C) + cam_cond_embed_fine(f_delta1) |
|
f2 = f2.view(B*views, -1, C) + cam_cond_embed_fine(f_delta2) |
|
if i in self.idx_hook: |
|
f1 = f1.view(B, -1, C) + self.inject_stage2[hook_idx](self.enc_inject_stage2[hook_idx](low_token[i * 2][:,:1].view(B, -1, 1024))) |
|
f2 = f2.view(B*views, -1, C) + self.inject_stage2[hook_idx](self.enc_inject_stage2[hook_idx](low_token[i * 2][:,1:].reshape(B*views, -1, 1024))) |
|
hook_idx += 1 |
|
f1 = f1.view(B, 1, -1 ,C) |
|
f2 = f2.view(B, views, -1 ,C) |
|
f = torch.cat((f1, f2), 1) |
|
final_output.append(f) |
|
f = f.view(B, -1 ,C) |
|
|
|
|
|
del final_output[1] |
|
final_output[-1] = self.dec_norm_fine(final_output[-1]) |
|
|
|
cam_tokens[-1] = tuple(map(self.dec_cam_norm_fine, cam_tokens[-1])) |
|
return final_output, zip(*cam_tokens) |
|
|
|
def _decoder_stage3(self, feat_ref, pos1, pos2, pose1, pose2, low_token=None, feat_stage2=None, fxfycxcy1=None, fxfycxcy2=None): |
|
final_output = [feat_ref[0]] |
|
|
|
final_output.append(feat_ref[1]) |
|
with torch.cuda.amp.autocast(enabled=False,dtype=torch.float32): |
|
pose1_embed = self.embed_pose(pose1) |
|
pose2_embed = self.embed_pose(pose2) |
|
pose_embed = torch.cat((pose1_embed, pose2_embed), 1) |
|
B, views, P, C = feat_ref[-1].shape |
|
if feat_stage2 is None: |
|
f = self.decoder_embed_point(feat_ref[0]) |
|
else: |
|
f = self.decoder_embed_point(feat_ref[0]) + self.decoder_embed_stage2(feat_stage2) |
|
views = views - 1 |
|
dtype = f.dtype |
|
pose_token_ref, pose_token_source = self.pose_token_ref_point.to(dtype).repeat(B,1,1).view(B, -1, C), self.pose_token_source_point.to(dtype).repeat(B*views,1,1).view(B*views, -1, C) |
|
pos = torch.cat((pos1, pos2), 1) |
|
if fxfycxcy1 is not None: |
|
with torch.cuda.amp.autocast(enabled=False,dtype=torch.float32): |
|
fxfycxcy1 = self.decoder_embed_fxfycxcy(fxfycxcy1) |
|
fxfycxcy2 = self.decoder_embed_fxfycxcy(fxfycxcy2) |
|
pose1_embed = pose1_embed + fxfycxcy1 |
|
pose2_embed = pose2_embed + fxfycxcy2 |
|
pose1_embed = pose1_embed.to(dtype) |
|
pose2_embed = pose2_embed.to(dtype) |
|
pose_token_ref = pose_token_ref + pose1_embed |
|
pose_token_source = pose_token_source + pose2_embed.view(B*views, -1, C) |
|
hook_idx = 0 |
|
for i, (blk, blk_cross, cam_cond, cam_cond_embed_point, cam_cond_embed_point_pre) in enumerate(zip(self.dec_blocks_point, self.dec_blocks_point_cross, self.cam_cond_encoder_point, self.cam_cond_embed_point, self.cam_cond_embed_point_pre)): |
|
f1_pre = feat_ref[i+1].reshape(B, (views+1), -1, C)[:,:1].view(B, -1, C) |
|
f2_pre = feat_ref[i+1].reshape(B, (views+1), -1, C)[:,1:].reshape(B*views, -1, C) |
|
f1_pre = f1_pre + cam_cond_embed_point_pre(pose_token_ref) |
|
f2_pre = f2_pre + cam_cond_embed_point_pre(pose_token_source) |
|
f_pre = torch.cat((f1_pre.view(B, 1, -1, C), f2_pre.view(B, views, -1, C)), 1) |
|
feat, _ = checkpoint(blk_cross, f.reshape(B*(views+1), -1, C), f_pre.reshape(B*(views+1), -1, C), pos.reshape(B*(views+1), -1, 2), pos.reshape(B*(views+1), -1, 2)) |
|
feat = feat.view(B, views+1, -1, C).reshape(B, -1, C) |
|
feat = checkpoint(blk, feat, pos.reshape(B, -1, 2)) |
|
feat = feat.view(B, views+1, -1, C) |
|
f1 = feat[:,:1].view(B, -1, C) |
|
f2 = feat[:,1:].reshape(B*views, -1, C) |
|
f1_cam = torch.cat((pose_token_ref, f1.view(B, -1, C)), 1) |
|
f2_cam = torch.cat((pose_token_source, f2.view(B*views, -1, C)), 1) |
|
f_cam = torch.cat((f1_cam, f2_cam), 0) |
|
f_cam = checkpoint(cam_cond, f_cam) |
|
f_delta = f_cam[:,1:] |
|
f_cam = f_cam[:,:1] |
|
f_delta1 = f_delta[:B].view(B, -1, C) |
|
f_delta2 = f_delta[B:].view(B*views, -1, C) |
|
pose_token_ref = pose_token_ref.view(B, -1, C) + f_cam[:B].view(B, -1, C) |
|
pose_token_source = pose_token_source.view(B*views, -1, C) + f_cam[B:].view(B*views, -1, C) |
|
f1 = f1.view(B, -1, C) + cam_cond_embed_point(f_delta1) |
|
f2 = f2.view(B*views, -1, C) + cam_cond_embed_point(f_delta2) |
|
if i in self.idx_hook: |
|
f1 = f1.view(B, -1, C) + self.inject_stage3[hook_idx](self.enc_inject_stage3[hook_idx](low_token[i * 2][:,:1].view(B, -1, 1024))) |
|
f2 = f2.view(B*views, -1, C) + self.inject_stage3[hook_idx](self.enc_inject_stage3[hook_idx](low_token[i * 2][:,1:].reshape(B*views, -1, 1024))) |
|
hook_idx += 1 |
|
f1 = f1.view(B, 1, -1 ,C) |
|
f2 = f2.view(B, views, -1 ,C) |
|
f = torch.cat((f1, f2), 1) |
|
final_output.append(f) |
|
f = f.view(B, -1 ,C) |
|
|
|
|
|
del final_output[1] |
|
final_output[-1] = self.dec_norm_point(final_output[-1]) |
|
return final_output |
|
|
|
def _decoder(self, f1, pos1, f2, pos2): |
|
final_output = [(f1, f2)] |
|
|
|
f1 = self.decoder_embed(f1) |
|
f2 = self.decoder_embed(f2) |
|
B, views, P, C = f2.shape |
|
f1 = f1.view(B, -1 ,C) |
|
f2 = f2.view(B, -1 ,C) |
|
pos1 = pos1.view(B, -1, pos1.shape[-1]) |
|
pos2 = pos2.view(B, -1, pos2.shape[-1]) |
|
cam_tokens = [] |
|
final_output.append((f1, f2)) |
|
pose_token_ref, pose_token_source = self.pose_token_ref.to(f1.dtype).repeat(B,1,1).view(B, -1, C), self.pose_token_source.to(f1.dtype).repeat(B*views,1,1).view(B*views, -1, C) |
|
for i, (blk1, blk2, cam_cond, cam_cond_embed) in enumerate(zip(self.dec_blocks, self.dec_blocks2, self.cam_cond_encoder, self.cam_cond_embed)): |
|
f1, _ = checkpoint(blk1, *final_output[-1][::+1], pos1, pos2) |
|
f2, _ = checkpoint(blk2, *final_output[-1][::-1], pos2, pos1) |
|
f1_cam = torch.cat((pose_token_ref, f1.view(B, -1, C)), 1) |
|
f2_cam = torch.cat((pose_token_source, f2.view(B*views, -1, C)), 1) |
|
f_cam = torch.cat((f1_cam, f2_cam), 0) |
|
f_cam = checkpoint(cam_cond, f_cam) |
|
f_delta = f_cam[:,1:] |
|
f_cam = f_cam[:,:1] |
|
f_delta1 = f_delta[:B].view(B, -1, C) |
|
f_delta2 = f_delta[B:].view(B*views, -1, C) |
|
pose_token_ref = pose_token_ref.view(B, -1, C) + f_cam[:B].view(B, -1, C) |
|
pose_token_source = pose_token_source.view(B*views, -1, C) + f_cam[B:].view(B*views, -1, C) |
|
cam_tokens.append((pose_token_ref, pose_token_source)) |
|
f1 = f1.view(B, -1, C) + cam_cond_embed(f_delta1) |
|
f2 = f2.view(B*views, -1, C) + cam_cond_embed(f_delta2) |
|
f1 = f1.view(B, -1 ,C) |
|
f2 = f2.view(B, -1 ,C) |
|
|
|
final_output.append((f1, f2)) |
|
|
|
|
|
del final_output[1] |
|
cam_tokens[-1] = tuple(map(self.dec_cam_norm, cam_tokens[-1])) |
|
return zip(*cam_tokens) |
|
|
|
|
|
def forward_coarse_pose(self, view1, view2, enabled=True, dtype=torch.bfloat16): |
|
|
|
batch_size, _, _, _ = view1[0]['img'].shape |
|
view_num = len(view2) |
|
with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16): |
|
shapes, feat, pos, shape_stage2, feat_stage2, pos_stage2, interm_features = self._encode_symmetrized(view1+view2) |
|
feat1 = feat[:, :1].to(dtype) |
|
feat2 = feat[:, 1:].to(dtype) |
|
pos1 = pos[:, :1] |
|
pos2 = pos[:, 1:] |
|
shape1 = shapes[:, :1] |
|
shape2 = shapes[:, 1:] |
|
shape1_stage2 = shape_stage2[:, :1] |
|
shape2_stage2 = shape_stage2[:, 1:] |
|
feat1_stage2 = feat_stage2[:, :1] |
|
feat2_stage2 = feat_stage2[:, 1:] |
|
pos1_stage2 = pos_stage2[:, :1] |
|
pos2_stage2 = pos_stage2[:, 1:] |
|
(pose_token1, pose_token2) = self._decoder(feat1, pos1, feat2, pos2) |
|
pred_cameras, _ = self.pose_head(batch_size, interm_feature1=pose_token1, interm_feature2=pose_token2, enabled=True, dtype=dtype) |
|
return feat1_stage2, pos1_stage2, feat2_stage2, pos2_stage2, pred_cameras, shape1_stage2, shape2_stage2, None, None, pose_token1, pose_token2, interm_features |
|
|
|
def forward(self, view1, view2, enabled=True, dtype=torch.bfloat16): |
|
if self.wogs: |
|
res1, res2, pred_cameras = self.forward_pointmap(view1, view2, enabled=enabled, dtype=dtype) |
|
else: |
|
res1, res2, pred_cameras = self.forward_gs(view1, view2, enabled=enabled, dtype=dtype) |
|
return res1, res2, pred_cameras |
|
|
|
|
|
def forward_gs(self, view1, view2, enabled=True, dtype=torch.bfloat16): |
|
raise NotImplementedError("This feature (novel view synthesis) has not been released yet.") |
|
|
|
def load_state_dict_posehead(self, ckpt, strict=True): |
|
|
|
new_ckpt = {} |
|
for key, value in ckpt.items(): |
|
new_key = '.'.join([key.split('.')[0] + '_pose_head']+ key.split('.')[1:]) |
|
if new_key in dict(self.named_parameters()).keys(): |
|
print(f'Loading {new_key} from checkpoint') |
|
new_ckpt[new_key] = ckpt[key] |
|
return self.load_state_dict(new_ckpt, strict=strict) |
|
|
|
|
|
def forward_pointmap(self, view1, view2, enabled=True, dtype=torch.bfloat16): |
|
|
|
batch_size, _, _, _ = view1[0]['img'].shape |
|
view_num = len(view2) |
|
|
|
|
|
feat1, pos1, feat2, pos2, pred_cameras_coarse, shape1, shape2, res1_stage1, res2_stage1, pose_token1, pose_token2, interm_features = self.forward_coarse_pose(view1, view2, enabled=enabled, dtype=dtype) |
|
|
|
if self.wpose == False: |
|
trans = pred_cameras_coarse[-1]['T'].float().detach().clone() |
|
trans = trans.reshape(batch_size, -1, 3) |
|
quaternion_R_pred = pred_cameras_coarse[-1]['quaternion_R'].reshape(batch_size, -1, 4).float().detach().clone() |
|
else: |
|
ref_camera_pose = torch.cat([view['camera_pose'] for view in view1], 0).double() |
|
trajectory = torch.cat([view['camera_pose'] for view in view1 + view2], 0).double() |
|
in_camera1 = closed_form_inverse(ref_camera_pose) |
|
trajectory = torch.bmm(in_camera1.repeat(trajectory.shape[0],1,1), trajectory) |
|
quaternion_R_pred = matrix_to_quaternion(trajectory[:, :3, :3]).float().reshape(batch_size, -1, 4) |
|
trans = trajectory[:, :3, 3].float().reshape(batch_size, -1, 3) |
|
gt_quaternion_R = quaternion_R_pred |
|
gt_trans = trans |
|
|
|
size = (trans.norm(dim=-1, keepdim=True).mean(dim=-2, keepdim=True) + 1e-8) |
|
trans_pred = trans / size |
|
camera_embed = torch.cat((quaternion_R_pred, trans_pred), -1) |
|
camera_embed1 = camera_embed[:, :1].to(dtype) |
|
camera_embed2 = camera_embed[:, 1:].to(dtype) |
|
|
|
|
|
dec_fine, (pose_token1_fine, pose_token2_fine) = self._decoder_stage2(feat1, pos1, feat2, pos2, camera_embed1, camera_embed2, interm_features) |
|
shape = torch.cat((shape1, shape2), 1) |
|
res1 = self._downstream_head(1, [tok.to(dtype).reshape(-1, tok.shape[-2], tok.shape[-1]) for tok in dec_fine], shape.reshape(-1, 2)) |
|
res1.pop('desc') |
|
|
|
for key in res1.keys(): |
|
res1[key] = res1[key].unflatten(0, (batch_size, view_num+1)).float() |
|
with torch.cuda.amp.autocast(enabled=False, dtype=torch.float32): |
|
pred_cameras, _ = self.pose_head_stage2(batch_size, interm_feature1=pose_token1_fine, interm_feature2=pose_token2_fine, enabled=True, dtype=torch.float32) |
|
|
|
if self.wpose == False: |
|
trans = pred_cameras[-1]['T'].float().detach().clone() |
|
quaternion_R_pred = pred_cameras[-1]['quaternion_R'].reshape(batch_size, -1, 4).float().detach().clone() |
|
else: |
|
quaternion_R_pred = gt_quaternion_R |
|
trans = gt_trans |
|
|
|
size = (trans.norm(dim=-1, keepdim=True).mean(dim=-2, keepdim=True) + 1e-8) |
|
trans_pred = trans / size |
|
quaternion_R_noise = quaternion_R_pred |
|
trans_noise = trans_pred |
|
camera_embed = torch.cat((quaternion_R_noise, trans_noise), -1) |
|
camera_embed1 = camera_embed[:, :1] |
|
camera_embed2 = camera_embed[:, 1:] |
|
pred_cameras = pred_cameras_coarse + pred_cameras |
|
|
|
|
|
dec_fine_stage2 = self._decoder_stage3(dec_fine, pos1, pos2, camera_embed1, camera_embed2, interm_features) |
|
with torch.cuda.amp.autocast(enabled=False, dtype=torch.float32): |
|
res2 = self._downstream_head(2, [tok.float().reshape(-1, tok.shape[-2], tok.shape[-1]) for tok in dec_fine_stage2], shape.reshape(-1, 2)) |
|
res2.pop('desc') |
|
for key in res2.keys(): |
|
res2[key] = res2[key].unflatten(0, (batch_size, view_num+1)).float() |
|
return res1, res2, pred_cameras |
|
|