import sys from collections import OrderedDict, namedtuple from pathlib import Path import torch from .. import MODEL_REPO_ID, logger from ..utils.base_model import BaseModel sgmnet_path = Path(__file__).parent / "../../third_party/SGMNet" sys.path.append(str(sgmnet_path)) from sgmnet import matcher as SGM_Model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class SGMNet(BaseModel): default_conf = { "name": "SGM", "model_name": "weights/sgm/root/model_best.pth", "seed_top_k": [256, 256], "seed_radius_coe": 0.01, "net_channels": 128, "layer_num": 9, "head": 4, "seedlayer": [0, 6], "use_mc_seeding": True, "use_score_encoding": False, "conf_bar": [1.11, 0.1], "sink_iter": [10, 100], "detach_iter": 1000000, "match_threshold": 0.2, } required_inputs = [ "image0", "image1", ] # Initialize the line matcher def _init(self, conf): model_path = self._download_model( repo_id=MODEL_REPO_ID, filename="{}/{}".format(Path(__file__).stem, self.conf["model_name"]), ) # config config = namedtuple("config", conf.keys())(*conf.values()) self.net = SGM_Model(config) checkpoint = torch.load(model_path, map_location="cpu") # for ddp model if list(checkpoint["state_dict"].items())[0][0].split(".")[0] == "module": new_stat_dict = OrderedDict() for key, value in checkpoint["state_dict"].items(): new_stat_dict[key[7:]] = value checkpoint["state_dict"] = new_stat_dict self.net.load_state_dict(checkpoint["state_dict"]) logger.info("Load SGMNet model done.") def _forward(self, data): x1 = data["keypoints0"].squeeze() # N x 2 x2 = data["keypoints1"].squeeze() score1 = data["scores0"].reshape(-1, 1) # N x 1 score2 = data["scores1"].reshape(-1, 1) desc1 = data["descriptors0"].permute(0, 2, 1) # 1 x N x 128 desc2 = data["descriptors1"].permute(0, 2, 1) size1 = ( torch.tensor(data["image0"].shape[2:]).flip(0).to(x1.device) ) # W x H -> x & y size2 = torch.tensor(data["image1"].shape[2:]).flip(0).to(x2.device) # W x H norm_x1 = self.normalize_size(x1, size1) norm_x2 = self.normalize_size(x2, size2) x1 = torch.cat((norm_x1, score1), dim=-1) # N x 3 x2 = torch.cat((norm_x2, score2), dim=-1) input = {"x1": x1[None], "x2": x2[None], "desc1": desc1, "desc2": desc2} input = { k: v.to(device).float() if isinstance(v, torch.Tensor) else v for k, v in input.items() } pred = self.net(input, test_mode=True) p = pred["p"] # shape: N * M indices0 = self.match_p(p[0, :-1, :-1]) pred = { "matches0": indices0.unsqueeze(0), "matching_scores0": torch.zeros(indices0.size(0)).unsqueeze(0), } return pred def match_p(self, p): score, index = torch.topk(p, k=1, dim=-1) _, index2 = torch.topk(p, k=1, dim=-2) mask_th, index, index2 = ( score[:, 0] > self.conf["match_threshold"], index[:, 0], index2.squeeze(0), ) mask_mc = index2[index] == torch.arange(len(p)).to(device) mask = mask_th & mask_mc indices0 = torch.where(mask, index, index.new_tensor(-1)) return indices0 def normalize_size(self, x, size, scale=1): norm_fac = size.max() return (x - size / 2 + 0.5) / (norm_fac * scale)