XingyiHe's picture
init commit
3040ac4
raw
history blame
3.69 kB
import sys
from collections import OrderedDict, namedtuple
from pathlib import Path
import torch
from .. import MODEL_REPO_ID, logger
from ..utils.base_model import BaseModel
sgmnet_path = Path(__file__).parent / "../../third_party/SGMNet"
sys.path.append(str(sgmnet_path))
from sgmnet import matcher as SGM_Model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class SGMNet(BaseModel):
default_conf = {
"name": "SGM",
"model_name": "weights/sgm/root/model_best.pth",
"seed_top_k": [256, 256],
"seed_radius_coe": 0.01,
"net_channels": 128,
"layer_num": 9,
"head": 4,
"seedlayer": [0, 6],
"use_mc_seeding": True,
"use_score_encoding": False,
"conf_bar": [1.11, 0.1],
"sink_iter": [10, 100],
"detach_iter": 1000000,
"match_threshold": 0.2,
}
required_inputs = [
"image0",
"image1",
]
# Initialize the line matcher
def _init(self, conf):
model_path = self._download_model(
repo_id=MODEL_REPO_ID,
filename="{}/{}".format(Path(__file__).stem, self.conf["model_name"]),
)
# config
config = namedtuple("config", conf.keys())(*conf.values())
self.net = SGM_Model(config)
checkpoint = torch.load(model_path, map_location="cpu")
# for ddp model
if list(checkpoint["state_dict"].items())[0][0].split(".")[0] == "module":
new_stat_dict = OrderedDict()
for key, value in checkpoint["state_dict"].items():
new_stat_dict[key[7:]] = value
checkpoint["state_dict"] = new_stat_dict
self.net.load_state_dict(checkpoint["state_dict"])
logger.info("Load SGMNet model done.")
def _forward(self, data):
x1 = data["keypoints0"].squeeze() # N x 2
x2 = data["keypoints1"].squeeze()
score1 = data["scores0"].reshape(-1, 1) # N x 1
score2 = data["scores1"].reshape(-1, 1)
desc1 = data["descriptors0"].permute(0, 2, 1) # 1 x N x 128
desc2 = data["descriptors1"].permute(0, 2, 1)
size1 = (
torch.tensor(data["image0"].shape[2:]).flip(0).to(x1.device)
) # W x H -> x & y
size2 = torch.tensor(data["image1"].shape[2:]).flip(0).to(x2.device) # W x H
norm_x1 = self.normalize_size(x1, size1)
norm_x2 = self.normalize_size(x2, size2)
x1 = torch.cat((norm_x1, score1), dim=-1) # N x 3
x2 = torch.cat((norm_x2, score2), dim=-1)
input = {"x1": x1[None], "x2": x2[None], "desc1": desc1, "desc2": desc2}
input = {
k: v.to(device).float() if isinstance(v, torch.Tensor) else v
for k, v in input.items()
}
pred = self.net(input, test_mode=True)
p = pred["p"] # shape: N * M
indices0 = self.match_p(p[0, :-1, :-1])
pred = {
"matches0": indices0.unsqueeze(0),
"matching_scores0": torch.zeros(indices0.size(0)).unsqueeze(0),
}
return pred
def match_p(self, p):
score, index = torch.topk(p, k=1, dim=-1)
_, index2 = torch.topk(p, k=1, dim=-2)
mask_th, index, index2 = (
score[:, 0] > self.conf["match_threshold"],
index[:, 0],
index2.squeeze(0),
)
mask_mc = index2[index] == torch.arange(len(p)).to(device)
mask = mask_th & mask_mc
indices0 = torch.where(mask, index, index.new_tensor(-1))
return indices0
def normalize_size(self, x, size, scale=1):
norm_fac = size.max()
return (x - size / 2 + 0.5) / (norm_fac * scale)