from html import escape
import re
import torch
import streamlit as st
import pandas as pd, numpy as np
from transformers import CLIPProcessor, CLIPModel, FlavaModel, FlavaProcessor
from st_clickable_images import clickable_images
MODEL_NAMES = ["flava-full", "vit-base-patch32", "vit-base-patch16", "vit-large-patch14", "vit-large-patch14-336"]
@st.cache(allow_output_mutation=True)
def load():
df = {0: pd.read_csv("data.csv"), 1: pd.read_csv("data2.csv")}
models = {}
processors = {}
embeddings = {}
for name in MODEL_NAMES:
if "flava" not in name:
model = CLIPModel
processor = CLIPProcessor
prefix = "openai/clip-"
else:
model = FlavaModel
processor = FlavaProcessor
prefix = "facebook/"
models[name] = model.from_pretrained(f"{prefix}{name}")
models[name].eval()
processors[name] = processor.from_pretrained(f"{prefix}{name}")
embeddings[name] = {
0: np.load(f"embeddings-{name}.npy"),
1: np.load(f"embeddings2-{name}.npy"),
}
for k in [0, 1]:
embeddings[name][k] = embeddings[name][k] / np.linalg.norm(
embeddings[name][k], axis=1, keepdims=True
)
return models, processors, df, embeddings
models, processors, df, embeddings = load()
source = {0: "\nSource: Unsplash", 1: "\nSource: The Movie Database (TMDB)"}
def compute_text_embeddings(list_of_strings, name):
inputs = processors[name](text=list_of_strings, return_tensors="pt", padding=True)
with torch.no_grad():
result = models[name].get_text_features(**inputs)
if "flava" in name:
result = result[:, 0, :]
result = result.detach().numpy()
return result / np.linalg.norm(result, axis=1, keepdims=True)
def image_search(query, corpus, name, n_results=24):
positive_embeddings = None
def concatenate_embeddings(e1, e2):
if e1 is None:
return e2
else:
return np.concatenate((e1, e2), axis=0)
splitted_query = query.split("EXCLUDING ")
dot_product = 0
k = 0 if corpus == "Unsplash" else 1
if len(splitted_query[0]) > 0:
positive_queries = splitted_query[0].split(";")
for positive_query in positive_queries:
match = re.match(r"\[(Movies|Unsplash):(\d{1,5})\](.*)", positive_query)
if match:
corpus2, idx, remainder = match.groups()
idx, remainder = int(idx), remainder.strip()
k2 = 0 if corpus2 == "Unsplash" else 1
positive_embeddings = concatenate_embeddings(
positive_embeddings, embeddings[name][k2][idx : idx + 1, :]
)
if len(remainder) > 0:
positive_embeddings = concatenate_embeddings(
positive_embeddings, compute_text_embeddings([remainder], name)
)
else:
positive_embeddings = concatenate_embeddings(
positive_embeddings, compute_text_embeddings([positive_query], name)
)
dot_product = embeddings[name][k] @ positive_embeddings.T
dot_product = dot_product - np.median(dot_product, axis=0)
dot_product = dot_product / np.max(dot_product, axis=0, keepdims=True)
dot_product = np.min(dot_product, axis=1)
if len(splitted_query) > 1:
negative_queries = (" ".join(splitted_query[1:])).split(";")
negative_embeddings = compute_text_embeddings(negative_queries, name)
dot_product2 = embeddings[name][k] @ negative_embeddings.T
dot_product2 = dot_product2 - np.median(dot_product2, axis=0)
dot_product2 = dot_product2 / np.max(dot_product2, axis=0, keepdims=True)
dot_product -= np.max(np.maximum(dot_product2, 0), axis=1)
results = np.argsort(dot_product)[-1 : -n_results - 1 : -1]
return [
(
df[k].iloc[i]["path"],
df[k].iloc[i]["tooltip"] + source[k],
i,
)
for i in results
]
description = """
# FLAVA Semantic Image-Text Search
"""
instruction= """
### **Enter your query and hit enter**
**Things to try:** compare with other models or search for "a field in country side EXCLUDING green"
"""
credit = """
*Built with FAIR's [FLAVA](https://arxiv.org/abs/2112.04482) models, 🤗 Hugging Face's [transformers library](https://huggingface.co/transformers/), [Streamlit](https://streamlit.io/), 25k images from [Unsplash](https://unsplash.com/) and 8k images from [The Movie Database (TMDB)](https://www.themoviedb.org/)*
*Forked and inspired from a similar app available [here](https://huggingface.co/spaces/vivien/clip/)*
"""
options = """
## Compare
Check results for a single model or compare two models by using the dropdown below:
"""
howto = """
## Advanced Use
- Click on an image to use it as a query and find similar images
- Several queries, including one based on an image, can be combined (use "**;**" as a separator).
- Try "a person walking on a grass field; red flowers".
- If the input includes "**EXCLUDING**", text following it will be used as a negative query.
- Try "a field in country side which is green" and "a field in countryside EXCLUDING green".
"""
div_style = {
"display": "flex",
"justify-content": "center",
"flex-wrap": "wrap",
}
def main():
st.markdown(
"""
""",
unsafe_allow_html=True,
)
st.sidebar.markdown(description)
st.sidebar.markdown(options)
mode = st.sidebar.selectbox(
"", ["Results for FLAVA full", "Comparison of 2 models"], index=0
)
st.sidebar.markdown(howto)
st.sidebar.markdown(credit)
_, c, _ = st.columns((1, 3, 1))
c.markdown(instruction)
if "query" in st.session_state:
query = c.text_input("", value=st.session_state["query"])
else:
query = c.text_input("", value="a field in the countryside which is green")
corpus = st.radio("", ["Unsplash", "Movies"])
models_dict = {
"FLAVA": "flava-full",
"ViT-B/32 (quickest)": "vit-base-patch32",
"ViT-B/16 (quick)": "vit-base-patch16",
"ViT-L/14 (slow)": "vit-large-patch14",
"ViT-L/14@336px (slowest)": "vit-large-patch14-336",
}
if "Comparison" in mode:
c1, c2 = st.columns((1, 1))
selection1 = c1.selectbox("", models_dict.keys(), index=0)
selection2 = c2.selectbox("", models_dict.keys(), index=3)
name1 = models_dict[selection1]
name2 = models_dict[selection2]
else:
name1 = MODEL_NAMES[0]
if len(query) > 0:
results1 = image_search(query, corpus, name1)
if "Comparison" in mode:
with c1:
clicked1 = clickable_images(
[result[0] for result in results1],
titles=[result[1] for result in results1],
div_style=div_style,
img_style={"margin": "2px", "height": "150px"},
key=query + corpus + name1 + "1",
)
results2 = image_search(query, corpus, name2)
with c2:
clicked2 = clickable_images(
[result[0] for result in results2],
titles=[result[1] for result in results2],
div_style=div_style,
img_style={"margin": "2px", "height": "150px"},
key=query + corpus + name2 + "2",
)
else:
clicked1 = clickable_images(
[result[0] for result in results1],
titles=[result[1] for result in results1],
div_style=div_style,
img_style={"margin": "2px", "height": "200px"},
key=query + corpus + name1 + "1",
)
clicked2 = -1
if clicked2 >= 0 or clicked1 >= 0:
change_query = False
if "last_clicked" not in st.session_state:
change_query = True
else:
if max(clicked2, clicked1) != st.session_state["last_clicked"]:
change_query = True
if change_query:
if clicked1 >= 0:
st.session_state["query"] = f"[{corpus}:{results1[clicked1][2]}]"
elif clicked2 >= 0:
st.session_state["query"] = f"[{corpus}:{results2[clicked2][2]}]"
st.experimental_rerun()
if __name__ == "__main__":
main()