XingyiHe's picture
init commit
3040ac4
raw
history blame
3.19 kB
import sys
import warnings
from copy import deepcopy
from pathlib import Path
import torch
from .. import MODEL_REPO_ID, logger
tp_path = Path(__file__).parent / "../../third_party"
sys.path.append(str(tp_path))
from EfficientLoFTR.src.loftr import LoFTR as ELoFTR_
from EfficientLoFTR.src.loftr import (
full_default_cfg,
opt_default_cfg,
reparameter,
)
from ..utils.base_model import BaseModel
class ELoFTR(BaseModel):
default_conf = {
"model_name": "eloftr_outdoor.ckpt",
"match_threshold": 0.2,
# "sinkhorn_iterations": 20,
"max_keypoints": -1,
# You can choose model type in ['full', 'opt']
"model_type": "full", # 'full' for best quality, 'opt' for best efficiency
# You can choose numerical precision in ['fp32', 'mp', 'fp16']. 'fp16' for best efficiency
"precision": "fp32",
}
required_inputs = ["image0", "image1"]
def _init(self, conf):
if self.conf["model_type"] == "full":
_default_cfg = deepcopy(full_default_cfg)
elif self.conf["model_type"] == "opt":
_default_cfg = deepcopy(opt_default_cfg)
if self.conf["precision"] == "mp":
_default_cfg["mp"] = True
elif self.conf["precision"] == "fp16":
_default_cfg["half"] = True
model_path = self._download_model(
repo_id=MODEL_REPO_ID,
filename="{}/{}".format(Path(__file__).stem, self.conf["model_name"]),
)
cfg = _default_cfg
cfg["match_coarse"]["thr"] = conf["match_threshold"]
# cfg["match_coarse"]["skh_iters"] = conf["sinkhorn_iterations"]
state_dict = torch.load(model_path, map_location="cpu")["state_dict"]
matcher = ELoFTR_(config=cfg)
matcher.load_state_dict(state_dict)
self.net = reparameter(matcher)
if self.conf["precision"] == "fp16":
self.net = self.net.half()
logger.info(f"Loaded Efficient LoFTR with weights {conf['model_name']}")
def _forward(self, data):
# For consistency with hloc pairs, we refine kpts in image0!
rename = {
"keypoints0": "keypoints1",
"keypoints1": "keypoints0",
"image0": "image1",
"image1": "image0",
"mask0": "mask1",
"mask1": "mask0",
}
data_ = {rename[k]: v for k, v in data.items()}
with warnings.catch_warnings():
warnings.simplefilter("ignore")
pred = self.net(data_)
pred = {
"keypoints0": data_["mkpts0_f"],
"keypoints1": data_["mkpts1_f"],
}
scores = data_["mconf"]
top_k = self.conf["max_keypoints"]
if top_k is not None and len(scores) > top_k:
keep = torch.argsort(scores, descending=True)[:top_k]
pred["keypoints0"], pred["keypoints1"] = (
pred["keypoints0"][keep],
pred["keypoints1"][keep],
)
scores = scores[keep]
# Switch back indices
pred = {(rename[k] if k in rename else k): v for k, v in pred.items()}
pred["scores"] = scores
return pred