import sys from pathlib import Path import torch from .. import logger from ..utils.base_model import BaseModel sys.path.append(str(Path(__file__).parent / "../../third_party")) from SuperGluePretrainedNetwork.models import superpoint # noqa E402 # The original keypoint sampling is incorrect. We patch it here but # we don't fix it upstream to not impact exisiting evaluations. def sample_descriptors_fix_sampling(keypoints, descriptors, s: int = 8): """Interpolate descriptors at keypoint locations""" b, c, h, w = descriptors.shape keypoints = (keypoints + 0.5) / (keypoints.new_tensor([w, h]) * s) keypoints = keypoints * 2 - 1 # normalize to (-1, 1) descriptors = torch.nn.functional.grid_sample( descriptors, keypoints.view(b, 1, -1, 2), mode="bilinear", align_corners=False, ) descriptors = torch.nn.functional.normalize( descriptors.reshape(b, c, -1), p=2, dim=1 ) return descriptors class SuperPoint(BaseModel): default_conf = { "nms_radius": 4, "keypoint_threshold": 0.005, "max_keypoints": -1, "remove_borders": 4, "fix_sampling": False, } required_inputs = ["image"] detection_noise = 2.0 def _init(self, conf): if conf["fix_sampling"]: superpoint.sample_descriptors = sample_descriptors_fix_sampling self.net = superpoint.SuperPoint(conf) logger.info("Load SuperPoint model done.") def _forward(self, data): return self.net(data, self.conf)