from functools import lru_cache import numpy as np import torch from sentence_transformers import SentenceTransformer DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' class SBert: def __init__(self, path): print(f'Loading model from {path} ...') self.model = SentenceTransformer(path, device=DEVICE) @lru_cache(maxsize=10000) def __call__(self, x) -> np.ndarray: y = self.model.encode(x) return y