File size: 4,088 Bytes
3040ac4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import sys
from pathlib import Path

import numpy as np
import torch
import torchvision.transforms as tfm

from .. import MODEL_REPO_ID, logger
from ..utils.base_model import BaseModel

duster_path = Path(__file__).parent / "../../third_party/dust3r"
sys.path.append(str(duster_path))

from dust3r.cloud_opt import GlobalAlignerMode, global_aligner
from dust3r.image_pairs import make_pairs
from dust3r.inference import inference
from dust3r.model import AsymmetricCroCo3DStereo
from dust3r.utils.geometry import find_reciprocal_matches, xy_grid

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class Duster(BaseModel):
    default_conf = {
        "name": "Duster3r",
        "model_name": "duster_vit_large.pth",
        "max_keypoints": 3000,
        "vit_patch_size": 16,
    }

    def _init(self, conf):
        self.normalize = tfm.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        model_path = self._download_model(
            repo_id=MODEL_REPO_ID,
            filename="{}/{}".format(Path(__file__).stem, self.conf["model_name"]),
        )
        self.net = AsymmetricCroCo3DStereo.from_pretrained(model_path).to(device)
        logger.info("Loaded Dust3r model")

    def preprocess(self, img):
        # the super-class already makes sure that img0,img1 have
        # same resolution and that h == w
        _, h, _ = img.shape
        imsize = h
        if not ((h % self.vit_patch_size) == 0):
            imsize = int(self.vit_patch_size * round(h / self.vit_patch_size, 0))
            img = tfm.functional.resize(img, imsize, antialias=True)

        _, new_h, new_w = img.shape
        if not ((new_w % self.vit_patch_size) == 0):
            safe_w = int(self.vit_patch_size * round(new_w / self.vit_patch_size, 0))
            img = tfm.functional.resize(img, (new_h, safe_w), antialias=True)

        img = self.normalize(img).unsqueeze(0)

        return img

    def _forward(self, data):
        img0, img1 = data["image0"], data["image1"]
        mean = torch.tensor([0.5, 0.5, 0.5]).to(device)
        std = torch.tensor([0.5, 0.5, 0.5]).to(device)

        img0 = (img0 - mean.view(1, 3, 1, 1)) / std.view(1, 3, 1, 1)
        img1 = (img1 - mean.view(1, 3, 1, 1)) / std.view(1, 3, 1, 1)

        images = [
            {"img": img0, "idx": 0, "instance": 0},
            {"img": img1, "idx": 1, "instance": 1},
        ]
        pairs = make_pairs(
            images, scene_graph="complete", prefilter=None, symmetrize=True
        )
        output = inference(pairs, self.net, device, batch_size=1)
        scene = global_aligner(output, device=device, mode=GlobalAlignerMode.PairViewer)
        # retrieve useful values from scene:
        imgs = scene.imgs
        confidence_masks = scene.get_masks()
        pts3d = scene.get_pts3d()
        pts2d_list, pts3d_list = [], []
        for i in range(2):
            conf_i = confidence_masks[i].cpu().numpy()
            pts2d_list.append(
                xy_grid(*imgs[i].shape[:2][::-1])[conf_i]
            )  # imgs[i].shape[:2] = (H, W)
            pts3d_list.append(pts3d[i].detach().cpu().numpy()[conf_i])

        if len(pts3d_list[1]) == 0:
            pred = {
                "keypoints0": torch.zeros([0, 2]),
                "keypoints1": torch.zeros([0, 2]),
            }
            logger.warning(f"Matched {0} points")
        else:
            reciprocal_in_P2, nn2_in_P1, num_matches = find_reciprocal_matches(
                *pts3d_list
            )
            logger.info(f"Found {num_matches} matches")
            mkpts1 = pts2d_list[1][reciprocal_in_P2]
            mkpts0 = pts2d_list[0][nn2_in_P1][reciprocal_in_P2]
            top_k = self.conf["max_keypoints"]
            if top_k is not None and len(mkpts0) > top_k:
                keep = np.round(np.linspace(0, len(mkpts0) - 1, top_k)).astype(int)
                mkpts0 = mkpts0[keep]
                mkpts1 = mkpts1[keep]
            pred = {
                "keypoints0": torch.from_numpy(mkpts0),
                "keypoints1": torch.from_numpy(mkpts1),
            }
        return pred