Resolves dependency issues from latest streamlit

#1
by mattupson - opened
Files changed (4) hide show
  1. Makefile +35 -0
  2. app.py +86 -47
  3. requirements.txt +77 -8
  4. unpinned_requirements.txt +10 -0
Makefile ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################
2
+ # GLOBALS #
3
+ #################################################################################
4
+
5
+ PYTHON_VERSION = python3.10
6
+ VIRTUALENV := .venv
7
+
8
+ #################################################################################
9
+ # COMMANDS #
10
+ #################################################################################
11
+
12
+ # Set the default location for the virtualenv to be stored
13
+ # Create the virtualenv by installing the requirements and test requirements
14
+
15
+
16
+ .PHONY: virtualenv
17
+ virtualenv: requirements.txt
18
+ @if [ -d $(VIRTUALENV) ]; then rm -rf $(VIRTUALENV); fi
19
+ @mkdir -p $(VIRTUALENV)
20
+ $(PYTHON_VERSION) -m venv $(VIRTUALENV)
21
+ $(VIRTUALENV)/bin/pip install --upgrade pip
22
+ $(VIRTUALENV)/bin/pip install --upgrade -r requirements.txt
23
+ touch $@
24
+
25
+ .PHONY: update-requirements-txt
26
+ update-requirements-txt: unpinned_requirements.txt
27
+ update-requirements-txt: VIRTUALENV := /tmp/update-requirements-virtualenv
28
+ update-requirements-txt:
29
+ @if [ -d $(VIRTUALENV) ]; then rm -rf $(VIRTUALENV); fi
30
+ @mkdir -p $(VIRTUALENV)
31
+ virtualenv --python $(PYTHON_VERSION) $(VIRTUALENV)
32
+ $(VIRTUALENV)/bin/pip install --upgrade pip
33
+ $(VIRTUALENV)/bin/pip install --upgrade -r unpinned_requirements.txt
34
+ echo "# Created by 'make update-requirements-txt'. DO NOT EDIT!" > requirements.txt
35
+ $(VIRTUALENV)/bin/pip freeze | grep -v pkg_resources==0.0.0 >> requirements.txt
app.py CHANGED
@@ -1,12 +1,14 @@
 
 
 
 
1
  import nmslib
2
  import numpy as np
3
  import streamlit as st
 
4
  from transformers import AutoTokenizer, CLIPProcessor
 
5
  from model import FlaxHybridCLIP
6
- from PIL import Image
7
- import jax.numpy as jnp
8
- import os
9
- import jax
10
 
11
  # st.header('Under construction')
12
 
@@ -14,17 +16,17 @@ import jax
14
  st.sidebar.title("CLIP React Demo")
15
 
16
  st.sidebar.write("[Model Card](https://huggingface.co/flax-community/clip-reply)")
17
- sc= st.sidebar.columns(2)
18
 
19
- sc[0].image("./huggingface_explode3.png",width=150)
20
  sc[1].write(" ")
21
  sc[1].write(" ")
22
  sc[1].markdown("## Researching fun")
23
 
24
- with st.sidebar.expander("Motivation",expanded=True):
25
  st.markdown(
26
  """
27
- Reaction GIFs became an integral part of communication.
28
  They convey complex emotions with many levels, in a short compact format.
29
 
30
  If a picture is worth a thousand words then a GIF is worth more.
@@ -32,37 +34,44 @@ with st.sidebar.expander("Motivation",expanded=True):
32
  This is just a first step in the more ambitious goal of GIF/Image generation.
33
  """
34
  )
35
- top_k=st.sidebar.slider("Show top-K", min_value=1, max_value=50, value=20)
36
- col_count=4
37
- file_names=os.listdir("./jpg")
38
  file_names.sort()
39
 
40
- show_val=st.sidebar.button("show all validation set images")
 
41
  if show_val:
42
- cols=st.sidebar.columns(col_count)
43
- for i,im in enumerate(file_names):
44
- j=i%col_count
45
- cols[j].image("./jpg/"+im)
 
46
 
47
  st.write("# Search Reaction GIFs with CLIP ")
48
  st.write(" ")
49
  st.write(" ")
50
- @st.cache(allow_output_mutation=True)
 
 
51
  def load_model():
52
  model = FlaxHybridCLIP.from_pretrained("ceyda/clip-reply")
53
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
54
- processor.tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base")
 
 
 
55
  return model, processor
56
 
57
- @st.cache(allow_output_mutation=True)
 
58
  def load_image_index():
59
- index = nmslib.init(method='hnsw', space='cosinesimil')
60
  index.loadIndex("./features/image_embeddings", load_data=True)
61
 
62
  return index
63
 
64
 
65
-
66
  image_index = load_image_index()
67
  model, processor = load_model()
68
 
@@ -72,22 +81,26 @@ def add_image_emb(image):
72
  image = Image.open(image).convert("RGB")
73
 
74
  inputs = processor(text=[""], images=image, return_tensors="jax", padding=True)
75
-
76
  inputs["pixel_values"] = jnp.transpose(inputs["pixel_values"], axes=[0, 2, 3, 1])
77
  features = model(**inputs).image_embeds
78
-
79
  image_index.addDataPoint(features)
80
 
81
 
82
- def query_with_images(query_images,query_text):
83
- images=[]
84
- for im in query_images:
85
- img=Image.open(im).convert("RGB")
 
 
86
  if im.name.endswith(".gif"):
87
  img.seek(0)
88
  images.append(img)
89
 
90
- inputs = processor(text=[query_text], images=images, return_tensors="jax", padding=True)
 
 
91
  inputs["pixel_values"] = jnp.transpose(inputs["pixel_values"], axes=[0, 2, 3, 1])
92
  outputs = model(**inputs)
93
  logits_per_image = outputs.logits_per_image.reshape(-1)
@@ -95,53 +108,79 @@ def query_with_images(query_images,query_text):
95
  probs = jax.nn.softmax(logits_per_image)
96
  # st.write(probs)
97
  # st.write(list(zip(images,probs)))
98
- results = sorted(list(zip(images,probs)),key=lambda x: x[1], reverse=True)
99
  # st.write(results)
 
100
  return zip(*results)
101
 
102
- q_cols=st.columns([5,2,5])
103
 
104
- examples = ["OMG that is disgusting","I'm so scared right now"," I got the job 🎉","Congratulations to all the flax-community week teams","You're awesome","I love you ❤️"]
105
- example_input = q_cols[0].radio("Example Queries :",examples,index=4,help="These are examples I wrote off the top of my head. They don't occur in the dataset")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  q_cols[2].markdown(
107
- """
108
  Searches among the validation set images if not specified
109
-
110
  (There may be non-exact duplicates)
111
 
112
  """
113
  )
114
 
115
- query_text = q_cols[0].text_input("Write text you want to get reaction for", value=example_input)
116
- query_images = q_cols[2].file_uploader("(optional) Upload images to rank them",type=['jpg','jpeg','gif'], accept_multiple_files=True)
 
 
 
 
 
 
117
 
118
  if query_images:
119
  st.write("Ranking your uploaded images with respect to input text:")
120
  with st.spinner("Calculating..."):
121
- ids, dists = query_with_images(query_images,query_text)
122
  else:
123
  st.write("Found these images within validation set:")
124
  with st.spinner("Calculating..."):
125
- proc = processor(text=[query_text], images=None, return_tensors="jax", padding=True)
 
 
126
  vec = np.asarray(model.get_text_features(**proc))
127
  ids, dists = image_index.knnQuery(vec, k=top_k)
128
 
129
- show_gif=st.checkbox("Play GIFs",value=True,help="Will play the original animation. Only first frame is used in training!")
 
 
 
 
130
  ext = "jpg" if not show_gif else "gif"
131
- res_cols=st.columns(col_count)
132
 
133
 
134
- for i,(id_, dist) in enumerate(zip(ids, dists)):
135
- j=i%col_count
136
  with res_cols[j]:
137
  if isinstance(id_, np.int32):
138
  st.image(f"./{ext}/{file_names[id_][:-4]}.{ext}")
139
  # st.write(file_names[id_])
140
- st.write(1.0 - dist, help="score")
141
  else:
142
  st.image(id_)
143
- st.write(dist, help="score")
144
-
145
-
146
- # Credits
147
  st.sidebar.caption("Made by [Ceyda Cinarel](https://huggingface.co/ceyda)")
 
1
+ import os
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
  import nmslib
6
  import numpy as np
7
  import streamlit as st
8
+ from PIL import Image
9
  from transformers import AutoTokenizer, CLIPProcessor
10
+
11
  from model import FlaxHybridCLIP
 
 
 
 
12
 
13
  # st.header('Under construction')
14
 
 
16
  st.sidebar.title("CLIP React Demo")
17
 
18
  st.sidebar.write("[Model Card](https://huggingface.co/flax-community/clip-reply)")
19
+ sc = st.sidebar.columns(2)
20
 
21
+ sc[0].image("./huggingface_explode3.png", width=150)
22
  sc[1].write(" ")
23
  sc[1].write(" ")
24
  sc[1].markdown("## Researching fun")
25
 
26
+ with st.sidebar.expander("Motivation", expanded=True):
27
  st.markdown(
28
  """
29
+ Reaction GIFs became an integral part of communication.
30
  They convey complex emotions with many levels, in a short compact format.
31
 
32
  If a picture is worth a thousand words then a GIF is worth more.
 
34
  This is just a first step in the more ambitious goal of GIF/Image generation.
35
  """
36
  )
37
+ top_k = st.sidebar.slider("Show top-K", min_value=1, max_value=50, value=20)
38
+ col_count = 4
39
+ file_names = os.listdir("./jpg")
40
  file_names.sort()
41
 
42
+ show_val = st.sidebar.button("show all validation set images")
43
+
44
  if show_val:
45
+ cols = st.sidebar.columns(col_count)
46
+
47
+ for i, im in enumerate(file_names):
48
+ j = i % col_count
49
+ cols[j].image("./jpg/" + im)
50
 
51
  st.write("# Search Reaction GIFs with CLIP ")
52
  st.write(" ")
53
  st.write(" ")
54
+
55
+
56
+ @st.cache_resource()
57
  def load_model():
58
  model = FlaxHybridCLIP.from_pretrained("ceyda/clip-reply")
59
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
60
+ processor.tokenizer = AutoTokenizer.from_pretrained(
61
+ "cardiffnlp/twitter-roberta-base"
62
+ )
63
+
64
  return model, processor
65
 
66
+
67
+ @st.cache_resource()
68
  def load_image_index():
69
+ index = nmslib.init(method="hnsw", space="cosinesimil")
70
  index.loadIndex("./features/image_embeddings", load_data=True)
71
 
72
  return index
73
 
74
 
 
75
  image_index = load_image_index()
76
  model, processor = load_model()
77
 
 
81
  image = Image.open(image).convert("RGB")
82
 
83
  inputs = processor(text=[""], images=image, return_tensors="jax", padding=True)
84
+
85
  inputs["pixel_values"] = jnp.transpose(inputs["pixel_values"], axes=[0, 2, 3, 1])
86
  features = model(**inputs).image_embeds
87
+
88
  image_index.addDataPoint(features)
89
 
90
 
91
+ def query_with_images(query_images, query_text):
92
+ images = []
93
+
94
+ for im in query_images:
95
+ img = Image.open(im).convert("RGB")
96
+
97
  if im.name.endswith(".gif"):
98
  img.seek(0)
99
  images.append(img)
100
 
101
+ inputs = processor(
102
+ text=[query_text], images=images, return_tensors="jax", padding=True
103
+ )
104
  inputs["pixel_values"] = jnp.transpose(inputs["pixel_values"], axes=[0, 2, 3, 1])
105
  outputs = model(**inputs)
106
  logits_per_image = outputs.logits_per_image.reshape(-1)
 
108
  probs = jax.nn.softmax(logits_per_image)
109
  # st.write(probs)
110
  # st.write(list(zip(images,probs)))
111
+ results = sorted(list(zip(images, probs)), key=lambda x: x[1], reverse=True)
112
  # st.write(results)
113
+
114
  return zip(*results)
115
 
 
116
 
117
+ q_cols = st.columns([5, 2, 5])
118
+
119
+ examples = [
120
+ "OMG that is disgusting",
121
+ "I'm so scared right now",
122
+ " I got the job 🎉",
123
+ "Congratulations to all the flax-community week teams",
124
+ "You're awesome",
125
+ "I love you ❤️",
126
+ ]
127
+ example_input = q_cols[0].radio(
128
+ "Example Queries :",
129
+ examples,
130
+ index=4,
131
+ help="These are examples I wrote off the top of my head. They don't occur in the dataset",
132
+ )
133
  q_cols[2].markdown(
134
+ """
135
  Searches among the validation set images if not specified
136
+
137
  (There may be non-exact duplicates)
138
 
139
  """
140
  )
141
 
142
+ query_text = q_cols[0].text_input(
143
+ "Write text you want to get reaction for", value=example_input
144
+ )
145
+ query_images = q_cols[2].file_uploader(
146
+ "(optional) Upload images to rank them",
147
+ type=["jpg", "jpeg", "gif"],
148
+ accept_multiple_files=True,
149
+ )
150
 
151
  if query_images:
152
  st.write("Ranking your uploaded images with respect to input text:")
153
  with st.spinner("Calculating..."):
154
+ ids, dists = query_with_images(query_images, query_text)
155
  else:
156
  st.write("Found these images within validation set:")
157
  with st.spinner("Calculating..."):
158
+ proc = processor(
159
+ text=[query_text], images=None, return_tensors="jax", padding=True
160
+ )
161
  vec = np.asarray(model.get_text_features(**proc))
162
  ids, dists = image_index.knnQuery(vec, k=top_k)
163
 
164
+ show_gif = st.checkbox(
165
+ "Play GIFs",
166
+ value=True,
167
+ help="Will play the original animation. Only first frame is used in training!",
168
+ )
169
  ext = "jpg" if not show_gif else "gif"
170
+ res_cols = st.columns(col_count)
171
 
172
 
173
+ for i, (id_, dist) in enumerate(zip(ids, dists)):
174
+ j = i % col_count
175
  with res_cols[j]:
176
  if isinstance(id_, np.int32):
177
  st.image(f"./{ext}/{file_names[id_][:-4]}.{ext}")
178
  # st.write(file_names[id_])
179
+ st.write(1.0 - dist)
180
  else:
181
  st.image(id_)
182
+ st.write(dist)
183
+
184
+
185
+ # Credits
186
  st.sidebar.caption("Made by [Ceyda Cinarel](https://huggingface.co/ceyda)")
requirements.txt CHANGED
@@ -1,8 +1,77 @@
1
- jax
2
- jaxlib
3
- flax
4
- tqdm
5
- requests
6
- nmslib
7
- numpy
8
- transformers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Created by 'make update-requirements-txt'. DO NOT EDIT!
2
+ absl-py==1.4.0
3
+ altair==4.2.2
4
+ attrs==23.1.0
5
+ blinker==1.6.2
6
+ cachetools==5.3.1
7
+ certifi==2023.5.7
8
+ charset-normalizer==3.2.0
9
+ chex==0.1.82
10
+ click==8.1.6
11
+ decorator==5.1.1
12
+ entrypoints==0.4
13
+ etils==1.3.0
14
+ filelock==3.12.2
15
+ flax==0.7.0
16
+ fsspec==2023.6.0
17
+ gitdb==4.0.10
18
+ GitPython==3.1.32
19
+ huggingface-hub==0.16.4
20
+ idna==3.4
21
+ importlib-metadata==6.8.0
22
+ importlib-resources==6.0.0
23
+ jax==0.4.13
24
+ jaxlib==0.4.13
25
+ Jinja2==3.1.2
26
+ jsonschema==4.18.4
27
+ jsonschema-specifications==2023.7.1
28
+ markdown-it-py==3.0.0
29
+ MarkupSafe==2.1.3
30
+ mdurl==0.1.2
31
+ ml-dtypes==0.2.0
32
+ msgpack==1.0.5
33
+ nest-asyncio==1.5.6
34
+ nmslib==2.1.1
35
+ numpy==1.25.1
36
+ opt-einsum==3.3.0
37
+ optax==0.1.5
38
+ orbax-checkpoint==0.2.7
39
+ packaging==23.1
40
+ pandas==2.0.3
41
+ Pillow==9.5.0
42
+ protobuf==4.23.4
43
+ psutil==5.9.5
44
+ pyarrow==12.0.1
45
+ pybind11==2.6.1
46
+ pydeck==0.8.0
47
+ Pygments==2.15.1
48
+ Pympler==1.0.1
49
+ python-dateutil==2.8.2
50
+ pytz==2023.3
51
+ pytz-deprecation-shim==0.1.0.post0
52
+ PyYAML==6.0.1
53
+ referencing==0.30.0
54
+ regex==2023.6.3
55
+ requests==2.31.0
56
+ rich==13.4.2
57
+ rpds-py==0.9.2
58
+ safetensors==0.3.1
59
+ scipy==1.11.1
60
+ six==1.16.0
61
+ smmap==5.0.0
62
+ streamlit==1.25.0
63
+ tenacity==8.2.2
64
+ tensorstore==0.1.40
65
+ tokenizers==0.13.3
66
+ toml==0.10.2
67
+ toolz==0.12.0
68
+ tornado==6.3.2
69
+ tqdm==4.65.0
70
+ transformers==4.31.0
71
+ typing_extensions==4.7.1
72
+ tzdata==2023.3
73
+ tzlocal==4.3.1
74
+ urllib3==2.0.4
75
+ validators==0.20.0
76
+ watchdog==3.0.0
77
+ zipp==3.16.2
unpinned_requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ jax
2
+ jaxlib
3
+ flax
4
+ tqdm
5
+ requests
6
+ nmslib
7
+ numpy
8
+ transformers
9
+ altair<5
10
+ streamlit