Spaces:
Running
on
Zero
Running
on
Zero
import sys | |
from pathlib import Path | |
import torchvision.transforms as tvf | |
from .. import MODEL_REPO_ID, logger | |
from ..utils.base_model import BaseModel | |
r2d2_path = Path(__file__).parents[2] / "third_party/r2d2" | |
sys.path.append(str(r2d2_path)) | |
gim_path = Path(__file__).parents[2] / "third_party/gim" | |
if str(gim_path) in sys.path: | |
sys.path.remove(str(gim_path)) | |
from extract import NonMaxSuppression, extract_multiscale, load_network | |
class R2D2(BaseModel): | |
default_conf = { | |
"model_name": "r2d2_WASF_N16.pt", | |
"max_keypoints": 5000, | |
"scale_factor": 2**0.25, | |
"min_size": 256, | |
"max_size": 1024, | |
"min_scale": 0, | |
"max_scale": 1, | |
"reliability_threshold": 0.7, | |
"repetability_threshold": 0.7, | |
} | |
required_inputs = ["image"] | |
def _init(self, conf): | |
model_path = self._download_model( | |
repo_id=MODEL_REPO_ID, | |
filename="{}/{}".format(Path(__file__).stem, self.conf["model_name"]), | |
) | |
self.norm_rgb = tvf.Normalize( | |
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] | |
) | |
self.net = load_network(model_path) | |
self.detector = NonMaxSuppression( | |
rel_thr=conf["reliability_threshold"], | |
rep_thr=conf["repetability_threshold"], | |
) | |
logger.info("Load R2D2 model done.") | |
def _forward(self, data): | |
img = data["image"] | |
img = self.norm_rgb(img) | |
xys, desc, scores = extract_multiscale( | |
self.net, | |
img, | |
self.detector, | |
scale_f=self.conf["scale_factor"], | |
min_size=self.conf["min_size"], | |
max_size=self.conf["max_size"], | |
min_scale=self.conf["min_scale"], | |
max_scale=self.conf["max_scale"], | |
) | |
idxs = scores.argsort()[-self.conf["max_keypoints"] or None :] | |
xy = xys[idxs, :2] | |
desc = desc[idxs].t() | |
scores = scores[idxs] | |
pred = { | |
"keypoints": xy[None], | |
"descriptors": desc[None], | |
"scores": scores[None], | |
} | |
return pred | |