XingyiHe's picture
init commit
3040ac4
raw
history blame
1.82 kB
import warnings
import torch
from kornia.feature import LoFTR as LoFTR_
from kornia.feature.loftr.loftr import default_cfg
from .. import logger
from ..utils.base_model import BaseModel
class LoFTR(BaseModel):
default_conf = {
"weights": "outdoor",
"match_threshold": 0.2,
"sinkhorn_iterations": 20,
"max_keypoints": -1,
}
required_inputs = ["image0", "image1"]
def _init(self, conf):
cfg = default_cfg
cfg["match_coarse"]["thr"] = conf["match_threshold"]
cfg["match_coarse"]["skh_iters"] = conf["sinkhorn_iterations"]
self.net = LoFTR_(pretrained=conf["weights"], config=cfg)
logger.info(f"Loaded LoFTR with weights {conf['weights']}")
def _forward(self, data):
# For consistency with hloc pairs, we refine kpts in image0!
rename = {
"keypoints0": "keypoints1",
"keypoints1": "keypoints0",
"image0": "image1",
"image1": "image0",
"mask0": "mask1",
"mask1": "mask0",
}
data_ = {rename[k]: v for k, v in data.items()}
with warnings.catch_warnings():
warnings.simplefilter("ignore")
pred = self.net(data_)
scores = pred["confidence"]
top_k = self.conf["max_keypoints"]
if top_k is not None and len(scores) > top_k:
keep = torch.argsort(scores, descending=True)[:top_k]
pred["keypoints0"], pred["keypoints1"] = (
pred["keypoints0"][keep],
pred["keypoints1"][keep],
)
scores = scores[keep]
# Switch back indices
pred = {(rename[k] if k in rename else k): v for k, v in pred.items()}
pred["scores"] = scores
del pred["confidence"]
return pred