Spaces:
Running
on
Zero
Running
on
Zero
import sys | |
from pathlib import Path | |
import numpy as np | |
import PIL | |
from PIL import Image | |
import cv2 | |
import torch | |
import torch.nn.functional as F | |
import os | |
from .. import DEVICE, MODEL_REPO_ID, logger | |
from ..utils.base_model import BaseModel | |
sys.path.append(str(Path(__file__).parent / "../../third_party")) | |
sys.path.append(str(Path(__file__).parent / "../../third_party/MatchAnything")) | |
from MatchAnything.src.lightning.lightning_loftr import PL_LoFTR | |
from MatchAnything.src.config.default import get_cfg_defaults | |
class MatchAnything(BaseModel): | |
required_inputs = [ | |
"image0", | |
"image1", | |
] | |
def _init(self, conf): | |
self.conf = conf | |
config = get_cfg_defaults() | |
if conf['model_name'] == 'matchanything_eloftr': | |
config_path = str(Path(__file__).parent / "../../third_party" / 'MatchAnything' / 'configs/models/eloftr_model.py') | |
config.merge_from_file(config_path) | |
# Config overwrite: | |
if config.LOFTR.COARSE.ROPE: | |
assert config.DATASET.NPE_NAME is not None | |
if config.DATASET.NPE_NAME is not None: | |
if config.DATASET.NPE_NAME == 'megadepth': | |
config.LOFTR.COARSE.NPE = [832, 832, conf['img_resize'], conf['img_resize']] | |
elif conf['model_name'] == 'matchanything_roma': | |
config_path = str(Path(__file__).parent / "../../third_party" / 'MatchAnything' / 'configs/models/roma_model.py') | |
config.merge_from_file(config_path) | |
print(f"*****************{DEVICE}, {str(DEVICE) == 'cpu'}**************************") | |
if str(DEVICE) == 'cpu': | |
config.LOFTR.FP16 = False | |
config.ROMA.MODEL.AMP = False | |
else: | |
raise NotImplementedError | |
config.METHOD = conf['model_name'] | |
config.LOFTR.MATCH_COARSE.THR = conf["match_threshold"] | |
model_path = Path(__file__).parent / "../../third_party" / 'MatchAnything'/ 'weights' / "{}.ckpt".format(conf["model_name"]) | |
self.net = PL_LoFTR(config, pretrained_ckpt=model_path, test_mode=True).matcher | |
self.net.eval().to(DEVICE) | |
logger.info(f"Loading {conf['model_name']} model done") | |
def _forward(self, data): | |
img0 = data["image0"].cpu().numpy().squeeze() * 255 | |
img1 = data["image1"].cpu().numpy().squeeze() * 255 | |
img0 = img0.transpose(1, 2, 0) | |
img1 = img1.transpose(1, 2, 0) | |
# Get original images: | |
img0, img1 = img0.astype("uint8"), img1.astype("uint8") | |
img0_size, img1_size = np.array(img0.shape[:2]), np.array(img1.shape[:2]) | |
img0_gray, img1_gray = np.array(Image.fromarray(img0).convert("L")), np.array(Image.fromarray(img1).convert("L")) | |
(img0_gray, hw0_new, mask0), (img1_gray, hw1_new, mask1)= map(lambda x: resize(x, df=32), [img0_gray, img1_gray]) | |
img0 = torch.from_numpy(img0_gray)[None][None] / 255. | |
img1 = torch.from_numpy(img1_gray)[None][None] / 255. | |
batch = {'image0': img0, 'image1': img1} | |
batch.update({'image0_rgb_origin': data['image0'], 'image1_rgb_origin': data['image1'], 'origin_img_size0': torch.from_numpy(img0_size)[None], 'origin_img_size1': torch.from_numpy(img1_size)[None]}) | |
if mask0 is not None: | |
mask0 = torch.from_numpy(mask0).to(DEVICE) | |
mask1 = torch.from_numpy(mask1).to(DEVICE) | |
[ts_mask_0, ts_mask_1] = F.interpolate(torch.stack([mask0, mask1], dim=0)[None].float(), | |
scale_factor=0.125, | |
mode='nearest', | |
recompute_scale_factor=False)[0].bool() | |
batch.update({"mask0": ts_mask_0[None], "mask1": ts_mask_1[None]}) | |
batch = dict_to_cuda(batch, device=DEVICE) | |
self.net(batch) | |
mkpts0 = batch['mkpts0_f'].cpu() | |
mkpts1 = batch['mkpts1_f'].cpu() | |
mconf = batch['mconf'].cpu() | |
if self.conf['model_name'] == 'matchanything_eloftr': | |
mkpts0 *= torch.tensor(hw0_new)[[1,0]] | |
mkpts1 *= torch.tensor(hw1_new)[[1,0]] | |
pred = { | |
"keypoints0": mkpts0, | |
"keypoints1": mkpts1, | |
"mconf": mconf, | |
} | |
return pred | |
def resize(img, resize=None, df=8, padding=True): | |
w, h = img.shape[1], img.shape[0] | |
w_new, h_new = process_resize(w, h, resize=resize, df=df, resize_no_larger_than=False) | |
img_new = resize_image(img, (w_new, h_new), interp="pil_LANCZOS").astype('float32') | |
h_scale, w_scale = img.shape[0] / img_new.shape[0], img.shape[1] / img_new.shape[1] | |
mask = None | |
if padding: | |
img_new, mask = pad_bottom_right(img_new, max(h_new, w_new), ret_mask=True) | |
return img_new, [h_scale, w_scale], mask | |
def process_resize(w, h, resize=None, df=None, resize_no_larger_than=False): | |
if resize is not None: | |
assert(len(resize) > 0 and len(resize) <= 2) | |
if resize_no_larger_than and (max(h, w) <= max(resize)): | |
w_new, h_new = w, h | |
else: | |
if len(resize) == 1 and resize[0] > -1: # resize the larger side | |
scale = resize[0] / max(h, w) | |
w_new, h_new = int(round(w*scale)), int(round(h*scale)) | |
elif len(resize) == 1 and resize[0] == -1: | |
w_new, h_new = w, h | |
else: # len(resize) == 2: | |
w_new, h_new = resize[0], resize[1] | |
else: | |
w_new, h_new = w, h | |
if df is not None: | |
w_new, h_new = map(lambda x: int(x // df * df), [w_new, h_new]) | |
return w_new, h_new | |
def resize_image(image, size, interp): | |
if interp.startswith('cv2_'): | |
interp = getattr(cv2, 'INTER_'+interp[len('cv2_'):].upper()) | |
h, w = image.shape[:2] | |
if interp == cv2.INTER_AREA and (w < size[0] or h < size[1]): | |
interp = cv2.INTER_LINEAR | |
resized = cv2.resize(image, size, interpolation=interp) | |
elif interp.startswith('pil_'): | |
interp = getattr(PIL.Image, interp[len('pil_'):].upper()) | |
resized = PIL.Image.fromarray(image.astype(np.uint8)) | |
resized = resized.resize(size, resample=interp) | |
resized = np.asarray(resized, dtype=image.dtype) | |
else: | |
raise ValueError( | |
f'Unknown interpolation {interp}.') | |
return resized | |
def pad_bottom_right(inp, pad_size, ret_mask=False): | |
assert isinstance(pad_size, int) and pad_size >= max(inp.shape[-2:]), f"{pad_size} < {max(inp.shape[-2:])}" | |
mask = None | |
if inp.ndim == 2: | |
padded = np.zeros((pad_size, pad_size), dtype=inp.dtype) | |
padded[:inp.shape[0], :inp.shape[1]] = inp | |
if ret_mask: | |
mask = np.zeros((pad_size, pad_size), dtype=bool) | |
mask[:inp.shape[0], :inp.shape[1]] = True | |
elif inp.ndim == 3: | |
padded = np.zeros((inp.shape[0], pad_size, pad_size), dtype=inp.dtype) | |
padded[:, :inp.shape[1], :inp.shape[2]] = inp | |
if ret_mask: | |
mask = np.zeros((inp.shape[0], pad_size, pad_size), dtype=bool) | |
mask[:, :inp.shape[1], :inp.shape[2]] = True | |
mask = mask[0] | |
else: | |
raise NotImplementedError() | |
return padded, mask | |
def dict_to_cuda(data_dict, device='cuda'): | |
data_dict_cuda = {} | |
for k, v in data_dict.items(): | |
if isinstance(v, torch.Tensor): | |
data_dict_cuda[k] = v.to(device) | |
elif isinstance(v, dict): | |
data_dict_cuda[k] = dict_to_cuda(v, device=device) | |
elif isinstance(v, list): | |
data_dict_cuda[k] = list_to_cuda(v, device=device) | |
else: | |
data_dict_cuda[k] = v | |
return data_dict_cuda | |
def list_to_cuda(data_list, device='cuda'): | |
data_list_cuda = [] | |
for obj in data_list: | |
if isinstance(obj, torch.Tensor): | |
data_list_cuda.append(obj.cuda()) | |
elif isinstance(obj, dict): | |
data_list_cuda.append(dict_to_cuda(obj, device=device)) | |
elif isinstance(obj, list): | |
data_list_cuda.append(list_to_cuda(obj, device=device)) | |
else: | |
data_list_cuda.append(obj) | |
return data_list_cuda |