import sys from pathlib import Path import torch from .. import MODEL_REPO_ID, logger from ..utils.base_model import BaseModel sys.path.append(str(Path(__file__).parent / "../../third_party")) from ASpanFormer.src.ASpanFormer.aspanformer import ASpanFormer as _ASpanFormer from ASpanFormer.src.config.default import get_cfg_defaults from ASpanFormer.src.utils.misc import lower_config aspanformer_path = Path(__file__).parent / "../../third_party/ASpanFormer" class ASpanFormer(BaseModel): default_conf = { "model_name": "outdoor.ckpt", "match_threshold": 0.2, "sinkhorn_iterations": 20, "max_keypoints": 2048, "config_path": aspanformer_path / "configs/aspan/outdoor/aspan_test.py", } required_inputs = ["image0", "image1"] def _init(self, conf): config = get_cfg_defaults() config.merge_from_file(conf["config_path"]) _config = lower_config(config) # update: match threshold _config["aspan"]["match_coarse"]["thr"] = conf["match_threshold"] _config["aspan"]["match_coarse"]["skh_iters"] = conf["sinkhorn_iterations"] self.net = _ASpanFormer(config=_config["aspan"]) model_path = self._download_model( repo_id=MODEL_REPO_ID, filename="{}/{}".format(Path(__file__).stem, self.conf["model_name"]), ) state_dict = torch.load(str(model_path), map_location="cpu")["state_dict"] self.net.load_state_dict(state_dict, strict=False) logger.info("Loaded Aspanformer model") def _forward(self, data): data_ = { "image0": data["image0"], "image1": data["image1"], } self.net(data_, online_resize=True) pred = { "keypoints0": data_["mkpts0_f"], "keypoints1": data_["mkpts1_f"], "mconf": data_["mconf"], } scores = data_["mconf"] top_k = self.conf["max_keypoints"] if top_k is not None and len(scores) > top_k: keep = torch.argsort(scores, descending=True)[:top_k] scores = scores[keep] pred["keypoints0"], pred["keypoints1"], pred["mconf"] = ( pred["keypoints0"][keep], pred["keypoints1"][keep], scores, ) return pred