File size: 3,690 Bytes
3040ac4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import sys
from collections import OrderedDict, namedtuple
from pathlib import Path

import torch

from .. import MODEL_REPO_ID, logger
from ..utils.base_model import BaseModel

sgmnet_path = Path(__file__).parent / "../../third_party/SGMNet"
sys.path.append(str(sgmnet_path))

from sgmnet import matcher as SGM_Model

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class SGMNet(BaseModel):
    default_conf = {
        "name": "SGM",
        "model_name": "weights/sgm/root/model_best.pth",
        "seed_top_k": [256, 256],
        "seed_radius_coe": 0.01,
        "net_channels": 128,
        "layer_num": 9,
        "head": 4,
        "seedlayer": [0, 6],
        "use_mc_seeding": True,
        "use_score_encoding": False,
        "conf_bar": [1.11, 0.1],
        "sink_iter": [10, 100],
        "detach_iter": 1000000,
        "match_threshold": 0.2,
    }
    required_inputs = [
        "image0",
        "image1",
    ]

    # Initialize the line matcher
    def _init(self, conf):
        model_path = self._download_model(
            repo_id=MODEL_REPO_ID,
            filename="{}/{}".format(Path(__file__).stem, self.conf["model_name"]),
        )

        # config
        config = namedtuple("config", conf.keys())(*conf.values())
        self.net = SGM_Model(config)
        checkpoint = torch.load(model_path, map_location="cpu")
        # for ddp model
        if list(checkpoint["state_dict"].items())[0][0].split(".")[0] == "module":
            new_stat_dict = OrderedDict()
            for key, value in checkpoint["state_dict"].items():
                new_stat_dict[key[7:]] = value
            checkpoint["state_dict"] = new_stat_dict
        self.net.load_state_dict(checkpoint["state_dict"])
        logger.info("Load SGMNet model done.")

    def _forward(self, data):
        x1 = data["keypoints0"].squeeze()  # N x 2
        x2 = data["keypoints1"].squeeze()
        score1 = data["scores0"].reshape(-1, 1)  # N x 1
        score2 = data["scores1"].reshape(-1, 1)
        desc1 = data["descriptors0"].permute(0, 2, 1)  # 1 x N x 128
        desc2 = data["descriptors1"].permute(0, 2, 1)
        size1 = (
            torch.tensor(data["image0"].shape[2:]).flip(0).to(x1.device)
        )  # W x H -> x & y
        size2 = torch.tensor(data["image1"].shape[2:]).flip(0).to(x2.device)  # W x H
        norm_x1 = self.normalize_size(x1, size1)
        norm_x2 = self.normalize_size(x2, size2)

        x1 = torch.cat((norm_x1, score1), dim=-1)  # N x 3
        x2 = torch.cat((norm_x2, score2), dim=-1)
        input = {"x1": x1[None], "x2": x2[None], "desc1": desc1, "desc2": desc2}
        input = {
            k: v.to(device).float() if isinstance(v, torch.Tensor) else v
            for k, v in input.items()
        }
        pred = self.net(input, test_mode=True)

        p = pred["p"]  # shape: N * M
        indices0 = self.match_p(p[0, :-1, :-1])
        pred = {
            "matches0": indices0.unsqueeze(0),
            "matching_scores0": torch.zeros(indices0.size(0)).unsqueeze(0),
        }
        return pred

    def match_p(self, p):
        score, index = torch.topk(p, k=1, dim=-1)
        _, index2 = torch.topk(p, k=1, dim=-2)
        mask_th, index, index2 = (
            score[:, 0] > self.conf["match_threshold"],
            index[:, 0],
            index2.squeeze(0),
        )
        mask_mc = index2[index] == torch.arange(len(p)).to(device)
        mask = mask_th & mask_mc
        indices0 = torch.where(mask, index, index.new_tensor(-1))
        return indices0

    def normalize_size(self, x, size, scale=1):
        norm_fac = size.max()
        return (x - size / 2 + 0.5) / (norm_fac * scale)