import os import sys from pathlib import Path from zipfile import ZipFile import gdown import sklearn import torch from ..utils.base_model import BaseModel sys.path.append(str(Path(__file__).parent / "../../third_party/deep-image-retrieval")) os.environ["DB_ROOT"] = "" # required by dirtorch from dirtorch.extract_features import load_model # noqa: E402 from dirtorch.utils import common # noqa: E402 # The DIR model checkpoints (pickle files) include sklearn.decomposition.pca, # which has been deprecated in sklearn v0.24 # and must be explicitly imported with `from sklearn.decomposition import PCA`. # This is a hacky workaround to maintain forward compatibility. sys.modules["sklearn.decomposition.pca"] = sklearn.decomposition._pca class DIR(BaseModel): default_conf = { "model_name": "Resnet-101-AP-GeM", "whiten_name": "Landmarks_clean", "whiten_params": { "whitenp": 0.25, "whitenv": None, "whitenm": 1.0, }, "pooling": "gem", "gemp": 3, } required_inputs = ["image"] dir_models = { "Resnet-101-AP-GeM": "https://docs.google.com/uc?export=download&id=1UWJGDuHtzaQdFhSMojoYVQjmCXhIwVvy", } def _init(self, conf): # todo: download from google drive -> huggingface models checkpoint = Path(torch.hub.get_dir(), "dirtorch", conf["model_name"] + ".pt") if not checkpoint.exists(): checkpoint.parent.mkdir(exist_ok=True, parents=True) link = self.dir_models[conf["model_name"]] gdown.download(str(link), str(checkpoint) + ".zip", quiet=False) zf = ZipFile(str(checkpoint) + ".zip", "r") zf.extractall(checkpoint.parent) zf.close() os.remove(str(checkpoint) + ".zip") self.net = load_model(checkpoint, False) # first load on CPU if conf["whiten_name"]: assert conf["whiten_name"] in self.net.pca def _forward(self, data): image = data["image"] assert image.shape[1] == 3 mean = self.net.preprocess["mean"] std = self.net.preprocess["std"] image = image - image.new_tensor(mean)[:, None, None] image = image / image.new_tensor(std)[:, None, None] desc = self.net(image) desc = desc.unsqueeze(0) # batch dimension if self.conf["whiten_name"]: pca = self.net.pca[self.conf["whiten_name"]] desc = common.whiten_features( desc.cpu().numpy(), pca, **self.conf["whiten_params"] ) desc = torch.from_numpy(desc) return { "global_descriptor": desc, }