Spaces:
Running
on
Zero
Running
on
Zero
import sys | |
from pathlib import Path | |
import torch | |
from .. import DEVICE, MODEL_REPO_ID, logger | |
from ..utils.base_model import BaseModel | |
tp_path = Path(__file__).parent / "../../third_party" | |
sys.path.append(str(tp_path)) | |
from pram.nets.gml import GML | |
class IMP(BaseModel): | |
default_conf = { | |
"match_threshold": 0.2, | |
"features": "sfd2", | |
"model_name": "imp_gml.920.pth", | |
"sinkhorn_iterations": 20, | |
} | |
required_inputs = [ | |
"image0", | |
"keypoints0", | |
"scores0", | |
"descriptors0", | |
"image1", | |
"keypoints1", | |
"scores1", | |
"descriptors1", | |
] | |
def _init(self, conf): | |
self.conf = {**self.default_conf, **conf} | |
model_path = self._download_model( | |
repo_id=MODEL_REPO_ID, | |
filename="{}/{}".format("pram", self.conf["model_name"]), | |
) | |
# self.net = nets.gml(self.conf).eval().to(DEVICE) | |
self.net = GML(self.conf).eval().to(DEVICE) | |
self.net.load_state_dict( | |
torch.load(model_path, map_location="cpu")["model"], strict=True | |
) | |
logger.info("Load IMP model done.") | |
def _forward(self, data): | |
data["descriptors0"] = data["descriptors0"].transpose(2, 1).float() | |
data["descriptors1"] = data["descriptors1"].transpose(2, 1).float() | |
return self.net.produce_matches(data, p=0.2) | |