Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,350 Bytes
3040ac4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import argparse
import sys
from pathlib import Path
import numpy as np
import torch
from torchvision.transforms import ToPILImage
from .. import DEVICE, MODEL_REPO_ID
from ..utils.base_model import BaseModel
sys.path.append(str(Path(__file__).parent / "../../third_party/COTR"))
from COTR.inference.sparse_engine import SparseEngine
from COTR.models import build_model
from COTR.options.options import * # noqa: F403
from COTR.options.options_utils import * # noqa: F403
from COTR.utils import utils as utils_cotr
utils_cotr.fix_randomness(0)
torch.set_grad_enabled(False)
class COTR(BaseModel):
default_conf = {
"weights": "out/default",
"match_threshold": 0.2,
"max_keypoints": -1,
"model_name": "checkpoint.pth.tar",
}
required_inputs = ["image0", "image1"]
def _init(self, conf):
parser = argparse.ArgumentParser()
set_COTR_arguments(parser) # noqa: F405
opt = parser.parse_args()
opt.command = " ".join(sys.argv)
opt.load_weights_path = self._download_model(
repo_id=MODEL_REPO_ID,
filename="{}/{}".format(Path(__file__).stem, self.conf["model_name"]),
)
layer_2_channels = {
"layer1": 256,
"layer2": 512,
"layer3": 1024,
"layer4": 2048,
}
opt.dim_feedforward = layer_2_channels[opt.layer]
model = build_model(opt)
model = model.to(DEVICE)
weights = torch.load(opt.load_weights_path, map_location="cpu")[
"model_state_dict"
]
utils_cotr.safe_load_weights(model, weights)
self.net = model.eval()
self.to_pil_func = ToPILImage(mode="RGB")
def _forward(self, data):
img_a = np.array(self.to_pil_func(data["image0"][0].cpu()))
img_b = np.array(self.to_pil_func(data["image1"][0].cpu()))
corrs = SparseEngine(
self.net, 32, mode="tile"
).cotr_corr_multiscale_with_cycle_consistency(
img_a,
img_b,
np.linspace(0.5, 0.0625, 4),
1,
max_corrs=self.conf["max_keypoints"],
queries_a=None,
)
pred = {
"keypoints0": torch.from_numpy(corrs[:, :2]),
"keypoints1": torch.from_numpy(corrs[:, 2:]),
}
return pred
|