Spaces:
Runtime error
Runtime error
Resolves dependency issues from latest streamlit
#1
by
mattupson
- opened
- Makefile +35 -0
- app.py +86 -47
- requirements.txt +77 -8
- unpinned_requirements.txt +10 -0
Makefile
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#################################################################################
|
2 |
+
# GLOBALS #
|
3 |
+
#################################################################################
|
4 |
+
|
5 |
+
PYTHON_VERSION = python3.10
|
6 |
+
VIRTUALENV := .venv
|
7 |
+
|
8 |
+
#################################################################################
|
9 |
+
# COMMANDS #
|
10 |
+
#################################################################################
|
11 |
+
|
12 |
+
# Set the default location for the virtualenv to be stored
|
13 |
+
# Create the virtualenv by installing the requirements and test requirements
|
14 |
+
|
15 |
+
|
16 |
+
.PHONY: virtualenv
|
17 |
+
virtualenv: requirements.txt
|
18 |
+
@if [ -d $(VIRTUALENV) ]; then rm -rf $(VIRTUALENV); fi
|
19 |
+
@mkdir -p $(VIRTUALENV)
|
20 |
+
$(PYTHON_VERSION) -m venv $(VIRTUALENV)
|
21 |
+
$(VIRTUALENV)/bin/pip install --upgrade pip
|
22 |
+
$(VIRTUALENV)/bin/pip install --upgrade -r requirements.txt
|
23 |
+
touch $@
|
24 |
+
|
25 |
+
.PHONY: update-requirements-txt
|
26 |
+
update-requirements-txt: unpinned_requirements.txt
|
27 |
+
update-requirements-txt: VIRTUALENV := /tmp/update-requirements-virtualenv
|
28 |
+
update-requirements-txt:
|
29 |
+
@if [ -d $(VIRTUALENV) ]; then rm -rf $(VIRTUALENV); fi
|
30 |
+
@mkdir -p $(VIRTUALENV)
|
31 |
+
virtualenv --python $(PYTHON_VERSION) $(VIRTUALENV)
|
32 |
+
$(VIRTUALENV)/bin/pip install --upgrade pip
|
33 |
+
$(VIRTUALENV)/bin/pip install --upgrade -r unpinned_requirements.txt
|
34 |
+
echo "# Created by 'make update-requirements-txt'. DO NOT EDIT!" > requirements.txt
|
35 |
+
$(VIRTUALENV)/bin/pip freeze | grep -v pkg_resources==0.0.0 >> requirements.txt
|
app.py
CHANGED
@@ -1,12 +1,14 @@
|
|
|
|
|
|
|
|
|
|
1 |
import nmslib
|
2 |
import numpy as np
|
3 |
import streamlit as st
|
|
|
4 |
from transformers import AutoTokenizer, CLIPProcessor
|
|
|
5 |
from model import FlaxHybridCLIP
|
6 |
-
from PIL import Image
|
7 |
-
import jax.numpy as jnp
|
8 |
-
import os
|
9 |
-
import jax
|
10 |
|
11 |
# st.header('Under construction')
|
12 |
|
@@ -14,17 +16,17 @@ import jax
|
|
14 |
st.sidebar.title("CLIP React Demo")
|
15 |
|
16 |
st.sidebar.write("[Model Card](https://huggingface.co/flax-community/clip-reply)")
|
17 |
-
sc= st.sidebar.columns(2)
|
18 |
|
19 |
-
sc[0].image("./huggingface_explode3.png",width=150)
|
20 |
sc[1].write(" ")
|
21 |
sc[1].write(" ")
|
22 |
sc[1].markdown("## Researching fun")
|
23 |
|
24 |
-
with st.sidebar.expander("Motivation",expanded=True):
|
25 |
st.markdown(
|
26 |
"""
|
27 |
-
Reaction GIFs became an integral part of communication.
|
28 |
They convey complex emotions with many levels, in a short compact format.
|
29 |
|
30 |
If a picture is worth a thousand words then a GIF is worth more.
|
@@ -32,37 +34,44 @@ with st.sidebar.expander("Motivation",expanded=True):
|
|
32 |
This is just a first step in the more ambitious goal of GIF/Image generation.
|
33 |
"""
|
34 |
)
|
35 |
-
top_k=st.sidebar.slider("Show top-K", min_value=1, max_value=50, value=20)
|
36 |
-
col_count=4
|
37 |
-
file_names=os.listdir("./jpg")
|
38 |
file_names.sort()
|
39 |
|
40 |
-
show_val=st.sidebar.button("show all validation set images")
|
|
|
41 |
if show_val:
|
42 |
-
cols=st.sidebar.columns(col_count)
|
43 |
-
|
44 |
-
|
45 |
-
|
|
|
46 |
|
47 |
st.write("# Search Reaction GIFs with CLIP ")
|
48 |
st.write(" ")
|
49 |
st.write(" ")
|
50 |
-
|
|
|
|
|
51 |
def load_model():
|
52 |
model = FlaxHybridCLIP.from_pretrained("ceyda/clip-reply")
|
53 |
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
54 |
-
processor.tokenizer = AutoTokenizer.from_pretrained(
|
|
|
|
|
|
|
55 |
return model, processor
|
56 |
|
57 |
-
|
|
|
58 |
def load_image_index():
|
59 |
-
index = nmslib.init(method=
|
60 |
index.loadIndex("./features/image_embeddings", load_data=True)
|
61 |
|
62 |
return index
|
63 |
|
64 |
|
65 |
-
|
66 |
image_index = load_image_index()
|
67 |
model, processor = load_model()
|
68 |
|
@@ -72,22 +81,26 @@ def add_image_emb(image):
|
|
72 |
image = Image.open(image).convert("RGB")
|
73 |
|
74 |
inputs = processor(text=[""], images=image, return_tensors="jax", padding=True)
|
75 |
-
|
76 |
inputs["pixel_values"] = jnp.transpose(inputs["pixel_values"], axes=[0, 2, 3, 1])
|
77 |
features = model(**inputs).image_embeds
|
78 |
-
|
79 |
image_index.addDataPoint(features)
|
80 |
|
81 |
|
82 |
-
def query_with_images(query_images,query_text):
|
83 |
-
images=[]
|
84 |
-
|
85 |
-
|
|
|
|
|
86 |
if im.name.endswith(".gif"):
|
87 |
img.seek(0)
|
88 |
images.append(img)
|
89 |
|
90 |
-
inputs = processor(
|
|
|
|
|
91 |
inputs["pixel_values"] = jnp.transpose(inputs["pixel_values"], axes=[0, 2, 3, 1])
|
92 |
outputs = model(**inputs)
|
93 |
logits_per_image = outputs.logits_per_image.reshape(-1)
|
@@ -95,53 +108,79 @@ def query_with_images(query_images,query_text):
|
|
95 |
probs = jax.nn.softmax(logits_per_image)
|
96 |
# st.write(probs)
|
97 |
# st.write(list(zip(images,probs)))
|
98 |
-
results = sorted(list(zip(images,probs)),key=lambda x: x[1], reverse=True)
|
99 |
# st.write(results)
|
|
|
100 |
return zip(*results)
|
101 |
|
102 |
-
q_cols=st.columns([5,2,5])
|
103 |
|
104 |
-
|
105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
q_cols[2].markdown(
|
107 |
-
"""
|
108 |
Searches among the validation set images if not specified
|
109 |
-
|
110 |
(There may be non-exact duplicates)
|
111 |
|
112 |
"""
|
113 |
)
|
114 |
|
115 |
-
query_text = q_cols[0].text_input(
|
116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
|
118 |
if query_images:
|
119 |
st.write("Ranking your uploaded images with respect to input text:")
|
120 |
with st.spinner("Calculating..."):
|
121 |
-
ids, dists = query_with_images(query_images,query_text)
|
122 |
else:
|
123 |
st.write("Found these images within validation set:")
|
124 |
with st.spinner("Calculating..."):
|
125 |
-
proc = processor(
|
|
|
|
|
126 |
vec = np.asarray(model.get_text_features(**proc))
|
127 |
ids, dists = image_index.knnQuery(vec, k=top_k)
|
128 |
|
129 |
-
show_gif=st.checkbox(
|
|
|
|
|
|
|
|
|
130 |
ext = "jpg" if not show_gif else "gif"
|
131 |
-
res_cols=st.columns(col_count)
|
132 |
|
133 |
|
134 |
-
for i,(id_, dist) in enumerate(zip(ids, dists)):
|
135 |
-
j=i%col_count
|
136 |
with res_cols[j]:
|
137 |
if isinstance(id_, np.int32):
|
138 |
st.image(f"./{ext}/{file_names[id_][:-4]}.{ext}")
|
139 |
# st.write(file_names[id_])
|
140 |
-
st.write(1.0 - dist
|
141 |
else:
|
142 |
st.image(id_)
|
143 |
-
st.write(dist
|
144 |
-
|
145 |
-
|
146 |
-
# Credits
|
147 |
st.sidebar.caption("Made by [Ceyda Cinarel](https://huggingface.co/ceyda)")
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import jax
|
4 |
+
import jax.numpy as jnp
|
5 |
import nmslib
|
6 |
import numpy as np
|
7 |
import streamlit as st
|
8 |
+
from PIL import Image
|
9 |
from transformers import AutoTokenizer, CLIPProcessor
|
10 |
+
|
11 |
from model import FlaxHybridCLIP
|
|
|
|
|
|
|
|
|
12 |
|
13 |
# st.header('Under construction')
|
14 |
|
|
|
16 |
st.sidebar.title("CLIP React Demo")
|
17 |
|
18 |
st.sidebar.write("[Model Card](https://huggingface.co/flax-community/clip-reply)")
|
19 |
+
sc = st.sidebar.columns(2)
|
20 |
|
21 |
+
sc[0].image("./huggingface_explode3.png", width=150)
|
22 |
sc[1].write(" ")
|
23 |
sc[1].write(" ")
|
24 |
sc[1].markdown("## Researching fun")
|
25 |
|
26 |
+
with st.sidebar.expander("Motivation", expanded=True):
|
27 |
st.markdown(
|
28 |
"""
|
29 |
+
Reaction GIFs became an integral part of communication.
|
30 |
They convey complex emotions with many levels, in a short compact format.
|
31 |
|
32 |
If a picture is worth a thousand words then a GIF is worth more.
|
|
|
34 |
This is just a first step in the more ambitious goal of GIF/Image generation.
|
35 |
"""
|
36 |
)
|
37 |
+
top_k = st.sidebar.slider("Show top-K", min_value=1, max_value=50, value=20)
|
38 |
+
col_count = 4
|
39 |
+
file_names = os.listdir("./jpg")
|
40 |
file_names.sort()
|
41 |
|
42 |
+
show_val = st.sidebar.button("show all validation set images")
|
43 |
+
|
44 |
if show_val:
|
45 |
+
cols = st.sidebar.columns(col_count)
|
46 |
+
|
47 |
+
for i, im in enumerate(file_names):
|
48 |
+
j = i % col_count
|
49 |
+
cols[j].image("./jpg/" + im)
|
50 |
|
51 |
st.write("# Search Reaction GIFs with CLIP ")
|
52 |
st.write(" ")
|
53 |
st.write(" ")
|
54 |
+
|
55 |
+
|
56 |
+
@st.cache_resource()
|
57 |
def load_model():
|
58 |
model = FlaxHybridCLIP.from_pretrained("ceyda/clip-reply")
|
59 |
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
60 |
+
processor.tokenizer = AutoTokenizer.from_pretrained(
|
61 |
+
"cardiffnlp/twitter-roberta-base"
|
62 |
+
)
|
63 |
+
|
64 |
return model, processor
|
65 |
|
66 |
+
|
67 |
+
@st.cache_resource()
|
68 |
def load_image_index():
|
69 |
+
index = nmslib.init(method="hnsw", space="cosinesimil")
|
70 |
index.loadIndex("./features/image_embeddings", load_data=True)
|
71 |
|
72 |
return index
|
73 |
|
74 |
|
|
|
75 |
image_index = load_image_index()
|
76 |
model, processor = load_model()
|
77 |
|
|
|
81 |
image = Image.open(image).convert("RGB")
|
82 |
|
83 |
inputs = processor(text=[""], images=image, return_tensors="jax", padding=True)
|
84 |
+
|
85 |
inputs["pixel_values"] = jnp.transpose(inputs["pixel_values"], axes=[0, 2, 3, 1])
|
86 |
features = model(**inputs).image_embeds
|
87 |
+
|
88 |
image_index.addDataPoint(features)
|
89 |
|
90 |
|
91 |
+
def query_with_images(query_images, query_text):
|
92 |
+
images = []
|
93 |
+
|
94 |
+
for im in query_images:
|
95 |
+
img = Image.open(im).convert("RGB")
|
96 |
+
|
97 |
if im.name.endswith(".gif"):
|
98 |
img.seek(0)
|
99 |
images.append(img)
|
100 |
|
101 |
+
inputs = processor(
|
102 |
+
text=[query_text], images=images, return_tensors="jax", padding=True
|
103 |
+
)
|
104 |
inputs["pixel_values"] = jnp.transpose(inputs["pixel_values"], axes=[0, 2, 3, 1])
|
105 |
outputs = model(**inputs)
|
106 |
logits_per_image = outputs.logits_per_image.reshape(-1)
|
|
|
108 |
probs = jax.nn.softmax(logits_per_image)
|
109 |
# st.write(probs)
|
110 |
# st.write(list(zip(images,probs)))
|
111 |
+
results = sorted(list(zip(images, probs)), key=lambda x: x[1], reverse=True)
|
112 |
# st.write(results)
|
113 |
+
|
114 |
return zip(*results)
|
115 |
|
|
|
116 |
|
117 |
+
q_cols = st.columns([5, 2, 5])
|
118 |
+
|
119 |
+
examples = [
|
120 |
+
"OMG that is disgusting",
|
121 |
+
"I'm so scared right now",
|
122 |
+
" I got the job 🎉",
|
123 |
+
"Congratulations to all the flax-community week teams",
|
124 |
+
"You're awesome",
|
125 |
+
"I love you ❤️",
|
126 |
+
]
|
127 |
+
example_input = q_cols[0].radio(
|
128 |
+
"Example Queries :",
|
129 |
+
examples,
|
130 |
+
index=4,
|
131 |
+
help="These are examples I wrote off the top of my head. They don't occur in the dataset",
|
132 |
+
)
|
133 |
q_cols[2].markdown(
|
134 |
+
"""
|
135 |
Searches among the validation set images if not specified
|
136 |
+
|
137 |
(There may be non-exact duplicates)
|
138 |
|
139 |
"""
|
140 |
)
|
141 |
|
142 |
+
query_text = q_cols[0].text_input(
|
143 |
+
"Write text you want to get reaction for", value=example_input
|
144 |
+
)
|
145 |
+
query_images = q_cols[2].file_uploader(
|
146 |
+
"(optional) Upload images to rank them",
|
147 |
+
type=["jpg", "jpeg", "gif"],
|
148 |
+
accept_multiple_files=True,
|
149 |
+
)
|
150 |
|
151 |
if query_images:
|
152 |
st.write("Ranking your uploaded images with respect to input text:")
|
153 |
with st.spinner("Calculating..."):
|
154 |
+
ids, dists = query_with_images(query_images, query_text)
|
155 |
else:
|
156 |
st.write("Found these images within validation set:")
|
157 |
with st.spinner("Calculating..."):
|
158 |
+
proc = processor(
|
159 |
+
text=[query_text], images=None, return_tensors="jax", padding=True
|
160 |
+
)
|
161 |
vec = np.asarray(model.get_text_features(**proc))
|
162 |
ids, dists = image_index.knnQuery(vec, k=top_k)
|
163 |
|
164 |
+
show_gif = st.checkbox(
|
165 |
+
"Play GIFs",
|
166 |
+
value=True,
|
167 |
+
help="Will play the original animation. Only first frame is used in training!",
|
168 |
+
)
|
169 |
ext = "jpg" if not show_gif else "gif"
|
170 |
+
res_cols = st.columns(col_count)
|
171 |
|
172 |
|
173 |
+
for i, (id_, dist) in enumerate(zip(ids, dists)):
|
174 |
+
j = i % col_count
|
175 |
with res_cols[j]:
|
176 |
if isinstance(id_, np.int32):
|
177 |
st.image(f"./{ext}/{file_names[id_][:-4]}.{ext}")
|
178 |
# st.write(file_names[id_])
|
179 |
+
st.write(1.0 - dist)
|
180 |
else:
|
181 |
st.image(id_)
|
182 |
+
st.write(dist)
|
183 |
+
|
184 |
+
|
185 |
+
# Credits
|
186 |
st.sidebar.caption("Made by [Ceyda Cinarel](https://huggingface.co/ceyda)")
|
requirements.txt
CHANGED
@@ -1,8 +1,77 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Created by 'make update-requirements-txt'. DO NOT EDIT!
|
2 |
+
absl-py==1.4.0
|
3 |
+
altair==4.2.2
|
4 |
+
attrs==23.1.0
|
5 |
+
blinker==1.6.2
|
6 |
+
cachetools==5.3.1
|
7 |
+
certifi==2023.5.7
|
8 |
+
charset-normalizer==3.2.0
|
9 |
+
chex==0.1.82
|
10 |
+
click==8.1.6
|
11 |
+
decorator==5.1.1
|
12 |
+
entrypoints==0.4
|
13 |
+
etils==1.3.0
|
14 |
+
filelock==3.12.2
|
15 |
+
flax==0.7.0
|
16 |
+
fsspec==2023.6.0
|
17 |
+
gitdb==4.0.10
|
18 |
+
GitPython==3.1.32
|
19 |
+
huggingface-hub==0.16.4
|
20 |
+
idna==3.4
|
21 |
+
importlib-metadata==6.8.0
|
22 |
+
importlib-resources==6.0.0
|
23 |
+
jax==0.4.13
|
24 |
+
jaxlib==0.4.13
|
25 |
+
Jinja2==3.1.2
|
26 |
+
jsonschema==4.18.4
|
27 |
+
jsonschema-specifications==2023.7.1
|
28 |
+
markdown-it-py==3.0.0
|
29 |
+
MarkupSafe==2.1.3
|
30 |
+
mdurl==0.1.2
|
31 |
+
ml-dtypes==0.2.0
|
32 |
+
msgpack==1.0.5
|
33 |
+
nest-asyncio==1.5.6
|
34 |
+
nmslib==2.1.1
|
35 |
+
numpy==1.25.1
|
36 |
+
opt-einsum==3.3.0
|
37 |
+
optax==0.1.5
|
38 |
+
orbax-checkpoint==0.2.7
|
39 |
+
packaging==23.1
|
40 |
+
pandas==2.0.3
|
41 |
+
Pillow==9.5.0
|
42 |
+
protobuf==4.23.4
|
43 |
+
psutil==5.9.5
|
44 |
+
pyarrow==12.0.1
|
45 |
+
pybind11==2.6.1
|
46 |
+
pydeck==0.8.0
|
47 |
+
Pygments==2.15.1
|
48 |
+
Pympler==1.0.1
|
49 |
+
python-dateutil==2.8.2
|
50 |
+
pytz==2023.3
|
51 |
+
pytz-deprecation-shim==0.1.0.post0
|
52 |
+
PyYAML==6.0.1
|
53 |
+
referencing==0.30.0
|
54 |
+
regex==2023.6.3
|
55 |
+
requests==2.31.0
|
56 |
+
rich==13.4.2
|
57 |
+
rpds-py==0.9.2
|
58 |
+
safetensors==0.3.1
|
59 |
+
scipy==1.11.1
|
60 |
+
six==1.16.0
|
61 |
+
smmap==5.0.0
|
62 |
+
streamlit==1.25.0
|
63 |
+
tenacity==8.2.2
|
64 |
+
tensorstore==0.1.40
|
65 |
+
tokenizers==0.13.3
|
66 |
+
toml==0.10.2
|
67 |
+
toolz==0.12.0
|
68 |
+
tornado==6.3.2
|
69 |
+
tqdm==4.65.0
|
70 |
+
transformers==4.31.0
|
71 |
+
typing_extensions==4.7.1
|
72 |
+
tzdata==2023.3
|
73 |
+
tzlocal==4.3.1
|
74 |
+
urllib3==2.0.4
|
75 |
+
validators==0.20.0
|
76 |
+
watchdog==3.0.0
|
77 |
+
zipp==3.16.2
|
unpinned_requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
jax
|
2 |
+
jaxlib
|
3 |
+
flax
|
4 |
+
tqdm
|
5 |
+
requests
|
6 |
+
nmslib
|
7 |
+
numpy
|
8 |
+
transformers
|
9 |
+
altair<5
|
10 |
+
streamlit
|