from html import escape import re import torch import streamlit as st import pandas as pd, numpy as np from transformers import CLIPProcessor, CLIPModel, FlavaModel, FlavaProcessor from st_clickable_images import clickable_images MODEL_NAMES = ["flava-full", "vit-base-patch32", "vit-base-patch16", "vit-large-patch14", "vit-large-patch14-336"] @st.cache(allow_output_mutation=True) def load(): df = {0: pd.read_csv("data.csv"), 1: pd.read_csv("data2.csv")} models = {} processors = {} embeddings = {} for name in MODEL_NAMES: if "flava" not in name: model = CLIPModel processor = CLIPProcessor prefix = "openai/clip-" else: model = FlavaModel processor = FlavaProcessor prefix = "facebook/" models[name] = model.from_pretrained(f"{prefix}{name}") models[name].eval() processors[name] = processor.from_pretrained(f"{prefix}{name}") embeddings[name] = { 0: np.load(f"embeddings-{name}.npy"), 1: np.load(f"embeddings2-{name}.npy"), } for k in [0, 1]: embeddings[name][k] = embeddings[name][k] / np.linalg.norm( embeddings[name][k], axis=1, keepdims=True ) return models, processors, df, embeddings models, processors, df, embeddings = load() source = {0: "\nSource: Unsplash", 1: "\nSource: The Movie Database (TMDB)"} def compute_text_embeddings(list_of_strings, name): inputs = processors[name](text=list_of_strings, return_tensors="pt", padding=True) with torch.no_grad(): result = models[name].get_text_features(**inputs) if "flava" in name: result = result[:, 0, :] result = result.detach().numpy() return result / np.linalg.norm(result, axis=1, keepdims=True) def image_search(query, corpus, name, n_results=24): positive_embeddings = None def concatenate_embeddings(e1, e2): if e1 is None: return e2 else: return np.concatenate((e1, e2), axis=0) splitted_query = query.split("EXCLUDING ") dot_product = 0 k = 0 if corpus == "Unsplash" else 1 if len(splitted_query[0]) > 0: positive_queries = splitted_query[0].split(";") for positive_query in positive_queries: match = re.match(r"\[(Movies|Unsplash):(\d{1,5})\](.*)", positive_query) if match: corpus2, idx, remainder = match.groups() idx, remainder = int(idx), remainder.strip() k2 = 0 if corpus2 == "Unsplash" else 1 positive_embeddings = concatenate_embeddings( positive_embeddings, embeddings[name][k2][idx : idx + 1, :] ) if len(remainder) > 0: positive_embeddings = concatenate_embeddings( positive_embeddings, compute_text_embeddings([remainder], name) ) else: positive_embeddings = concatenate_embeddings( positive_embeddings, compute_text_embeddings([positive_query], name) ) dot_product = embeddings[name][k] @ positive_embeddings.T dot_product = dot_product - np.median(dot_product, axis=0) dot_product = dot_product / np.max(dot_product, axis=0, keepdims=True) dot_product = np.min(dot_product, axis=1) if len(splitted_query) > 1: negative_queries = (" ".join(splitted_query[1:])).split(";") negative_embeddings = compute_text_embeddings(negative_queries, name) dot_product2 = embeddings[name][k] @ negative_embeddings.T dot_product2 = dot_product2 - np.median(dot_product2, axis=0) dot_product2 = dot_product2 / np.max(dot_product2, axis=0, keepdims=True) dot_product -= np.max(np.maximum(dot_product2, 0), axis=1) results = np.argsort(dot_product)[-1 : -n_results - 1 : -1] return [ ( df[k].iloc[i]["path"], df[k].iloc[i]["tooltip"] + source[k], i, ) for i in results ] description = """ # FLAVA Semantic Image-Text Search """ instruction= """ ### **Enter your query and hit enter** **Things to try:** compare with other models or search for "a field in country side EXCLUDING green" """ credit = """ *Built with FAIR's [FLAVA](https://arxiv.org/abs/2112.04482) models, 🤗 Hugging Face's [transformers library](https://huggingface.co/transformers/), [Streamlit](https://streamlit.io/), 25k images from [Unsplash](https://unsplash.com/) and 8k images from [The Movie Database (TMDB)](https://www.themoviedb.org/)* *Forked and inspired from a similar app available [here](https://huggingface.co/spaces/vivien/clip/)* """ options = """ ## Compare Check results for a single model or compare two models by using the dropdown below: """ howto = """ ## Advanced Use - Click on an image to use it as a query and find similar images - Several queries, including one based on an image, can be combined (use "**;**" as a separator). - Try "a person walking on a grass field; red flowers". - If the input includes "**EXCLUDING**", text following it will be used as a negative query. - Try "a field in country side which is green" and "a field in countryside EXCLUDING green". """ div_style = { "display": "flex", "justify-content": "center", "flex-wrap": "wrap", } def main(): st.markdown( """ """, unsafe_allow_html=True, ) st.sidebar.markdown(description) st.sidebar.markdown(options) mode = st.sidebar.selectbox( "", ["Results for FLAVA full", "Comparison of 2 models"], index=0 ) st.sidebar.markdown(howto) st.sidebar.markdown(credit) _, c, _ = st.columns((1, 3, 1)) c.markdown(instruction) if "query" in st.session_state: query = c.text_input("", value=st.session_state["query"]) else: query = c.text_input("", value="a field in the countryside which is green") corpus = st.radio("", ["Unsplash", "Movies"]) models_dict = { "FLAVA": "flava-full", "ViT-B/32 (quickest)": "vit-base-patch32", "ViT-B/16 (quick)": "vit-base-patch16", "ViT-L/14 (slow)": "vit-large-patch14", "ViT-L/14@336px (slowest)": "vit-large-patch14-336", } if "Comparison" in mode: c1, c2 = st.columns((1, 1)) selection1 = c1.selectbox("", models_dict.keys(), index=0) selection2 = c2.selectbox("", models_dict.keys(), index=3) name1 = models_dict[selection1] name2 = models_dict[selection2] else: name1 = MODEL_NAMES[0] if len(query) > 0: results1 = image_search(query, corpus, name1) if "Comparison" in mode: with c1: clicked1 = clickable_images( [result[0] for result in results1], titles=[result[1] for result in results1], div_style=div_style, img_style={"margin": "2px", "height": "150px"}, key=query + corpus + name1 + "1", ) results2 = image_search(query, corpus, name2) with c2: clicked2 = clickable_images( [result[0] for result in results2], titles=[result[1] for result in results2], div_style=div_style, img_style={"margin": "2px", "height": "150px"}, key=query + corpus + name2 + "2", ) else: clicked1 = clickable_images( [result[0] for result in results1], titles=[result[1] for result in results1], div_style=div_style, img_style={"margin": "2px", "height": "200px"}, key=query + corpus + name1 + "1", ) clicked2 = -1 if clicked2 >= 0 or clicked1 >= 0: change_query = False if "last_clicked" not in st.session_state: change_query = True else: if max(clicked2, clicked1) != st.session_state["last_clicked"]: change_query = True if change_query: if clicked1 >= 0: st.session_state["query"] = f"[{corpus}:{results1[clicked1][2]}]" elif clicked2 >= 0: st.session_state["query"] = f"[{corpus}:{results2[clicked2][2]}]" st.experimental_rerun() if __name__ == "__main__": main()