File size: 3,188 Bytes
3040ac4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import sys
import warnings
from copy import deepcopy
from pathlib import Path

import torch

from .. import MODEL_REPO_ID, logger

tp_path = Path(__file__).parent / "../../third_party"
sys.path.append(str(tp_path))

from EfficientLoFTR.src.loftr import LoFTR as ELoFTR_
from EfficientLoFTR.src.loftr import (
    full_default_cfg,
    opt_default_cfg,
    reparameter,
)


from ..utils.base_model import BaseModel


class ELoFTR(BaseModel):
    default_conf = {
        "model_name": "eloftr_outdoor.ckpt",
        "match_threshold": 0.2,
        # "sinkhorn_iterations": 20,
        "max_keypoints": -1,
        # You can choose model type in ['full', 'opt']
        "model_type": "full",  # 'full' for best quality, 'opt' for best efficiency
        # You can choose numerical precision in ['fp32', 'mp', 'fp16']. 'fp16' for best efficiency
        "precision": "fp32",
    }
    required_inputs = ["image0", "image1"]

    def _init(self, conf):
        if self.conf["model_type"] == "full":
            _default_cfg = deepcopy(full_default_cfg)
        elif self.conf["model_type"] == "opt":
            _default_cfg = deepcopy(opt_default_cfg)

        if self.conf["precision"] == "mp":
            _default_cfg["mp"] = True
        elif self.conf["precision"] == "fp16":
            _default_cfg["half"] = True

        model_path = self._download_model(
            repo_id=MODEL_REPO_ID,
            filename="{}/{}".format(Path(__file__).stem, self.conf["model_name"]),
        )

        cfg = _default_cfg
        cfg["match_coarse"]["thr"] = conf["match_threshold"]
        # cfg["match_coarse"]["skh_iters"] = conf["sinkhorn_iterations"]
        state_dict = torch.load(model_path, map_location="cpu")["state_dict"]
        matcher = ELoFTR_(config=cfg)
        matcher.load_state_dict(state_dict)
        self.net = reparameter(matcher)

        if self.conf["precision"] == "fp16":
            self.net = self.net.half()
        logger.info(f"Loaded Efficient LoFTR with weights {conf['model_name']}")

    def _forward(self, data):
        # For consistency with hloc pairs, we refine kpts in image0!
        rename = {
            "keypoints0": "keypoints1",
            "keypoints1": "keypoints0",
            "image0": "image1",
            "image1": "image0",
            "mask0": "mask1",
            "mask1": "mask0",
        }
        data_ = {rename[k]: v for k, v in data.items()}
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            pred = self.net(data_)
        pred = {
            "keypoints0": data_["mkpts0_f"],
            "keypoints1": data_["mkpts1_f"],
        }
        scores = data_["mconf"]

        top_k = self.conf["max_keypoints"]
        if top_k is not None and len(scores) > top_k:
            keep = torch.argsort(scores, descending=True)[:top_k]
            pred["keypoints0"], pred["keypoints1"] = (
                pred["keypoints0"][keep],
                pred["keypoints1"][keep],
            )
            scores = scores[keep]

        # Switch back indices
        pred = {(rename[k] if k in rename else k): v for k, v in pred.items()}
        pred["scores"] = scores
        return pred