import streamlit as st import yaml import torch from lib.IRRA.tokenizer import tokenize, SimpleTokenizer from lib.IRRA.image import prepare_images from lib.IRRA.model.build import build_model, IRRA from easydict import EasyDict @st.cache_resource def get_model(): args = yaml.load(open('model/configs.yaml'), Loader=yaml.FullLoader) args = EasyDict(args) args['training'] = False model = build_model(args) return model def get_similarities(text: str, images: list[str], model: IRRA) -> torch.Tensor: tokenizer = SimpleTokenizer() txt = tokenize(text, tokenizer) imgs = prepare_images(images) image_feats = model.encode_image(imgs) text_feats = model.encode_text(txt.unsqueeze(0)) return text_feats @ image_feats.t()