Spaces:
Runtime error
Runtime error
Initial commit.
Browse files- .gitattributes +1 -0
- .gitignore +4 -0
- Dockerfile +18 -0
- README.md +10 -8
- app.py +218 -0
- cc3m_embeddings_urls.npy +3 -0
- gill/layers.py +54 -0
- gill/models.py +909 -0
- gill/utils.py +249 -0
- requirements.txt +36 -0
- share_btn.py +107 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
cc3m_embeddings_urls.npy filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.DS_Store
|
2 |
+
venv/
|
3 |
+
__pycache__
|
4 |
+
*.pyc
|
Dockerfile
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM pytorch/pytorch:1.11.0-cuda11.3-cudnn8-runtime as base
|
2 |
+
|
3 |
+
RUN apt-get update && apt-get -y install git
|
4 |
+
|
5 |
+
|
6 |
+
ENV HOME=/exp/fromage
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
WORKDIR /exp/fromage
|
11 |
+
COPY ./requirements.txt ./requirements.txt
|
12 |
+
RUN python -m pip install -r ./requirements.txt
|
13 |
+
RUN python -m pip install gradio
|
14 |
+
|
15 |
+
COPY . .
|
16 |
+
RUN chmod -R a+rwX .
|
17 |
+
|
18 |
+
CMD ["uvicorn", "app:main", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
CHANGED
@@ -1,11 +1,13 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom: red
|
5 |
-
colorTo: purple
|
6 |
sdk: docker
|
7 |
-
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
---
|
10 |
-
|
11 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: GILL
|
3 |
+
emoji: π
|
|
|
|
|
4 |
sdk: docker
|
5 |
+
app_file: app.py
|
6 |
+
colorFrom: blue
|
7 |
+
colorTo: red
|
8 |
+
pinned: true
|
9 |
+
tags:
|
10 |
+
- multimodal
|
11 |
+
- computer-vision
|
12 |
+
- nlp
|
13 |
---
|
|
|
|
app.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tempfile
|
2 |
+
from share_btn import community_icon_html, loading_icon_html, share_js, save_js
|
3 |
+
import huggingface_hub
|
4 |
+
import gradio as gr
|
5 |
+
from gill import utils
|
6 |
+
from gill import models
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
from PIL import Image
|
9 |
+
import torch
|
10 |
+
import numpy as np
|
11 |
+
import os
|
12 |
+
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "False"
|
13 |
+
|
14 |
+
|
15 |
+
css = """
|
16 |
+
#chatbot { min-height: 300px; }
|
17 |
+
#save-btn {
|
18 |
+
background-image: linear-gradient(to right bottom, rgba(130,217,244, 0.9), rgba(158,231,214, 1.0));
|
19 |
+
}
|
20 |
+
#save-btn:hover {
|
21 |
+
background-image: linear-gradient(to right bottom, rgba(110,197,224, 0.9), rgba(138,211,194, 1.0));
|
22 |
+
}
|
23 |
+
#share-btn {
|
24 |
+
background-image: linear-gradient(to right bottom, rgba(130,217,244, 0.9), rgba(158,231,214, 1.0));
|
25 |
+
}
|
26 |
+
#share-btn:hover {
|
27 |
+
background-image: linear-gradient(to right bottom, rgba(110,197,224, 0.9), rgba(138,211,194, 1.0));
|
28 |
+
}
|
29 |
+
#gallery { z-index: 999999; }
|
30 |
+
#gallery img:hover {transform: scale(2.3); z-index: 999999; position: relative; padding-right: 30%; padding-bottom: 30%;}
|
31 |
+
#gallery button img:hover {transform: none; z-index: 999999; position: relative; padding-right: 0; padding-bottom: 0;}
|
32 |
+
@media (hover: none) {
|
33 |
+
#gallery img:hover {transform: none; z-index: 999999; position: relative; padding-right: 0; 0;}
|
34 |
+
}
|
35 |
+
"""
|
36 |
+
|
37 |
+
examples = [
|
38 |
+
'examples/sparrow.png',
|
39 |
+
'examples/beaver.png',
|
40 |
+
'examples/couch.png',
|
41 |
+
'examples/guac.png',
|
42 |
+
'examples/scraped_knee.png'
|
43 |
+
]
|
44 |
+
|
45 |
+
# Download model from HF Hub.
|
46 |
+
ckpt_path = huggingface_hub.hf_hub_download(
|
47 |
+
repo_id='jykoh/gill', filename='pretrained_ckpt.pth.tar')
|
48 |
+
decision_model_path = huggingface_hub.hf_hub_download(
|
49 |
+
repo_id='jykoh/gill', filename='decision_model.pth.tar')
|
50 |
+
args_path = huggingface_hub.hf_hub_download(
|
51 |
+
repo_id='jykoh/gill', filename='model_args.json')
|
52 |
+
model = models.load_gill('./', args_path, ckpt_path, decision_model_path)
|
53 |
+
|
54 |
+
|
55 |
+
def upload_image(state, image_input):
|
56 |
+
conversation = state[0]
|
57 |
+
chat_history = state[1]
|
58 |
+
input_image = Image.open(image_input.name).resize(
|
59 |
+
(224, 224)).convert('RGB')
|
60 |
+
input_image.save(image_input.name) # Overwrite with smaller image.
|
61 |
+
conversation += [(f'<img src="/file={image_input.name}" style="display: inline-block;">', "")]
|
62 |
+
return [conversation, chat_history + [input_image, ""]], conversation
|
63 |
+
|
64 |
+
|
65 |
+
def reset():
|
66 |
+
return [[], []], []
|
67 |
+
|
68 |
+
|
69 |
+
def reset_last(state):
|
70 |
+
conversation = state[0][:-1]
|
71 |
+
chat_history = state[1][:-2]
|
72 |
+
return [conversation, chat_history], conversation
|
73 |
+
|
74 |
+
|
75 |
+
def save_image_to_local(image: Image.Image):
|
76 |
+
# TODO(jykoh): Update so the url path is used, to prevent repeat saving.
|
77 |
+
filename = next(tempfile._get_candidate_names()) + '.png'
|
78 |
+
image.save(filename)
|
79 |
+
return filename
|
80 |
+
|
81 |
+
|
82 |
+
def generate_for_prompt(input_text, state, ret_scale_factor, num_words, temperature):
|
83 |
+
# Ignore empty inputs.
|
84 |
+
if len(input_text) == 0:
|
85 |
+
return state, state[0], gr.update(visible=True)
|
86 |
+
|
87 |
+
input_prompt = 'Q: ' + input_text + '\nA:'
|
88 |
+
conversation = state[0]
|
89 |
+
chat_history = state[1]
|
90 |
+
print('Generating for', chat_history, flush=True)
|
91 |
+
|
92 |
+
# If an image was uploaded, prepend it to the model.
|
93 |
+
model_inputs = chat_history
|
94 |
+
model_inputs.append(input_prompt)
|
95 |
+
|
96 |
+
top_p = 1.0
|
97 |
+
if temperature != 0.0:
|
98 |
+
top_p = 0.95
|
99 |
+
|
100 |
+
print('Running model.generate_for_images_and_texts with',
|
101 |
+
model_inputs, flush=True)
|
102 |
+
model_outputs = model.generate_for_images_and_texts(model_inputs,
|
103 |
+
num_words=max(num_words, 1), ret_scale_factor=ret_scale_factor, top_p=top_p,
|
104 |
+
temperature=temperature, max_num_rets=1,
|
105 |
+
num_inference_steps=1)
|
106 |
+
print('model_outputs', model_outputs, ret_scale_factor, flush=True)
|
107 |
+
|
108 |
+
im_names = []
|
109 |
+
response = ''
|
110 |
+
text_outputs = []
|
111 |
+
for output_i, p in enumerate(model_outputs):
|
112 |
+
if type(p) == str:
|
113 |
+
if output_i > 0:
|
114 |
+
response += '<br/>'
|
115 |
+
# Remove the image tokens for output.
|
116 |
+
text_outputs.append(p.replace('[IMG0] [IMG1] [IMG2] [IMG3] [IMG4] [IMG5] [IMG6] [IMG7]', ''))
|
117 |
+
response += p
|
118 |
+
if len(model_outputs) > 1:
|
119 |
+
response += '<br/>'
|
120 |
+
elif type(p) == dict:
|
121 |
+
# Decide whether to generate or retrieve.
|
122 |
+
if p['decision'] is not None and p['decision'][0] == 'gen':
|
123 |
+
image = p['gen'][0][0].resize((512, 512))
|
124 |
+
filename = save_image_to_local(image)
|
125 |
+
response += f'<img src="/file={filename}" style="display: inline-block;"><p style="font-size: 12px; color: #555;">(Generated)</p>'
|
126 |
+
else:
|
127 |
+
image = p['ret'][0][0].resize((512, 512))
|
128 |
+
filename = save_image_to_local(image)
|
129 |
+
response += f'<img src="/file={filename}" style="display: inline-block;"><p style="font-size: 12px; color: #555;">(Retrieved)</p>'
|
130 |
+
|
131 |
+
|
132 |
+
chat_history = model_inputs + \
|
133 |
+
[' '.join([s for s in model_outputs if type(s) == str]) + '\n']
|
134 |
+
# Remove [RET] from outputs.
|
135 |
+
conversation.append((input_text, response.replace('[IMG0] [IMG1] [IMG2] [IMG3] [IMG4] [IMG5] [IMG6] [IMG7]', '')))
|
136 |
+
|
137 |
+
# Set input image to None.
|
138 |
+
print('state', state, flush=True)
|
139 |
+
print('updated state', [conversation, chat_history], flush=True)
|
140 |
+
return [conversation, chat_history], conversation, gr.update(visible=True), gr.update(visible=True)
|
141 |
+
|
142 |
+
|
143 |
+
with gr.Blocks(css=css) as demo:
|
144 |
+
gr.HTML("""
|
145 |
+
<h1>π§ FROMAGe</h1>
|
146 |
+
<p>This is the official Gradio demo for the FROMAGe model, a model that can process arbitrarily interleaved image and text inputs, and produce image and text outputs.</p>
|
147 |
+
|
148 |
+
<strong>Paper:</strong> <a href="https://arxiv.org/abs/2301.13823" target="_blank">Grounding Language Models to Images for Multimodal Generation</a>
|
149 |
+
<br/>
|
150 |
+
<strong>Project Website:</strong> <a href="https://jykoh.com/fromage" target="_blank">FROMAGe Website</a>
|
151 |
+
<br/>
|
152 |
+
<strong>Code and Models:</strong> <a href="https://github.com/kohjingyu/fromage" target="_blank">GitHub</a>
|
153 |
+
<br/>
|
154 |
+
<br/>
|
155 |
+
|
156 |
+
<strong>Tips:</strong>
|
157 |
+
<ul>
|
158 |
+
<li>Start by inputting either image or text prompts (or both) and chat with FROMAGe to get image-and-text replies.</li>
|
159 |
+
<li>Tweak the level of sensitivity to images and text using the parameters on the right.</li>
|
160 |
+
<li>FROMAGe <i>retrieves</i> images from a database, and doesn't generate novel images, and will not be able to return images outside those in Conceptual Captions.</li>
|
161 |
+
<li>Check out cool conversations in the examples or community tab for inspiration and share your own!</li>
|
162 |
+
<li>For faster inference without waiting in queue, you may duplicate the space and use your own GPU: <a href="https://huggingface.co/spaces/jykoh/fromage?duplicate=true"><img style="display: inline-block; margin-top: 0em; margin-bottom: 0em" src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAAAXNSR0IArs4c6QAAAP5JREFUOE+lk7FqAkEURY+ltunEgFXS2sZGIbXfEPdLlnxJyDdYB62sbbUKpLbVNhyYFzbrrA74YJlh9r079973psed0cvUD4A+4HoCjsA85X0Dfn/RBLBgBDxnQPfAEJgBY+A9gALA4tcbamSzS4xq4FOQAJgCDwV2CPKV8tZAJcAjMMkUe1vX+U+SMhfAJEHasQIWmXNN3abzDwHUrgcRGmYcgKe0bxrblHEB4E/pndMazNpSZGcsZdBlYJcEL9Afo75molJyM2FxmPgmgPqlWNLGfwZGG6UiyEvLzHYDmoPkDDiNm9JR9uboiONcBXrpY1qmgs21x1QwyZcpvxt9NS09PlsPAAAAAElFTkSuQmCC&logoWidth=14" alt="Duplicate Space"></a></li>
|
163 |
+
</ul>
|
164 |
+
""")
|
165 |
+
|
166 |
+
gr_state = gr.State([[], []]) # conversation, chat_history
|
167 |
+
|
168 |
+
with gr.Row():
|
169 |
+
with gr.Column(scale=0.7, min_width=500):
|
170 |
+
with gr.Row():
|
171 |
+
chatbot = gr.Chatbot(elem_id="chatbot", label="π§ FROMAGe Chatbot")
|
172 |
+
with gr.Row():
|
173 |
+
image_btn = gr.UploadButton("πΌοΈ Upload Image", file_types=["image"])
|
174 |
+
|
175 |
+
text_input = gr.Textbox(label="Message", placeholder="Type a message")
|
176 |
+
|
177 |
+
with gr.Column():
|
178 |
+
submit_btn = gr.Button(
|
179 |
+
"Submit", interactive=True, variant="primary")
|
180 |
+
clear_last_btn = gr.Button("Undo")
|
181 |
+
clear_btn = gr.Button("Reset All")
|
182 |
+
with gr.Row(visible=False) as save_group:
|
183 |
+
save_button = gr.Button("πΎ Save Conversation as .png", elem_id="save-btn")
|
184 |
+
|
185 |
+
with gr.Row(visible=False) as share_group:
|
186 |
+
share_button = gr.Button("π€ Share to Community (opens new window)", elem_id="share-btn")
|
187 |
+
|
188 |
+
with gr.Column(scale=0.3, min_width=400):
|
189 |
+
ret_scale_factor = gr.Slider(minimum=0.0, maximum=3.0, value=1.0, step=0.1, interactive=True,
|
190 |
+
label="Frequency multiplier for returning images (higher means more frequent)")
|
191 |
+
# max_ret_images = gr.Number(
|
192 |
+
# minimum=0, maximum=3, value=2, precision=1, interactive=True, label="Max images to return")
|
193 |
+
gr_max_len = gr.Slider(minimum=1, maximum=64, value=32,
|
194 |
+
step=1, interactive=True, label="Max # of words")
|
195 |
+
gr_temperature = gr.Slider(
|
196 |
+
minimum=0.0, maximum=1.0, value=0.0, interactive=True, label="Temperature (0 for deterministic, higher for more randomness)")
|
197 |
+
|
198 |
+
gallery = gr.Gallery(
|
199 |
+
value=[Image.open(e) for e in examples], label="Example Conversations", show_label=True, elem_id="gallery",
|
200 |
+
).style(grid=[2], height="auto")
|
201 |
+
|
202 |
+
text_input.submit(generate_for_prompt, [text_input, gr_state, ret_scale_factor,
|
203 |
+
gr_max_len, gr_temperature], [gr_state, chatbot, share_group, save_group])
|
204 |
+
text_input.submit(lambda: "", None, text_input) # Reset chatbox.
|
205 |
+
submit_btn.click(generate_for_prompt, [text_input, gr_state, ret_scale_factor,
|
206 |
+
gr_max_len, gr_temperature], [gr_state, chatbot, share_group, save_group])
|
207 |
+
submit_btn.click(lambda: "", None, text_input) # Reset chatbox.
|
208 |
+
|
209 |
+
image_btn.upload(upload_image, [gr_state, image_btn], [gr_state, chatbot])
|
210 |
+
clear_last_btn.click(reset_last, [gr_state], [gr_state, chatbot])
|
211 |
+
clear_btn.click(reset, [], [gr_state, chatbot])
|
212 |
+
share_button.click(None, [], [], _js=share_js)
|
213 |
+
save_button.click(None, [], [], _js=save_js)
|
214 |
+
|
215 |
+
|
216 |
+
demo.queue(concurrency_count=1, api_open=False, max_size=16)
|
217 |
+
# demo.launch(debug=True, server_name="0.0.0.0")
|
218 |
+
demo.launch(debug=True, server_name="127.0.0.1")
|
cc3m_embeddings_urls.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:797e2ab9d46f103106bbf111352c762c5969630e9a13ccdc1f56a51c63fc39a3
|
3 |
+
size 2887526287
|
gill/layers.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
|
5 |
+
class TextFcLayer(nn.Module):
|
6 |
+
"""Layers used in mapping text embeddings to visual outputs."""
|
7 |
+
|
8 |
+
def __init__(self, in_dim: int, out_dim: int, num_input_tokens: int = 1, num_output_tokens: int = 1, mode: str = 'linear'):
|
9 |
+
super().__init__()
|
10 |
+
|
11 |
+
self.num_input_tokens = num_input_tokens
|
12 |
+
self.num_output_tokens = num_output_tokens
|
13 |
+
self.mode = mode
|
14 |
+
|
15 |
+
if mode == 'linear':
|
16 |
+
self.model = nn.Linear(in_dim, out_dim)
|
17 |
+
elif mode == 'gill_mapper': # TODO(jykoh): Rename to GILLMapper
|
18 |
+
hidden_dim = 512
|
19 |
+
self.fc = nn.Linear(in_dim, hidden_dim)
|
20 |
+
self.tfm = nn.Transformer(batch_first=True, norm_first=True,
|
21 |
+
d_model=hidden_dim, num_encoder_layers=4, num_decoder_layers=4,
|
22 |
+
dim_feedforward=hidden_dim * 4, dropout=0.0, nhead=4)
|
23 |
+
self.model = nn.Linear(hidden_dim, out_dim)
|
24 |
+
self.query_embs = nn.Parameter(torch.randn(1, num_output_tokens, hidden_dim))
|
25 |
+
else:
|
26 |
+
raise NotImplementedError(mode)
|
27 |
+
|
28 |
+
def forward(self, x: torch.Tensor, input_embs: torch.Tensor) -> torch.Tensor:
|
29 |
+
outputs = None
|
30 |
+
|
31 |
+
if self.mode == 'gill_mapper':
|
32 |
+
x = x + input_embs
|
33 |
+
|
34 |
+
if isinstance(self.model, nn.ModuleList):
|
35 |
+
assert len(self.model) == x.shape[1] == self.num_input_tokens, (len(self.model), x.shape, self.num_input_tokens)
|
36 |
+
outputs = []
|
37 |
+
for i in range(self.num_input_tokens):
|
38 |
+
outputs.append(self.model[i](x[:, i, :])) # (N, D)
|
39 |
+
outputs = torch.stack(outputs, dim=1) # (N, T, D)
|
40 |
+
else:
|
41 |
+
if self.mode == 'gill_mapper':
|
42 |
+
x = self.fc(x)
|
43 |
+
x = self.tfm(x, self.query_embs.repeat(x.shape[0], 1, 1))
|
44 |
+
outputs = self.model(x)
|
45 |
+
|
46 |
+
if outputs.shape[1] != self.num_output_tokens and self.mode == 'linear':
|
47 |
+
if self.mode == 'linear':
|
48 |
+
outputs = outputs[:, :self.num_output_tokens, :]
|
49 |
+
else:
|
50 |
+
raise NotImplementedError
|
51 |
+
|
52 |
+
assert outputs.shape[1] == 1 or (outputs.shape[1] * outputs.shape[2] == self.num_output_tokens * 768), (outputs.shape, self.num_output_tokens)
|
53 |
+
return outputs # (N, T, D)
|
54 |
+
|
gill/models.py
ADDED
@@ -0,0 +1,909 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional
|
2 |
+
from collections import namedtuple
|
3 |
+
from diffusers import StableDiffusionPipeline
|
4 |
+
import json
|
5 |
+
import numpy as np
|
6 |
+
import os
|
7 |
+
import glob
|
8 |
+
import torch
|
9 |
+
from torch import Tensor
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
import pickle as pkl
|
13 |
+
from PIL import Image, UnidentifiedImageError
|
14 |
+
from requests.exceptions import ConnectionError
|
15 |
+
|
16 |
+
from transformers import AutoTokenizer, AutoModel, CLIPVisionModel, OPTForCausalLM
|
17 |
+
from gill import utils
|
18 |
+
from gill import layers
|
19 |
+
|
20 |
+
|
21 |
+
class GILLArgs:
|
22 |
+
freeze_lm: bool = True
|
23 |
+
freeze_vm: bool = True
|
24 |
+
opt_version: str = 'facebook/opt-6.7b'
|
25 |
+
visual_encoder: str = 'openai/clip-vit-large-patch14'
|
26 |
+
n_visual_tokens: int = 1
|
27 |
+
task: str = 'captioning'
|
28 |
+
ret_emb_dim: Optional[int] = 256
|
29 |
+
gen_emb_dim: Optional[int] = 256
|
30 |
+
text_emb_layers: List[int] = [-1]
|
31 |
+
gen_token_idx: List[int] = [0]
|
32 |
+
retrieval_token_idx: List[int] = [0]
|
33 |
+
text_fc_mode: str = 'gill_mapper'
|
34 |
+
ret_text_fc_mode: str = 'linear'
|
35 |
+
num_tokens: int = 8
|
36 |
+
num_clip_tokens: int = 77
|
37 |
+
|
38 |
+
|
39 |
+
class GILLModel(nn.Module):
|
40 |
+
def __init__(self, tokenizer, args: GILLArgs = GILLArgs()):
|
41 |
+
super().__init__()
|
42 |
+
self.tokenizer = tokenizer
|
43 |
+
self.feature_extractor = utils.get_feature_extractor_for_model(args.visual_encoder, train=False)
|
44 |
+
self.image_token = self.tokenizer.cls_token_id
|
45 |
+
assert args.text_emb_layers != set(args.text_emb_layers), 'text_emb_layers not unique'
|
46 |
+
self.args = args
|
47 |
+
self.num_tokens = args.num_tokens
|
48 |
+
self.num_clip_tokens = args.num_clip_tokens
|
49 |
+
|
50 |
+
opt_version = args.opt_version
|
51 |
+
visual_encoder = args.visual_encoder
|
52 |
+
n_visual_tokens = args.n_visual_tokens
|
53 |
+
print(f"Using {opt_version} for the language model.")
|
54 |
+
print(f"Using {visual_encoder} for the visual model with {n_visual_tokens} visual tokens.")
|
55 |
+
|
56 |
+
if 'facebook/opt' in opt_version:
|
57 |
+
self.lm = OPTForCausalLM.from_pretrained(opt_version)
|
58 |
+
else:
|
59 |
+
raise NotImplementedError
|
60 |
+
|
61 |
+
self.opt_version = opt_version
|
62 |
+
|
63 |
+
if self.args.freeze_lm:
|
64 |
+
self.lm.eval()
|
65 |
+
print("Freezing the LM.")
|
66 |
+
for param in self.lm.parameters():
|
67 |
+
param.requires_grad = False
|
68 |
+
else:
|
69 |
+
self.lm.train()
|
70 |
+
|
71 |
+
self.retrieval_token_idx = args.retrieval_token_idx
|
72 |
+
self.gen_token_idx = args.gen_token_idx
|
73 |
+
self.lm.resize_token_embeddings(len(tokenizer))
|
74 |
+
|
75 |
+
self.input_embeddings = self.lm.get_input_embeddings()
|
76 |
+
|
77 |
+
print("Restoring pretrained weights for the visual model.")
|
78 |
+
if 'clip' in visual_encoder:
|
79 |
+
self.visual_model = CLIPVisionModel.from_pretrained(visual_encoder)
|
80 |
+
else:
|
81 |
+
self.visual_model = AutoModel.from_pretrained(visual_encoder)
|
82 |
+
|
83 |
+
if 'clip' in visual_encoder:
|
84 |
+
hidden_size = self.visual_model.config.hidden_size
|
85 |
+
else:
|
86 |
+
raise NotImplementedError
|
87 |
+
|
88 |
+
if self.args.freeze_vm:
|
89 |
+
print("Freezing the VM.")
|
90 |
+
self.visual_model.eval()
|
91 |
+
for param in self.visual_model.parameters():
|
92 |
+
param.requires_grad = False
|
93 |
+
else:
|
94 |
+
self.visual_model.train()
|
95 |
+
|
96 |
+
self.visual_model_name = visual_encoder
|
97 |
+
|
98 |
+
embedding_dim = self.input_embeddings.embedding_dim * self.args.n_visual_tokens
|
99 |
+
self.ret_text_hidden_fcs = nn.ModuleList([])
|
100 |
+
self.gen_text_hidden_fcs = nn.ModuleList([])
|
101 |
+
|
102 |
+
for layer_idx in self.args.text_emb_layers:
|
103 |
+
if (layer_idx == -1 or layer_idx == self.lm.config.num_hidden_layers) and ('bert' not in opt_version):
|
104 |
+
if 'opt' in opt_version: # OPT models
|
105 |
+
in_dim = self.lm.config.word_embed_proj_dim
|
106 |
+
else:
|
107 |
+
raise NotImplementedError
|
108 |
+
|
109 |
+
self.ret_text_hidden_fcs.append(
|
110 |
+
layers.TextFcLayer(in_dim, self.args.ret_emb_dim, num_input_tokens=self.args.num_tokens,
|
111 |
+
num_output_tokens=1, mode=self.args.ret_text_fc_mode))
|
112 |
+
self.gen_text_hidden_fcs.append(
|
113 |
+
layers.TextFcLayer(in_dim, self.args.gen_emb_dim, num_input_tokens=self.args.num_tokens,
|
114 |
+
num_output_tokens=self.args.num_clip_tokens, mode=self.args.text_fc_mode))
|
115 |
+
|
116 |
+
elif layer_idx < self.lm.config.num_hidden_layers:
|
117 |
+
self.ret_text_hidden_fcs.append(layers.TextFcLayer(self.lm.config.hidden_size, self.args.ret_emb_dim, num_input_tokens=self.args.num_tokens, num_output_tokens=1, mode=self.args.ret_text_fc_mode))
|
118 |
+
self.gen_text_hidden_fcs.append(layers.TextFcLayer(self.lm.config.hidden_size, self.args.gen_emb_dim, num_input_tokens=self.args.num_tokens, num_output_tokens=self.args.num_clip_tokens, mode=self.args.text_fc_mode))
|
119 |
+
else:
|
120 |
+
raise ValueError(f'Embedding of layer {layer_idx} was requested but model only has {self.lm.config.num_hidden_layers} layers.')
|
121 |
+
|
122 |
+
self.visual_embeddings = nn.Linear(hidden_size, embedding_dim)
|
123 |
+
|
124 |
+
# Retrieval image FC layer.
|
125 |
+
self.visual_fc = nn.Linear(hidden_size, self.args.ret_emb_dim)
|
126 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
127 |
+
|
128 |
+
|
129 |
+
def get_visual_embs(self, pixel_values: torch.FloatTensor, mode: str = 'captioning'):
|
130 |
+
if mode not in ['captioning', 'retrieval', 'generation']:
|
131 |
+
raise ValueError(f"mode should be one of ['captioning', 'retrieval', 'generation'], got {mode} instead.")
|
132 |
+
|
133 |
+
# Extract visual embeddings from the vision encoder.
|
134 |
+
if 'clip' in self.visual_model_name:
|
135 |
+
outputs = self.visual_model(pixel_values)
|
136 |
+
encoder_outputs = outputs.pooler_output
|
137 |
+
else:
|
138 |
+
raise NotImplementedError
|
139 |
+
|
140 |
+
# Use the correct fc based on function argument.
|
141 |
+
if mode == 'captioning':
|
142 |
+
visual_embs = self.visual_embeddings(encoder_outputs) # (2, D * n_visual_tokens)
|
143 |
+
visual_embs = torch.reshape(visual_embs, (visual_embs.shape[0], self.args.n_visual_tokens, -1))
|
144 |
+
elif mode == 'retrieval':
|
145 |
+
visual_embs = self.visual_fc(encoder_outputs) # (2, D * n_visual_tokens)
|
146 |
+
visual_embs = torch.reshape(visual_embs, (visual_embs.shape[0], 1, -1))
|
147 |
+
elif mode == 'generation':
|
148 |
+
visual_embs = torch.zeros((pixel_values.shape[0], 1, 768), device=pixel_values.device)
|
149 |
+
else:
|
150 |
+
raise NotImplementedError
|
151 |
+
|
152 |
+
return visual_embs
|
153 |
+
|
154 |
+
|
155 |
+
def train(self, mode=True):
|
156 |
+
super(GILLModel, self).train(mode=mode)
|
157 |
+
# Overwrite train() to ensure frozen models remain frozen.
|
158 |
+
if self.args.freeze_lm:
|
159 |
+
self.lm.eval()
|
160 |
+
if self.args.freeze_vm:
|
161 |
+
self.visual_model.eval()
|
162 |
+
|
163 |
+
|
164 |
+
def forward(
|
165 |
+
self,
|
166 |
+
pixel_values: torch.FloatTensor,
|
167 |
+
labels: Optional[torch.LongTensor] = None,
|
168 |
+
caption_len: Optional[torch.LongTensor] = None,
|
169 |
+
mode: str = 'captioning',
|
170 |
+
concat_captions: bool = False,
|
171 |
+
input_prefix: Optional[str] = None,
|
172 |
+
):
|
173 |
+
visual_embs = self.get_visual_embs(pixel_values, mode)
|
174 |
+
|
175 |
+
batch_size, vis_seq_len, _ = visual_embs.shape # vis_seq_len = n_visual_tokens
|
176 |
+
if labels is not None:
|
177 |
+
assert labels.shape[0] == batch_size, (visual_embs.shape, labels.shape)
|
178 |
+
visual_embs_norm = ((visual_embs ** 2).sum(dim=-1) ** 0.5).mean()
|
179 |
+
|
180 |
+
input_embs = self.input_embeddings(labels) # (N, T, D)
|
181 |
+
input_embs_norm = ((input_embs ** 2).sum(dim=-1) ** 0.5).mean()
|
182 |
+
|
183 |
+
last_embedding_idx = caption_len - 1 # -1 to retrieve the token before the eos token
|
184 |
+
|
185 |
+
if input_prefix is not None:
|
186 |
+
prompt_ids = self.tokenizer(input_prefix, add_special_tokens=False, return_tensors="pt").input_ids
|
187 |
+
prompt_ids = prompt_ids.to(visual_embs.device)
|
188 |
+
prompt_embs = self.input_embeddings(prompt_ids)
|
189 |
+
prompt_embs = prompt_embs.repeat(batch_size, 1, 1)
|
190 |
+
assert prompt_embs.shape[0] == batch_size, prompt_embs.shape
|
191 |
+
assert prompt_embs.shape[2] == input_embs.shape[2], prompt_embs.shape
|
192 |
+
assert len(prompt_embs.shape) == 3, prompt_embs.shape
|
193 |
+
|
194 |
+
if mode == 'captioning':
|
195 |
+
# Concat to text embeddings.
|
196 |
+
condition_seq_len = 0
|
197 |
+
if input_prefix is None:
|
198 |
+
# Just add visual embeddings.
|
199 |
+
input_embs = torch.cat([visual_embs, input_embs], axis=1)
|
200 |
+
last_embedding_idx += vis_seq_len
|
201 |
+
condition_seq_len += vis_seq_len
|
202 |
+
full_labels = torch.zeros(visual_embs.shape[:2], dtype=torch.int64).to(visual_embs.device) - 100
|
203 |
+
else:
|
204 |
+
print(f'Adding prefix "{input_prefix}" to captioning.')
|
205 |
+
# Add visual and prompt embeddings.
|
206 |
+
prefix_embs = torch.cat([visual_embs, prompt_embs], axis=1)
|
207 |
+
input_embs = torch.cat([prefix_embs, input_embs], axis=1)
|
208 |
+
|
209 |
+
last_embedding_idx += prefix_embs.shape[1]
|
210 |
+
condition_seq_len += prefix_embs.shape[1]
|
211 |
+
full_labels = torch.zeros(prefix_embs.shape[:2], dtype=torch.int64).to(visual_embs.device) - 100
|
212 |
+
|
213 |
+
# Mask out embedding tokens in the labels.
|
214 |
+
full_labels = torch.cat([full_labels, labels], axis=1)
|
215 |
+
|
216 |
+
pad_idx = []
|
217 |
+
|
218 |
+
for label in full_labels:
|
219 |
+
for k, token in enumerate(label):
|
220 |
+
# Mask out retrieval/gen tokens if they exist.
|
221 |
+
if token in [self.tokenizer.pad_token_id] + self.retrieval_token_idx + self.gen_token_idx:
|
222 |
+
label[k:] = -100
|
223 |
+
pad_idx.append(k)
|
224 |
+
break
|
225 |
+
if k == len(label) - 1: # No padding found.
|
226 |
+
pad_idx.append(k + 1)
|
227 |
+
assert len(pad_idx) == batch_size, (len(pad_idx), batch_size)
|
228 |
+
|
229 |
+
bs, seq_len, embs_dim = input_embs.shape
|
230 |
+
if concat_captions:
|
231 |
+
print('Concatenating examples for captioning!')
|
232 |
+
assert len(input_embs.shape) == 3, input_embs
|
233 |
+
assert len(full_labels.shape) == 2, full_labels
|
234 |
+
assert batch_size % 2 == 0
|
235 |
+
all_concat_input_embs = []
|
236 |
+
all_concat_labels = []
|
237 |
+
|
238 |
+
# Rearrange embeddings and labels (and their padding) to concatenate captions.
|
239 |
+
for i in range(batch_size // 2):
|
240 |
+
first_idx = i * 2
|
241 |
+
second_idx = first_idx + 1
|
242 |
+
first_emb = input_embs[first_idx, :pad_idx[first_idx], :]
|
243 |
+
first_labels = full_labels[first_idx, :pad_idx[first_idx]]
|
244 |
+
first_padding = input_embs[first_idx, pad_idx[first_idx]:, :]
|
245 |
+
first_labels_padding = full_labels[first_idx, pad_idx[first_idx]:]
|
246 |
+
|
247 |
+
second_emb = input_embs[second_idx, :pad_idx[second_idx], :]
|
248 |
+
second_labels = full_labels[second_idx, :pad_idx[second_idx]]
|
249 |
+
second_padding = input_embs[second_idx, pad_idx[second_idx]:, :]
|
250 |
+
second_labels_padding = full_labels[second_idx, pad_idx[second_idx]:]
|
251 |
+
bos_idx = visual_embs.shape[1]
|
252 |
+
|
253 |
+
assert torch.all(first_labels_padding == -100), first_labels_padding
|
254 |
+
assert torch.all(second_labels_padding == -100), second_labels_padding
|
255 |
+
assert torch.all(second_labels[bos_idx] == self.tokenizer.bos_token_id), (second_labels, bos_idx, self.tokenizer.bos_token_id)
|
256 |
+
|
257 |
+
# Remove BOS token of the second caption.
|
258 |
+
second_labels = torch.cat([second_labels[:bos_idx], second_labels[bos_idx + 1:]], axis=0)
|
259 |
+
second_emb = torch.cat([second_emb[:bos_idx, :], second_emb[bos_idx + 1:, :]], axis=0)
|
260 |
+
|
261 |
+
concat_input_embs = torch.cat([first_emb, second_emb, first_padding, second_padding], axis=0) # (T*2, 768)
|
262 |
+
concat_labels = torch.cat([first_labels, second_labels, first_labels_padding, second_labels_padding], axis=0) # (T*2, 768)
|
263 |
+
all_concat_input_embs.append(concat_input_embs)
|
264 |
+
all_concat_labels.append(concat_labels)
|
265 |
+
|
266 |
+
# Pad to max length.
|
267 |
+
input_embs = torch.stack(all_concat_input_embs, axis=0) # (N/2, T*2, 768)
|
268 |
+
full_labels = torch.stack(all_concat_labels, axis=0) # (N/2, T*2, 768)
|
269 |
+
print("Concatenated full_labels:", full_labels[0, ...])
|
270 |
+
assert input_embs.shape == (bs // 2, seq_len * 2 - 1, embs_dim), input_embs.shape
|
271 |
+
assert full_labels.shape == (bs // 2, seq_len * 2 - 1), full_labels.shape
|
272 |
+
|
273 |
+
output = self.lm(inputs_embeds=input_embs,
|
274 |
+
labels=full_labels,
|
275 |
+
output_hidden_states=True)
|
276 |
+
elif mode in ['retrieval', 'generation']:
|
277 |
+
full_labels = torch.clone(labels)
|
278 |
+
if input_prefix is not None:
|
279 |
+
print(f'Adding prefix "{input_prefix}" to retrieval.')
|
280 |
+
# Add prompt embeddings.
|
281 |
+
prefix_embs = prompt_embs
|
282 |
+
input_embs = torch.cat([prefix_embs, input_embs], axis=1)
|
283 |
+
last_embedding_idx += prefix_embs.shape[1]
|
284 |
+
full_labels = torch.cat([
|
285 |
+
torch.zeros(prefix_embs.shape[:2], dtype=torch.int64).to(labels.device) - 100,
|
286 |
+
full_labels
|
287 |
+
], axis=1)
|
288 |
+
|
289 |
+
pad_idx = []
|
290 |
+
for label in full_labels:
|
291 |
+
for k, token in enumerate(label):
|
292 |
+
if (token == self.tokenizer.pad_token_id):
|
293 |
+
label[k:] = -100
|
294 |
+
pad_idx.append(k)
|
295 |
+
break
|
296 |
+
if k == len(label) - 1: # No padding found.
|
297 |
+
pad_idx.append(k + 1)
|
298 |
+
assert len(pad_idx) == batch_size, (len(pad_idx), batch_size)
|
299 |
+
|
300 |
+
bs, seq_len, embs_dim = input_embs.shape
|
301 |
+
# Concatenate examples for captioning, if specified.
|
302 |
+
if concat_captions:
|
303 |
+
print(f'Concatenating examples for {mode}!')
|
304 |
+
assert len(input_embs.shape) == 3, input_embs
|
305 |
+
assert len(full_labels.shape) == 2, full_labels
|
306 |
+
assert batch_size % 2 == 0
|
307 |
+
all_concat_input_embs = []
|
308 |
+
all_concat_labels = []
|
309 |
+
all_last_embedding_idx = []
|
310 |
+
|
311 |
+
# Rearrange embeddings and labels (and their padding) to concatenate captions.
|
312 |
+
for i in range(batch_size // 2):
|
313 |
+
first_idx = i * 2
|
314 |
+
second_idx = first_idx + 1
|
315 |
+
first_emb = input_embs[first_idx, :pad_idx[first_idx], :]
|
316 |
+
first_labels = full_labels[first_idx, :pad_idx[first_idx]]
|
317 |
+
first_padding = input_embs[first_idx, pad_idx[first_idx]:, :]
|
318 |
+
first_labels_padding = full_labels[first_idx, pad_idx[first_idx]:]
|
319 |
+
|
320 |
+
second_emb = input_embs[second_idx, :pad_idx[second_idx], :]
|
321 |
+
second_labels = full_labels[second_idx, :pad_idx[second_idx]]
|
322 |
+
second_padding = input_embs[second_idx, pad_idx[second_idx]:, :]
|
323 |
+
second_labels_padding = full_labels[second_idx, pad_idx[second_idx]:]
|
324 |
+
|
325 |
+
bos_idx = 0
|
326 |
+
assert torch.all(first_labels_padding == -100), first_labels_padding
|
327 |
+
assert torch.all(second_labels_padding == -100), second_labels_padding
|
328 |
+
assert torch.all(second_labels[bos_idx] == self.tokenizer.bos_token_id), (second_labels, bos_idx, self.tokenizer.bos_token_id)
|
329 |
+
|
330 |
+
# Remove BOS token of second caption.
|
331 |
+
second_labels = second_labels[bos_idx + 1:]
|
332 |
+
second_emb = second_emb[bos_idx + 1:, :]
|
333 |
+
last_embedding_idx[second_idx] = last_embedding_idx[second_idx] - 1
|
334 |
+
|
335 |
+
concat_input_embs = torch.cat([first_emb, second_emb, first_padding, second_padding], axis=0) # (T*2, 768)
|
336 |
+
concat_labels = torch.cat([first_labels, second_labels, first_labels_padding, second_labels_padding], axis=0) # (T*2, 768)
|
337 |
+
all_concat_input_embs.append(concat_input_embs)
|
338 |
+
all_concat_labels.append(concat_labels)
|
339 |
+
|
340 |
+
all_last_embedding_idx.append((last_embedding_idx[first_idx], first_emb.shape[0] + last_embedding_idx[second_idx]))
|
341 |
+
|
342 |
+
if mode == 'retrieval':
|
343 |
+
assert concat_labels[all_last_embedding_idx[-1][0]] in self.retrieval_token_idx, (concat_labels, all_last_embedding_idx[-1][0])
|
344 |
+
assert concat_labels[all_last_embedding_idx[-1][1]] in self.retrieval_token_idx, (concat_labels, all_last_embedding_idx[-1][1])
|
345 |
+
elif mode == 'generation':
|
346 |
+
# Check that the last n tokens are GEN tokens.
|
347 |
+
for gen_i in range(len(self.gen_token_idx)):
|
348 |
+
assert concat_labels[all_last_embedding_idx[-1][0]-gen_i] == self.gen_token_idx[-gen_i-1], (concat_labels, all_last_embedding_idx[-1][0]-gen_i, self.gen_token_idx[-gen_i-1])
|
349 |
+
assert concat_labels[all_last_embedding_idx[-1][1]-gen_i] == self.gen_token_idx[-gen_i-1], (concat_labels, all_last_embedding_idx[-1][1]-gen_i, self.gen_token_idx[-gen_i-1])
|
350 |
+
|
351 |
+
# Pad to max length.
|
352 |
+
input_embs = torch.stack(all_concat_input_embs, axis=0) # (N/2, T*2, 768)
|
353 |
+
full_labels = torch.stack(all_concat_labels, axis=0) # (N/2, T*2, 768)
|
354 |
+
assert input_embs.shape == (bs // 2, seq_len * 2 - 1, embs_dim), input_embs.shape
|
355 |
+
assert full_labels.shape == (bs // 2, seq_len * 2 - 1), full_labels.shape
|
356 |
+
|
357 |
+
# Update labels to pad non-first tokens.
|
358 |
+
for label in full_labels:
|
359 |
+
for k, token in enumerate(label):
|
360 |
+
if (token == self.tokenizer.pad_token_id) or (token in (self.retrieval_token_idx[1:] + self.gen_token_idx[1:])):
|
361 |
+
label[k:] = -100
|
362 |
+
break
|
363 |
+
output = self.lm(inputs_embeds=input_embs,
|
364 |
+
labels=full_labels,
|
365 |
+
output_hidden_states=True)
|
366 |
+
else:
|
367 |
+
raise NotImplementedError
|
368 |
+
|
369 |
+
last_embedding = None
|
370 |
+
last_output_logit = None
|
371 |
+
hidden_states = []
|
372 |
+
llm_hidden_states = []
|
373 |
+
|
374 |
+
if mode in ['retrieval', 'generation']:
|
375 |
+
num_tokens = self.num_tokens
|
376 |
+
if mode == 'retrieval':
|
377 |
+
text_hidden_fcs = self.ret_text_hidden_fcs
|
378 |
+
else:
|
379 |
+
text_hidden_fcs = self.gen_text_hidden_fcs
|
380 |
+
|
381 |
+
# Concatenate captions for retrieval / generation, if specified.
|
382 |
+
if not concat_captions:
|
383 |
+
for idx, fc_layer in zip(self.args.text_emb_layers, text_hidden_fcs):
|
384 |
+
input_hidden_state = torch.stack([output.hidden_states[idx][i, last_embedding_idx[i]-num_tokens+1:last_embedding_idx[i]+1, :] for i in range(batch_size)], axis=0)
|
385 |
+
input_embedding = torch.stack([input_embs[i, last_embedding_idx[i]-num_tokens+1:last_embedding_idx[i]+1, :] for i in range(batch_size)], axis=0)
|
386 |
+
llm_hidden_states.append(input_hidden_state)
|
387 |
+
hidden_states.append(fc_layer(input_hidden_state, input_embedding)) # (N, seq_len, 2048)
|
388 |
+
else:
|
389 |
+
for idx, fc_layer in zip(self.args.text_emb_layers, text_hidden_fcs):
|
390 |
+
all_last_embedding = []
|
391 |
+
all_input_embedding = []
|
392 |
+
all_last_output_logit = []
|
393 |
+
for i in range(batch_size // 2):
|
394 |
+
first_last_embedding_idx, second_last_embedding_idx = all_last_embedding_idx[i]
|
395 |
+
first_last_embedding = output.hidden_states[idx][i, first_last_embedding_idx-num_tokens+1:first_last_embedding_idx+1, :] # (N, D)
|
396 |
+
second_last_embedding = output.hidden_states[idx][i, second_last_embedding_idx-num_tokens+1:second_last_embedding_idx+1, :] # (N, D)
|
397 |
+
all_last_embedding.append(first_last_embedding)
|
398 |
+
all_last_embedding.append(second_last_embedding)
|
399 |
+
|
400 |
+
first_input_embs = input_embs[i, first_last_embedding_idx-num_tokens+1:first_last_embedding_idx+1, :] # (N, D)
|
401 |
+
second_input_embs = input_embs[i, second_last_embedding_idx-num_tokens+1:second_last_embedding_idx+1, :] # (N, D)
|
402 |
+
all_input_embedding.append(first_input_embs)
|
403 |
+
all_input_embedding.append(second_input_embs)
|
404 |
+
|
405 |
+
first_last_output_logit = output.logits[i, first_last_embedding_idx - 1, :] # (N, D)
|
406 |
+
second_last_output_logit = output.logits[i, second_last_embedding_idx - 1, :] # (N, D)
|
407 |
+
all_last_output_logit.append(first_last_output_logit)
|
408 |
+
all_last_output_logit.append(second_last_output_logit)
|
409 |
+
|
410 |
+
last_embedding = torch.stack(all_last_embedding, axis=0)
|
411 |
+
input_embedding = torch.stack(all_input_embedding, axis=0)
|
412 |
+
last_output_logit = torch.stack(all_last_output_logit, axis=0)
|
413 |
+
llm_hidden_states.append(last_embedding)
|
414 |
+
hidden_states.append(fc_layer(last_embedding, input_embedding)) # (N, seq_len, 2048)
|
415 |
+
|
416 |
+
if not concat_captions:
|
417 |
+
# Add hidden states together.
|
418 |
+
last_embedding = torch.stack(hidden_states, dim=-1).sum(dim=-1) #torch.stack([last_hidden_state[i, :, :] for i in range(batch_size)], axis=0) # (N, T, D)
|
419 |
+
last_output_logit = torch.stack([output.logits[i, last_embedding_idx[i] - 1, :] for i in range(batch_size)], axis=0) # (N, D)
|
420 |
+
else:
|
421 |
+
# Add hidden states together.
|
422 |
+
last_embedding = torch.stack(hidden_states, dim=-1).sum(dim=-1)
|
423 |
+
|
424 |
+
# Compute retrieval loss.
|
425 |
+
if mode == 'retrieval':
|
426 |
+
assert visual_embs.shape[1] == 1, visual_embs.shape
|
427 |
+
assert last_embedding.shape[1] == 1, last_embedding.shape
|
428 |
+
visual_embs = visual_embs[:, 0, :]
|
429 |
+
visual_embs = visual_embs / visual_embs.norm(dim=1, keepdim=True)
|
430 |
+
last_embedding = last_embedding[:, 0, :]
|
431 |
+
last_embedding = last_embedding / last_embedding.norm(dim=1, keepdim=True)
|
432 |
+
|
433 |
+
# cosine similarity as logits
|
434 |
+
logit_scale = self.logit_scale.exp()
|
435 |
+
visual_embs = logit_scale * visual_embs
|
436 |
+
elif mode == 'captioning':
|
437 |
+
pass
|
438 |
+
else:
|
439 |
+
raise NotImplementedError
|
440 |
+
|
441 |
+
return output, full_labels, last_embedding, last_output_logit, visual_embs, visual_embs_norm, input_embs_norm, llm_hidden_states
|
442 |
+
|
443 |
+
def generate(self, embeddings = torch.FloatTensor, max_len: int = 32,
|
444 |
+
temperature: float = 0.0, top_p: float = 1.0, min_word_tokens: int = 0,
|
445 |
+
ret_scale_factor: float = 1.0, gen_scale_factor: float = 1.0,
|
446 |
+
filter_value: float = -float('Inf')):
|
447 |
+
"""Runs greedy decoding and returns generated captions.
|
448 |
+
|
449 |
+
Args:
|
450 |
+
min_word_tokens: Minimum number of words to generate before allowing a [IMG] output.
|
451 |
+
filter_value: Value to assign to tokens that should never be generated.
|
452 |
+
Outputs:
|
453 |
+
out: (N, T) int32 sequence of output tokens.
|
454 |
+
output_embeddings: (N, T, 256) sequence of text output embeddings.
|
455 |
+
"""
|
456 |
+
self.lm.eval()
|
457 |
+
|
458 |
+
with torch.no_grad(): # no tracking history
|
459 |
+
# init output with image tokens
|
460 |
+
out = None
|
461 |
+
output_embeddings = []
|
462 |
+
output_logits = []
|
463 |
+
|
464 |
+
for i in range(max_len):
|
465 |
+
output = self.lm(inputs_embeds=embeddings, use_cache=False, output_hidden_states=True)
|
466 |
+
|
467 |
+
for idx in self.args.text_emb_layers:
|
468 |
+
output_embeddings.append(output.hidden_states[idx])
|
469 |
+
|
470 |
+
logits = output.logits[:, -1, :] # (N, vocab_size)
|
471 |
+
if top_p == 1.0:
|
472 |
+
logits = logits.cpu()
|
473 |
+
output_logits.append(logits)
|
474 |
+
|
475 |
+
# Prevent the model from generating the [IMG1..n] tokens.
|
476 |
+
logits[:, self.retrieval_token_idx[1:]] = filter_value
|
477 |
+
logits[:, self.gen_token_idx[1:]] = filter_value
|
478 |
+
|
479 |
+
if (self.retrieval_token_idx or self.gen_token_idx) and self.retrieval_token_idx[0] != -1 and self.gen_token_idx[0] != -1:
|
480 |
+
if i < min_word_tokens:
|
481 |
+
# Eliminate probability of generating [IMG] if this is earlier than min_word_tokens.
|
482 |
+
logits[:, self.retrieval_token_idx] = filter_value
|
483 |
+
logits[:, self.gen_token_idx] = filter_value
|
484 |
+
else:
|
485 |
+
# Multiply by scaling factor.
|
486 |
+
if ret_scale_factor > 1:
|
487 |
+
logits[:, self.retrieval_token_idx[0]] = logits[:, self.retrieval_token_idx[0]].abs() * ret_scale_factor
|
488 |
+
if gen_scale_factor > 1:
|
489 |
+
logits[:, self.gen_token_idx[0]] = logits[:, self.gen_token_idx[0]].abs() * gen_scale_factor
|
490 |
+
|
491 |
+
if temperature == 0.0:
|
492 |
+
if top_p != 1.0:
|
493 |
+
raise ValueError('top_p cannot be set if temperature is 0 (greedy decoding).')
|
494 |
+
next_token = torch.argmax(logits, keepdim=True, dim=-1) # (N, 1)
|
495 |
+
else:
|
496 |
+
logits = logits / temperature
|
497 |
+
|
498 |
+
# Apply top-p filtering.
|
499 |
+
if top_p < 1.0:
|
500 |
+
assert top_p > 0, f'top_p should be above 0, got {top_p} instead.'
|
501 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True) # (N, D) and (N, D)
|
502 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # (N, D)
|
503 |
+
|
504 |
+
# Remove tokens with cumulative probability above the threshold
|
505 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
506 |
+
# Shift the indices to the right to keep also the first token above the threshold
|
507 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
508 |
+
sorted_indices_to_remove[..., 0] = 0
|
509 |
+
|
510 |
+
for j in range(sorted_indices.shape[0]):
|
511 |
+
indices_to_remove = sorted_indices[j, sorted_indices_to_remove[j, :]]
|
512 |
+
logits[j, indices_to_remove] = filter_value
|
513 |
+
|
514 |
+
token_weights = logits.exp() # (N, vocab_size)
|
515 |
+
next_token = torch.multinomial(token_weights, 1) # (N, 1)
|
516 |
+
|
517 |
+
# Force generation of the remaining [IMG] tokens if [IMG0] is generated.
|
518 |
+
if next_token.shape[0] == 1 and next_token.item() == self.retrieval_token_idx[0]:
|
519 |
+
assert self.retrieval_token_idx == self.gen_token_idx, (self.retrieval_token_idx, self.gen_token_idx)
|
520 |
+
next_token = torch.tensor(self.retrieval_token_idx)[None, :].long().to(embeddings.device) # (1, num_tokens)
|
521 |
+
else:
|
522 |
+
next_token = next_token.long().to(embeddings.device)
|
523 |
+
|
524 |
+
if out is not None:
|
525 |
+
out = torch.cat([out, next_token], dim=-1)
|
526 |
+
else:
|
527 |
+
out = next_token
|
528 |
+
|
529 |
+
next_embedding = self.input_embeddings(next_token)
|
530 |
+
embeddings = torch.cat([embeddings, next_embedding], dim=1)
|
531 |
+
|
532 |
+
return out, output_embeddings, output_logits
|
533 |
+
|
534 |
+
|
535 |
+
class GILL(nn.Module):
|
536 |
+
def __init__(self, tokenizer, model_args: Optional[GILLArgs] = None,
|
537 |
+
path_array: Optional[List[str]] = None, emb_matrix: Optional[torch.tensor] = None,
|
538 |
+
load_sd: bool = False, num_gen_images: int = 1, decision_model_path: Optional[str] = None):
|
539 |
+
super().__init__()
|
540 |
+
self.model = GILLModel(tokenizer, model_args)
|
541 |
+
self.path_array = path_array
|
542 |
+
self.emb_matrix = emb_matrix
|
543 |
+
self.load_sd = load_sd
|
544 |
+
self.num_gen_images = num_gen_images
|
545 |
+
self.idx2dec = {0: 'gen', 1: 'ret', 2: 'same'}
|
546 |
+
self.decision_model = None
|
547 |
+
|
548 |
+
# Load the Stable Diffusion model.
|
549 |
+
if load_sd:
|
550 |
+
model_id = "runwayml/stable-diffusion-v1-5"
|
551 |
+
if torch.cuda.is_available():
|
552 |
+
self.sd_pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
|
553 |
+
else:
|
554 |
+
self.sd_pipe = StableDiffusionPipeline.from_pretrained(model_id)
|
555 |
+
|
556 |
+
if decision_model_path is not None:
|
557 |
+
print('Loading decision model...')
|
558 |
+
self.decision_model = nn.Sequential(*[
|
559 |
+
nn.Dropout(0.5),
|
560 |
+
nn.Linear(4097, 2),
|
561 |
+
])
|
562 |
+
|
563 |
+
if torch.cuda.is_available():
|
564 |
+
mlp_checkpoint = torch.load(decision_model_path)
|
565 |
+
else:
|
566 |
+
mlp_checkpoint = torch.load(decision_model_path, map_location=torch.device('cpu'))
|
567 |
+
|
568 |
+
self.decision_model.load_state_dict(mlp_checkpoint['state_dict'], strict=True)
|
569 |
+
self.decision_model.eval()
|
570 |
+
|
571 |
+
def __call__(self, images: Tensor, tgt_tokens: Optional[Tensor] = None, caption_len: Optional[Tensor] = None,
|
572 |
+
generate: bool = False, num_words: int = 32, temperature: float = 1.0, top_p: float = 1.0,
|
573 |
+
ret_scale_factor: float = 1.0, gen_scale_factor: float = 1.0,
|
574 |
+
min_word_tokens: int = 0, mode: str = 'captioning', concat_captions: bool = False,
|
575 |
+
input_prefix: Optional[str] = None) -> Tensor:
|
576 |
+
if generate:
|
577 |
+
return self.model.generate(images, num_words, temperature=temperature, top_p=top_p,
|
578 |
+
min_word_tokens=min_word_tokens, ret_scale_factor=ret_scale_factor,
|
579 |
+
gen_scale_factor=gen_scale_factor)
|
580 |
+
else:
|
581 |
+
output = self.model(
|
582 |
+
pixel_values = images,
|
583 |
+
labels = tgt_tokens,
|
584 |
+
caption_len = caption_len,
|
585 |
+
mode = mode,
|
586 |
+
concat_captions = concat_captions,
|
587 |
+
input_prefix = input_prefix)
|
588 |
+
return output
|
589 |
+
|
590 |
+
def generate_for_images_and_texts(
|
591 |
+
self, prompts: List, num_words: int = 0, min_word_tokens: int = 0, ret_scale_factor: float = 1.0, gen_scale_factor: float = 1.0,
|
592 |
+
top_p: float = 1.0, temperature: float = 0.0, max_num_rets: int = 1, generator=None,
|
593 |
+
always_add_bos : bool = False, guidance_scale: float = 7.5, num_inference_steps: int = 50):
|
594 |
+
"""
|
595 |
+
Encode prompts into embeddings, and generates text and image outputs accordingly.
|
596 |
+
|
597 |
+
Args:
|
598 |
+
prompts: List of interleaved PIL.Image.Image and strings representing input to the model.
|
599 |
+
num_words: Maximum number of words to generate for. If num_words = 0, the model will run its forward pass and return the outputs.
|
600 |
+
min_word_tokens: Minimum number of actual words before generating an image.
|
601 |
+
ret_scale_factor: Proportion to scale [IMG] token logits by. A higher value may increase the probability of the model generating [IMG] outputs.
|
602 |
+
top_p: If set to < 1, the smallest set of tokens with highest probabilities that add up to top_p or higher are kept for generation.
|
603 |
+
temperature: Used to modulate logit distribution.
|
604 |
+
max_num_rets: Maximum number of images to return in one generation pass.
|
605 |
+
Returns:
|
606 |
+
return_outputs: List consisting of either str or List[PIL.Image.Image] objects, representing image-text interleaved model outputs.
|
607 |
+
"""
|
608 |
+
input_embs = []
|
609 |
+
input_ids = []
|
610 |
+
add_bos = True
|
611 |
+
|
612 |
+
with torch.no_grad():
|
613 |
+
for p in prompts:
|
614 |
+
if type(p) == Image.Image:
|
615 |
+
# Encode as image.
|
616 |
+
pixel_values = utils.get_pixel_values_for_model(self.model.feature_extractor, p)
|
617 |
+
pixel_values = pixel_values.to(device=self.model.logit_scale.device, dtype=self.model.logit_scale.dtype)
|
618 |
+
pixel_values = pixel_values[None, ...]
|
619 |
+
|
620 |
+
visual_embs = self.model.get_visual_embs(pixel_values, mode='captioning') # (1, n_visual_tokens, D)
|
621 |
+
input_embs.append(visual_embs)
|
622 |
+
elif type(p) == str:
|
623 |
+
text_ids = self.model.tokenizer(p, add_special_tokens=add_bos, return_tensors="pt").input_ids.to(self.model.logit_scale.device)
|
624 |
+
# Only add <bos> once unless the flag is set.
|
625 |
+
if not always_add_bos:
|
626 |
+
add_bos = False
|
627 |
+
|
628 |
+
text_embs = self.model.input_embeddings(text_ids) # (1, T, D)
|
629 |
+
input_embs.append(text_embs)
|
630 |
+
input_ids.append(text_ids)
|
631 |
+
else:
|
632 |
+
raise ValueError(f'Input prompts should be either PIL.Image.Image or str types, got {type(p)} instead.')
|
633 |
+
input_embs = torch.cat(input_embs, dim=1)
|
634 |
+
input_ids = torch.cat(input_ids, dim=1)
|
635 |
+
|
636 |
+
if num_words == 0:
|
637 |
+
raise NotImplementedError('Generation not implemented for num_words=0.')
|
638 |
+
elif num_words > 0:
|
639 |
+
generated_ids, generated_embeddings, _ = self.model.generate(input_embs, num_words, min_word_tokens=min_word_tokens,
|
640 |
+
temperature=temperature, top_p=top_p, ret_scale_factor=ret_scale_factor, gen_scale_factor=gen_scale_factor)
|
641 |
+
embeddings = generated_embeddings[-1][:, input_embs.shape[1]:]
|
642 |
+
|
643 |
+
# Truncate to newline.
|
644 |
+
newline_token_id = self.model.tokenizer('\n', add_special_tokens=False).input_ids[0]
|
645 |
+
trunc_idx = 0
|
646 |
+
for j in range(generated_ids.shape[1]):
|
647 |
+
if generated_ids[0, j] == newline_token_id:
|
648 |
+
trunc_idx = j
|
649 |
+
break
|
650 |
+
if trunc_idx > 0:
|
651 |
+
generated_ids = generated_ids[:, :trunc_idx]
|
652 |
+
embeddings = embeddings[:, :trunc_idx]
|
653 |
+
else:
|
654 |
+
raise ValueError
|
655 |
+
|
656 |
+
# Save outputs as an interleaved list.
|
657 |
+
return_outputs = []
|
658 |
+
# Find up to max_num_rets [IMG] tokens, and their corresponding scores.
|
659 |
+
all_ret_idx = [i for i, x in enumerate(generated_ids[0, :] == self.model.retrieval_token_idx[0]) if x][:max_num_rets]
|
660 |
+
seen_image_idx = [] # Avoid showing the same image multiple times.
|
661 |
+
|
662 |
+
last_ret_idx = 0
|
663 |
+
if len(all_ret_idx) == 0:
|
664 |
+
# No [IMG] tokens.
|
665 |
+
caption = self.model.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
666 |
+
return_outputs.append(utils.truncate_caption(caption))
|
667 |
+
else:
|
668 |
+
for ret_idx in all_ret_idx:
|
669 |
+
assert generated_ids[0, ret_idx:ret_idx+self.model.num_tokens].cpu().detach().numpy().tolist() == self.model.retrieval_token_idx, (generated_ids[0, ret_idx:ret_idx+self.model.num_tokens], self.model.retrieval_token_idx)
|
670 |
+
raw_emb = embeddings[:, ret_idx:ret_idx+self.model.num_tokens, :] # (1, 8, 4096)
|
671 |
+
assert len(self.model.args.text_emb_layers) == 1
|
672 |
+
|
673 |
+
image_outputs = {
|
674 |
+
'gen': [],
|
675 |
+
'ret': [],
|
676 |
+
'decision': None,
|
677 |
+
}
|
678 |
+
|
679 |
+
if self.emb_matrix is not None:
|
680 |
+
# Produce retrieval embedding.
|
681 |
+
ret_emb = self.model.ret_text_hidden_fcs[0](raw_emb, None)[:, 0, :] # (1, 256)
|
682 |
+
ret_emb = ret_emb / ret_emb.norm(dim=-1, keepdim=True)
|
683 |
+
ret_emb = ret_emb.type(self.emb_matrix.dtype) # (1, 256)
|
684 |
+
scores = self.emb_matrix @ ret_emb.T
|
685 |
+
|
686 |
+
# Downweight seen images.
|
687 |
+
for seen_idx in seen_image_idx:
|
688 |
+
scores[seen_idx, :] -= 1000
|
689 |
+
|
690 |
+
# Get the top 3 images for each image.
|
691 |
+
_, top_image_idx = scores.squeeze().topk(3)
|
692 |
+
for img_idx in top_image_idx:
|
693 |
+
# Find the first image that does not error out.
|
694 |
+
try:
|
695 |
+
seen_image_idx.append(img_idx)
|
696 |
+
img = utils.get_image_from_url(self.path_array[img_idx])
|
697 |
+
image_outputs['ret'].append((img, 'ret', scores[img_idx].item()))
|
698 |
+
if len(image_outputs) == max_num_rets:
|
699 |
+
break
|
700 |
+
except (UnidentifiedImageError, ConnectionError):
|
701 |
+
pass
|
702 |
+
|
703 |
+
# Make decision with MLP.
|
704 |
+
if self.decision_model is not None:
|
705 |
+
decision_emb = raw_emb[:, 0, :] # (1, 4096)
|
706 |
+
assert decision_emb.shape[1] == 4096, decision_emb.shape
|
707 |
+
max_ret_score = scores.max().reshape((1, 1)).clone().detach().to(device=decision_emb.device, dtype=decision_emb.dtype)
|
708 |
+
decision_logits = self.decision_model(torch.cat([decision_emb, max_ret_score], dim=-1))
|
709 |
+
probs = decision_logits.softmax(dim=-1).cpu().float().numpy().tolist()
|
710 |
+
image_outputs['decision'] = [self.idx2dec[decision_logits.argmax().item()]] + probs
|
711 |
+
else:
|
712 |
+
# If no embedding matrix is provided, generate instead.
|
713 |
+
image_outputs['decision'] = ['gen', [0, 1]]
|
714 |
+
|
715 |
+
# Produce generation embedding.
|
716 |
+
gen_prefix = ' '.join([f'[IMG{i}]' for i in range(self.model.args.num_tokens)])
|
717 |
+
gen_prefx_ids = self.model.tokenizer(gen_prefix, add_special_tokens=False, return_tensors="pt").input_ids.to(self.model.logit_scale.device)
|
718 |
+
gen_prefix_embs = self.model.input_embeddings(gen_prefx_ids) # (1, T, D)
|
719 |
+
gen_emb = self.model.gen_text_hidden_fcs[0](raw_emb, gen_prefix_embs) # (1, 77, 768)
|
720 |
+
|
721 |
+
if gen_emb.shape[1] != 77:
|
722 |
+
print(f"Padding {gen_emb.shape} with zeros")
|
723 |
+
bs = gen_emb.shape[0]
|
724 |
+
clip_emb = 768
|
725 |
+
gen_emb = gen_emb.reshape(bs, -1, clip_emb) # (bs, T, 768)
|
726 |
+
seq_len = gen_emb.shape[1]
|
727 |
+
gen_emb = torch.cat([gen_emb, torch.zeros((bs, 77 - seq_len, clip_emb), device=gen_emb.device, dtype=gen_emb.dtype)], dim=1)
|
728 |
+
print('Padded to', gen_emb.shape)
|
729 |
+
|
730 |
+
gen_emb = gen_emb.repeat(self.num_gen_images, 1, 1) # (self.num_gen_images, 77, 768)
|
731 |
+
|
732 |
+
# OPTIM(jykoh): Only generate if scores are low.
|
733 |
+
if self.load_sd:
|
734 |
+
# If num_gen_images > 8, split into multiple batches (for GPU memory reasons).
|
735 |
+
gen_max_bs = 8
|
736 |
+
gen_images = []
|
737 |
+
for i in range(0, self.num_gen_images, gen_max_bs):
|
738 |
+
gen_images.extend(
|
739 |
+
self.sd_pipe(prompt_embeds=gen_emb[i:i+gen_max_bs], generator=generator,
|
740 |
+
guidance_scale=guidance_scale, num_inference_steps=num_inference_steps).images)
|
741 |
+
|
742 |
+
all_gen_pixels = []
|
743 |
+
for img in gen_images:
|
744 |
+
pixel_values = utils.get_pixel_values_for_model(self.model.feature_extractor, img.resize((224, 224)).convert('RGB'))
|
745 |
+
pixel_values = pixel_values.to(device=self.model.logit_scale.device, dtype=self.model.logit_scale.dtype)
|
746 |
+
all_gen_pixels.append(pixel_values)
|
747 |
+
|
748 |
+
if self.emb_matrix is not None:
|
749 |
+
all_gen_pixels = torch.stack(all_gen_pixels, dim=0)
|
750 |
+
gen_visual_embs = self.model.get_visual_embs(all_gen_pixels, mode='retrieval') # (1, D)
|
751 |
+
gen_visual_embs = gen_visual_embs / gen_visual_embs.norm(dim=-1, keepdim=True)
|
752 |
+
gen_visual_embs = gen_visual_embs.type(self.emb_matrix.dtype)
|
753 |
+
gen_rank_scores = (gen_visual_embs @ ret_emb.T).squeeze()
|
754 |
+
sorted_score_idx = torch.argsort(-gen_rank_scores)
|
755 |
+
|
756 |
+
# Rank images by retrieval score.
|
757 |
+
if self.num_gen_images > 1:
|
758 |
+
image_outputs['gen'] = [(gen_images[idx], gen_rank_scores[idx].item()) for idx in sorted_score_idx]
|
759 |
+
else:
|
760 |
+
image_outputs['gen'] = [(gen_images[0], gen_rank_scores.item())]
|
761 |
+
else:
|
762 |
+
image_outputs['gen'] = [(gen_images[0], 0)]
|
763 |
+
else:
|
764 |
+
image_outputs['gen'] = [gen_emb]
|
765 |
+
|
766 |
+
caption = self.model.tokenizer.batch_decode(generated_ids[:, last_ret_idx:ret_idx], skip_special_tokens=True)[0]
|
767 |
+
last_ret_idx = ret_idx + 1
|
768 |
+
return_outputs.append(utils.truncate_caption(caption) + f' {gen_prefix}')
|
769 |
+
return_outputs.append(image_outputs)
|
770 |
+
|
771 |
+
return return_outputs
|
772 |
+
|
773 |
+
def get_log_likelihood_scores(
|
774 |
+
self, prompts: List):
|
775 |
+
"""
|
776 |
+
Output the log likelihood of the given interleaved prompts.
|
777 |
+
|
778 |
+
Args:
|
779 |
+
prompts: List of interleaved PIL.Image.Image and strings representing input to the model.
|
780 |
+
Returns:
|
781 |
+
Log likelihood score of the prompt sequence.
|
782 |
+
"""
|
783 |
+
input_embs = []
|
784 |
+
input_ids = []
|
785 |
+
add_bos = True
|
786 |
+
|
787 |
+
for p in prompts:
|
788 |
+
if type(p) == Image.Image:
|
789 |
+
# Encode as image.
|
790 |
+
pixel_values = utils.get_pixel_values_for_model(self.model.feature_extractor, p)
|
791 |
+
pixel_values = pixel_values.to(device=self.model.logit_scale.device, dtype=self.model.logit_scale.dtype)
|
792 |
+
pixel_values = pixel_values[None, ...]
|
793 |
+
|
794 |
+
visual_embs = self.model.get_visual_embs(pixel_values, mode='captioning') # (1, n_visual_tokens, D)
|
795 |
+
input_embs.append(visual_embs)
|
796 |
+
id_ = torch.zeros(visual_embs.shape[:2], dtype=torch.int64).to(visual_embs.device) - 100
|
797 |
+
input_ids.append(id_)
|
798 |
+
elif type(p) == str:
|
799 |
+
text_ids = self.model.tokenizer(p, add_special_tokens=True, return_tensors="pt").input_ids.to(self.model.logit_scale.device)
|
800 |
+
if not add_bos:
|
801 |
+
# Remove <bos> tag.
|
802 |
+
text_ids = text_ids[:, 1:]
|
803 |
+
else:
|
804 |
+
# Only add <bos> once.
|
805 |
+
add_bos = False
|
806 |
+
|
807 |
+
text_embs = self.model.input_embeddings(text_ids) # (1, T, D)
|
808 |
+
input_embs.append(text_embs)
|
809 |
+
input_ids.append(text_ids)
|
810 |
+
else:
|
811 |
+
raise ValueError(f'Input prompts should be either PIL.Image.Image or str types, got {type(p)} instead.')
|
812 |
+
input_embs = torch.cat(input_embs, dim=1)
|
813 |
+
input_ids = torch.cat(input_ids, dim=1)
|
814 |
+
|
815 |
+
outputs = self.model.lm(inputs_embeds=input_embs, labels=input_ids, use_cache=False, output_hidden_states=True)
|
816 |
+
return -outputs.loss.item()
|
817 |
+
|
818 |
+
|
819 |
+
def load_gill(embeddings_dir: str, model_args_path: str, model_ckpt_path: str, decision_model_path: str) -> GILL:
|
820 |
+
embs_paths = [s for s in glob.glob(os.path.join(embeddings_dir, 'cc3m*.npy'))]
|
821 |
+
|
822 |
+
if not os.path.exists(model_args_path):
|
823 |
+
raise ValueError(f'model_args.json does not exist at {model_args_path}.')
|
824 |
+
if not os.path.exists(model_ckpt_path):
|
825 |
+
raise ValueError(f'pretrained_ckpt.pth.tar does not exist at {model_ckpt_path}.')
|
826 |
+
if len(embs_paths) == 0:
|
827 |
+
print(f'cc3m*.npy files do not exist in {embeddings_dir}. Running the model without retrieval.')
|
828 |
+
path_array, emb_matrix = None, None
|
829 |
+
else:
|
830 |
+
# Load embeddings.
|
831 |
+
# Construct embedding matrix for nearest neighbor lookup.
|
832 |
+
path_array = []
|
833 |
+
emb_matrix = []
|
834 |
+
|
835 |
+
# These were precomputed for all CC3M images with `model.get_visual_embs(image, mode='retrieval')`.
|
836 |
+
for p in embs_paths:
|
837 |
+
with open(p, 'rb') as wf:
|
838 |
+
train_embs_data = pkl.load(wf)
|
839 |
+
path_array.extend(train_embs_data['paths'])
|
840 |
+
emb_matrix.extend(train_embs_data['embeddings'])
|
841 |
+
emb_matrix = np.stack(emb_matrix, axis=0)
|
842 |
+
|
843 |
+
# Number of paths should be equal to number of embeddings.
|
844 |
+
assert len(path_array) == emb_matrix.shape[0], (len(path_array), emb_matrix.shape)
|
845 |
+
|
846 |
+
with open(model_args_path, 'r') as f:
|
847 |
+
model_kwargs = json.load(f)
|
848 |
+
|
849 |
+
# Initialize tokenizer.
|
850 |
+
tokenizer = AutoTokenizer.from_pretrained(model_kwargs['opt_version'], use_fast=False)
|
851 |
+
if tokenizer.pad_token is None:
|
852 |
+
tokenizer.pad_token_id = tokenizer.eos_token_id
|
853 |
+
# Add an image token for loss masking (and visualization) purposes.
|
854 |
+
tokenizer.add_special_tokens({"cls_token": "<|image|>"}) # add special image token to tokenizer
|
855 |
+
|
856 |
+
# Add [IMG] tokens to the vocabulary.
|
857 |
+
model_kwargs['retrieval_token_idx'] = []
|
858 |
+
for i in range(model_kwargs['num_tokens']):
|
859 |
+
print(f'Adding [IMG{i}] token to vocabulary.')
|
860 |
+
print(f'Before adding new token, tokenizer("[IMG{i}]") =', tokenizer(f'[IMG{i}]', add_special_tokens=False))
|
861 |
+
num_added_tokens = tokenizer.add_tokens(f'[IMG{i}]')
|
862 |
+
print(f'After adding {num_added_tokens} new tokens, tokenizer("[IMG{i}]") =', tokenizer(f'[IMG{i}]', add_special_tokens=False))
|
863 |
+
ret_token_idx = tokenizer(f'[IMG{i}]', add_special_tokens=False).input_ids
|
864 |
+
assert len(ret_token_idx) == 1, ret_token_idx
|
865 |
+
model_kwargs['retrieval_token_idx'].append(ret_token_idx[0])
|
866 |
+
# Use the same RET tokens for generation.
|
867 |
+
model_kwargs['gen_token_idx'] = model_kwargs['retrieval_token_idx']
|
868 |
+
|
869 |
+
debug = False
|
870 |
+
if debug:
|
871 |
+
model_kwargs['opt_version'] = 'facebook/opt-125m'
|
872 |
+
model_kwargs['visual_encoder'] = 'openai/clip-vit-base-patch32'
|
873 |
+
decision_model_path = None
|
874 |
+
|
875 |
+
args = namedtuple('args', model_kwargs)(**model_kwargs)
|
876 |
+
|
877 |
+
# Initialize model for inference.
|
878 |
+
model = GILL(tokenizer, args, path_array=path_array, emb_matrix=emb_matrix,
|
879 |
+
load_sd=not debug, num_gen_images=1, decision_model_path=decision_model_path)
|
880 |
+
model = model.eval()
|
881 |
+
if not debug:
|
882 |
+
model = model.bfloat16()
|
883 |
+
model = model.cuda()
|
884 |
+
|
885 |
+
# Load pretrained linear mappings and [IMG] embeddings.
|
886 |
+
checkpoint = torch.load(model_ckpt_path)
|
887 |
+
state_dict = {}
|
888 |
+
# This is needed if we train with DDP.
|
889 |
+
for k, v in checkpoint['state_dict'].items():
|
890 |
+
state_dict[k.replace('module.', '')] = v
|
891 |
+
img_token_embeddings = state_dict['model.input_embeddings.weight'].cpu().detach()
|
892 |
+
del state_dict['model.input_embeddings.weight']
|
893 |
+
|
894 |
+
model.load_state_dict(state_dict, strict=False)
|
895 |
+
# Copy over the embeddings of the [IMG] tokens (while loading the others from the pretrained LLM).
|
896 |
+
with torch.no_grad():
|
897 |
+
if 'share_ret_gen' in model_kwargs:
|
898 |
+
assert model_kwargs['share_ret_gen'], 'Model loading only supports share_ret_gen=True for now.'
|
899 |
+
model.model.input_embeddings.weight[-model_kwargs['num_tokens']:, :].copy_(img_token_embeddings)
|
900 |
+
|
901 |
+
if len(embs_paths) > 0:
|
902 |
+
logit_scale = model.model.logit_scale.exp()
|
903 |
+
emb_matrix = torch.tensor(emb_matrix, dtype=logit_scale.dtype).to(logit_scale.device)
|
904 |
+
emb_matrix = emb_matrix / emb_matrix.norm(dim=1, keepdim=True)
|
905 |
+
emb_matrix = logit_scale * emb_matrix
|
906 |
+
model.emb_matrix = emb_matrix
|
907 |
+
|
908 |
+
return model
|
909 |
+
|
gill/utils.py
ADDED
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import Enum
|
2 |
+
import subprocess
|
3 |
+
import sys
|
4 |
+
import shutil
|
5 |
+
import torch
|
6 |
+
import torch.distributed as dist
|
7 |
+
from torchvision.transforms import functional as F
|
8 |
+
from torchvision import transforms as T
|
9 |
+
from transformers import AutoFeatureExtractor
|
10 |
+
from PIL import Image, ImageDraw, ImageFont, ImageOps
|
11 |
+
import random
|
12 |
+
import requests
|
13 |
+
from io import BytesIO
|
14 |
+
|
15 |
+
|
16 |
+
def dump_git_status(out_file=sys.stdout, exclude_file_patterns=['*.ipynb', '*.th', '*.sh', '*.txt', '*.json']):
|
17 |
+
"""Logs git status to stdout."""
|
18 |
+
subprocess.call('git rev-parse HEAD', shell=True, stdout=out_file)
|
19 |
+
subprocess.call('echo', shell=True, stdout=out_file)
|
20 |
+
exclude_string = ''
|
21 |
+
subprocess.call('git --no-pager diff -- . {}'.format(exclude_string), shell=True, stdout=out_file)
|
22 |
+
|
23 |
+
|
24 |
+
def get_image_from_url(url: str):
|
25 |
+
response = requests.get(url)
|
26 |
+
img = Image.open(BytesIO(response.content))
|
27 |
+
img = img.resize((224, 224))
|
28 |
+
img = img.convert('RGB')
|
29 |
+
return img
|
30 |
+
|
31 |
+
|
32 |
+
def truncate_caption(caption: str) -> str:
|
33 |
+
"""Truncate captions at periods and newlines."""
|
34 |
+
caption = caption.strip('\n')
|
35 |
+
trunc_index = caption.find('\n') + 1
|
36 |
+
if trunc_index <= 0:
|
37 |
+
trunc_index = caption.find('.') + 1
|
38 |
+
if trunc_index > 0:
|
39 |
+
caption = caption[:trunc_index]
|
40 |
+
return caption
|
41 |
+
|
42 |
+
|
43 |
+
def pad_to_size(x, size=256):
|
44 |
+
delta_w = size - x.size[0]
|
45 |
+
delta_h = size - x.size[1]
|
46 |
+
padding = (
|
47 |
+
delta_w // 2,
|
48 |
+
delta_h // 2,
|
49 |
+
delta_w - (delta_w // 2),
|
50 |
+
delta_h - (delta_h // 2),
|
51 |
+
)
|
52 |
+
new_im = ImageOps.expand(x, padding)
|
53 |
+
return new_im
|
54 |
+
|
55 |
+
|
56 |
+
class RandCropResize(object):
|
57 |
+
|
58 |
+
"""
|
59 |
+
Randomly crops, then randomly resizes, then randomly crops again, an image. Mirroring the augmentations from https://arxiv.org/abs/2102.12092
|
60 |
+
"""
|
61 |
+
|
62 |
+
def __init__(self, target_size):
|
63 |
+
self.target_size = target_size
|
64 |
+
|
65 |
+
def __call__(self, img):
|
66 |
+
img = pad_to_size(img, self.target_size)
|
67 |
+
d_min = min(img.size)
|
68 |
+
img = T.RandomCrop(size=d_min)(img)
|
69 |
+
t_min = min(d_min, round(9 / 8 * self.target_size))
|
70 |
+
t_max = min(d_min, round(12 / 8 * self.target_size))
|
71 |
+
t = random.randint(t_min, t_max + 1)
|
72 |
+
img = T.Resize(t)(img)
|
73 |
+
if min(img.size) < 256:
|
74 |
+
img = T.Resize(256)(img)
|
75 |
+
return T.RandomCrop(size=self.target_size)(img)
|
76 |
+
|
77 |
+
|
78 |
+
class SquarePad(object):
|
79 |
+
"""Pads image to square.
|
80 |
+
From https://discuss.pytorch.org/t/how-to-resize-and-pad-in-a-torchvision-transforms-compose/71850/9
|
81 |
+
"""
|
82 |
+
def __call__(self, image):
|
83 |
+
max_wh = max(image.size)
|
84 |
+
p_left, p_top = [(max_wh - s) // 2 for s in image.size]
|
85 |
+
p_right, p_bottom = [max_wh - (s+pad) for s, pad in zip(image.size, [p_left, p_top])]
|
86 |
+
padding = (p_left, p_top, p_right, p_bottom)
|
87 |
+
return F.pad(image, padding, 0, 'constant')
|
88 |
+
|
89 |
+
|
90 |
+
def create_image_of_text(text: str, width: int = 224, nrows: int = 2, color=(255, 255, 255), font=None) -> torch.Tensor:
|
91 |
+
"""Creates a (3, nrows * 14, width) image of text.
|
92 |
+
|
93 |
+
Returns:
|
94 |
+
cap_img: (3, 14 * nrows, width) image of wrapped text.
|
95 |
+
"""
|
96 |
+
height = 12
|
97 |
+
padding = 5
|
98 |
+
effective_width = width - 2 * padding
|
99 |
+
# Create a black image to draw text on.
|
100 |
+
cap_img = Image.new('RGB', (effective_width * nrows, height), color = (0, 0, 0))
|
101 |
+
draw = ImageDraw.Draw(cap_img)
|
102 |
+
draw.text((0, 0), text, color, font=font or ImageFont.load_default())
|
103 |
+
cap_img = F.convert_image_dtype(F.pil_to_tensor(cap_img), torch.float32) # (3, height, W * nrows)
|
104 |
+
cap_img = torch.split(cap_img, effective_width, dim=-1) # List of nrow elements of shape (3, height, W)
|
105 |
+
cap_img = torch.cat(cap_img, dim=1) # (3, height * nrows, W)
|
106 |
+
# Add zero padding.
|
107 |
+
cap_img = torch.nn.functional.pad(cap_img, [padding, padding, 0, padding])
|
108 |
+
return cap_img
|
109 |
+
|
110 |
+
|
111 |
+
def get_feature_extractor_for_model(model_name: str, image_size: int = 224, train: bool = True):
|
112 |
+
print(f'Using HuggingFace AutoFeatureExtractor for {model_name}.')
|
113 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
|
114 |
+
return feature_extractor
|
115 |
+
|
116 |
+
|
117 |
+
def get_pixel_values_for_model(feature_extractor, img: Image.Image):
|
118 |
+
pixel_values = feature_extractor(img.convert('RGB'), return_tensors="pt").pixel_values[0, ...] # (3, H, W)
|
119 |
+
return pixel_values
|
120 |
+
|
121 |
+
|
122 |
+
def save_checkpoint(state, is_best, filename='checkpoint'):
|
123 |
+
torch.save(state, filename + '.pth.tar')
|
124 |
+
if is_best:
|
125 |
+
shutil.copyfile(filename + '.pth.tar', filename + '_best.pth.tar')
|
126 |
+
|
127 |
+
|
128 |
+
def accuracy(output, target, padding, topk=(1,)):
|
129 |
+
"""Computes the accuracy over the k top predictions for the specified values of k"""
|
130 |
+
with torch.no_grad():
|
131 |
+
maxk = max(topk)
|
132 |
+
if output.shape[-1] < maxk:
|
133 |
+
print(f"[WARNING] Less than {maxk} predictions available. Using {output.shape[-1]} for topk.")
|
134 |
+
|
135 |
+
maxk = min(maxk, output.shape[-1])
|
136 |
+
batch_size = target.size(0)
|
137 |
+
|
138 |
+
# Take topk along the last dimension.
|
139 |
+
_, pred = output.topk(maxk, -1, True, True) # (N, T, topk)
|
140 |
+
|
141 |
+
mask = (target != padding).type(target.dtype)
|
142 |
+
target_expand = target[..., None].expand_as(pred)
|
143 |
+
correct = pred.eq(target_expand)
|
144 |
+
correct = correct * mask[..., None].expand_as(correct)
|
145 |
+
|
146 |
+
res = []
|
147 |
+
for k in topk:
|
148 |
+
correct_k = correct[..., :k].reshape(-1).float().sum(0, keepdim=True)
|
149 |
+
res.append(correct_k.mul_(100.0 / mask.sum()))
|
150 |
+
return res
|
151 |
+
|
152 |
+
|
153 |
+
def get_params_count(model, max_name_len: int = 60):
|
154 |
+
params = [(name[:max_name_len], p.numel(), str(tuple(p.shape)), p.requires_grad) for name, p in model.named_parameters()]
|
155 |
+
total_trainable_params = sum([x[1] for x in params if x[-1]])
|
156 |
+
total_nontrainable_params = sum([x[1] for x in params if not x[-1]])
|
157 |
+
return params, total_trainable_params, total_nontrainable_params
|
158 |
+
|
159 |
+
|
160 |
+
def get_params_count_str(model, max_name_len: int = 60):
|
161 |
+
padding = 70 # Hardcoded depending on desired amount of padding and separators.
|
162 |
+
params, total_trainable_params, total_nontrainable_params = get_params_count(model, max_name_len)
|
163 |
+
param_counts_text = ''
|
164 |
+
param_counts_text += '=' * (max_name_len + padding) + '\n'
|
165 |
+
param_counts_text += f'| {"Module":<{max_name_len}} | {"Trainable":<10} | {"Shape":>15} | {"Param Count":>12} |\n'
|
166 |
+
param_counts_text += '-' * (max_name_len + padding) + '\n'
|
167 |
+
for name, param_count, shape, trainable in params:
|
168 |
+
param_counts_text += f'| {name:<{max_name_len}} | {"True" if trainable else "False":<10} | {shape:>15} | {param_count:>12,} |\n'
|
169 |
+
param_counts_text += '-' * (max_name_len + padding) + '\n'
|
170 |
+
param_counts_text += f'| {"Total trainable params":<{max_name_len}} | {"":<10} | {"":<15} | {total_trainable_params:>12,} |\n'
|
171 |
+
param_counts_text += f'| {"Total non-trainable params":<{max_name_len}} | {"":<10} | {"":<15} | {total_nontrainable_params:>12,} |\n'
|
172 |
+
param_counts_text += '=' * (max_name_len + padding) + '\n'
|
173 |
+
return param_counts_text
|
174 |
+
|
175 |
+
|
176 |
+
class Summary(Enum):
|
177 |
+
NONE = 0
|
178 |
+
AVERAGE = 1
|
179 |
+
SUM = 2
|
180 |
+
COUNT = 3
|
181 |
+
|
182 |
+
|
183 |
+
class ProgressMeter(object):
|
184 |
+
def __init__(self, num_batches, meters, prefix=""):
|
185 |
+
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
|
186 |
+
self.meters = meters
|
187 |
+
self.prefix = prefix
|
188 |
+
|
189 |
+
def display(self, batch):
|
190 |
+
entries = [self.prefix + self.batch_fmtstr.format(batch)]
|
191 |
+
entries += [str(meter) for meter in self.meters]
|
192 |
+
print('\t'.join(entries))
|
193 |
+
|
194 |
+
def display_summary(self):
|
195 |
+
entries = [" *"]
|
196 |
+
entries += [meter.summary() for meter in self.meters]
|
197 |
+
print(' '.join(entries))
|
198 |
+
|
199 |
+
def _get_batch_fmtstr(self, num_batches):
|
200 |
+
num_digits = len(str(num_batches // 1))
|
201 |
+
fmt = '{:' + str(num_digits) + 'd}'
|
202 |
+
return '[' + fmt + '/' + fmt.format(num_batches) + ']'
|
203 |
+
|
204 |
+
|
205 |
+
class AverageMeter(object):
|
206 |
+
"""Computes and stores the average and current value"""
|
207 |
+
def __init__(self, name, fmt=':f', summary_type=Summary.AVERAGE):
|
208 |
+
self.name = name
|
209 |
+
self.fmt = fmt
|
210 |
+
self.summary_type = summary_type
|
211 |
+
self.reset()
|
212 |
+
|
213 |
+
def reset(self):
|
214 |
+
self.val = 0
|
215 |
+
self.avg = 0
|
216 |
+
self.sum = 0
|
217 |
+
self.count = 0
|
218 |
+
|
219 |
+
def update(self, val, n=1):
|
220 |
+
self.val = val
|
221 |
+
self.sum += val * n
|
222 |
+
self.count += n
|
223 |
+
self.avg = self.sum / self.count
|
224 |
+
|
225 |
+
def all_reduce(self):
|
226 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
227 |
+
total = torch.tensor([self.sum, self.count], dtype=torch.float32, device=device)
|
228 |
+
dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False)
|
229 |
+
self.sum, self.count = total.tolist()
|
230 |
+
self.avg = self.sum / self.count
|
231 |
+
|
232 |
+
def __str__(self):
|
233 |
+
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
|
234 |
+
return fmtstr.format(**self.__dict__)
|
235 |
+
|
236 |
+
def summary(self):
|
237 |
+
fmtstr = ''
|
238 |
+
if self.summary_type is Summary.NONE:
|
239 |
+
fmtstr = ''
|
240 |
+
elif self.summary_type is Summary.AVERAGE:
|
241 |
+
fmtstr = '{name} {avg:.3f}'
|
242 |
+
elif self.summary_type is Summary.SUM:
|
243 |
+
fmtstr = '{name} {sum:.3f}'
|
244 |
+
elif self.summary_type is Summary.COUNT:
|
245 |
+
fmtstr = '{name} {count:.3f}'
|
246 |
+
else:
|
247 |
+
raise ValueError('invalid summary type %r' % self.summary_type)
|
248 |
+
|
249 |
+
return fmtstr.format(**self.__dict__)
|
requirements.txt
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
attrs==22.2.0
|
2 |
+
certifi==2022.12.7
|
3 |
+
charset-normalizer==3.0.1
|
4 |
+
contourpy==1.0.7
|
5 |
+
cycler==0.11.0
|
6 |
+
einops==0.4.1
|
7 |
+
exceptiongroup==1.1.0
|
8 |
+
filelock==3.9.0
|
9 |
+
fonttools==4.38.0
|
10 |
+
huggingface-hub==0.12.0
|
11 |
+
idna==3.4
|
12 |
+
iniconfig==2.0.0
|
13 |
+
kiwisolver==1.4.4
|
14 |
+
matplotlib==3.6.3
|
15 |
+
numpy==1.24.2
|
16 |
+
packaging==23.0
|
17 |
+
Pillow==9.4.0
|
18 |
+
pluggy==1.0.0
|
19 |
+
pyparsing==3.0.9
|
20 |
+
pytest==7.2.1
|
21 |
+
python-dateutil==2.8.2
|
22 |
+
PyYAML==6.0
|
23 |
+
regex==2022.10.31
|
24 |
+
requests==2.28.2
|
25 |
+
six==1.16.0
|
26 |
+
tokenizers==0.12.1
|
27 |
+
tomli==2.0.1
|
28 |
+
torch==1.11.0
|
29 |
+
torchaudio==0.11.0
|
30 |
+
torchmetrics==0.9.3
|
31 |
+
torchvision==0.12.0
|
32 |
+
tqdm==4.64.1
|
33 |
+
transformers==4.21.3
|
34 |
+
typing_extensions==4.4.0
|
35 |
+
urllib3==1.26.14
|
36 |
+
warmup-scheduler==0.3.0
|
share_btn.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from https://huggingface.co/spaces/haoheliu/audioldm-text-to-audio-generation/blob/79681cd8cb235160a27cdd100673346eb1784e53/share_btn.py
|
2 |
+
|
3 |
+
community_icon_html = """<svg id="share-btn-share-icon" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32">
|
4 |
+
<path d="M20.6081 3C21.7684 3 22.8053 3.49196 23.5284 4.38415C23.9756 4.93678 24.4428 5.82749 24.4808 7.16133C24.9674 7.01707 25.4353 6.93643 25.8725 6.93643C26.9833 6.93643 27.9865 7.37587 28.696 8.17411C29.6075 9.19872 30.0124 10.4579 29.8361 11.7177C29.7523 12.3177 29.5581 12.8555 29.2678 13.3534C29.8798 13.8646 30.3306 14.5763 30.5485 15.4322C30.719 16.1032 30.8939 17.5006 29.9808 18.9403C30.0389 19.0342 30.0934 19.1319 30.1442 19.2318C30.6932 20.3074 30.7283 21.5229 30.2439 22.6548C29.5093 24.3704 27.6841 25.7219 24.1397 27.1727C21.9347 28.0753 19.9174 28.6523 19.8994 28.6575C16.9842 29.4379 14.3477 29.8345 12.0653 29.8345C7.87017 29.8345 4.8668 28.508 3.13831 25.8921C0.356375 21.6797 0.754104 17.8269 4.35369 14.1131C6.34591 12.058 7.67023 9.02782 7.94613 8.36275C8.50224 6.39343 9.97271 4.20438 12.4172 4.20438H12.4179C12.6236 4.20438 12.8314 4.2214 13.0364 4.25468C14.107 4.42854 15.0428 5.06476 15.7115 6.02205C16.4331 5.09583 17.134 4.359 17.7682 3.94323C18.7242 3.31737 19.6794 3 20.6081 3ZM20.6081 5.95917C20.2427 5.95917 19.7963 6.1197 19.3039 6.44225C17.7754 7.44319 14.8258 12.6772 13.7458 14.7131C13.3839 15.3952 12.7655 15.6837 12.2086 15.6837C11.1036 15.6837 10.2408 14.5497 12.1076 13.1085C14.9146 10.9402 13.9299 7.39584 12.5898 7.1776C12.5311 7.16799 12.4731 7.16355 12.4172 7.16355C11.1989 7.16355 10.6615 9.33114 10.6615 9.33114C10.6615 9.33114 9.0863 13.4148 6.38031 16.206C3.67434 18.998 3.5346 21.2388 5.50675 24.2246C6.85185 26.2606 9.42666 26.8753 12.0653 26.8753C14.8021 26.8753 17.6077 26.2139 19.1799 25.793C19.2574 25.7723 28.8193 22.984 27.6081 20.6107C27.4046 20.212 27.0693 20.0522 26.6471 20.0522C24.9416 20.0522 21.8393 22.6726 20.5057 22.6726C20.2076 22.6726 19.9976 22.5416 19.9116 22.222C19.3433 20.1173 28.552 19.2325 27.7758 16.1839C27.639 15.6445 27.2677 15.4256 26.746 15.4263C24.4923 15.4263 19.4358 19.5181 18.3759 19.5181C18.2949 19.5181 18.2368 19.4937 18.2053 19.4419C17.6743 18.557 17.9653 17.9394 21.7082 15.6009C25.4511 13.2617 28.0783 11.8545 26.5841 10.1752C26.4121 9.98141 26.1684 9.8956 25.8725 9.8956C23.6001 9.89634 18.2311 14.9403 18.2311 14.9403C18.2311 14.9403 16.7821 16.496 15.9057 16.496C15.7043 16.496 15.533 16.4139 15.4169 16.2112C14.7956 15.1296 21.1879 10.1286 21.5484 8.06535C21.7928 6.66715 21.3771 5.95917 20.6081 5.95917Z" fill="#FF9D00"></path>
|
5 |
+
<path d="M5.50686 24.2246C3.53472 21.2387 3.67446 18.9979 6.38043 16.206C9.08641 13.4147 10.6615 9.33111 10.6615 9.33111C10.6615 9.33111 11.2499 6.95933 12.59 7.17757C13.93 7.39581 14.9139 10.9401 12.1069 13.1084C9.29997 15.276 12.6659 16.7489 13.7459 14.713C14.8258 12.6772 17.7747 7.44316 19.304 6.44221C20.8326 5.44128 21.9089 6.00204 21.5484 8.06532C21.188 10.1286 14.795 15.1295 15.4171 16.2118C16.0391 17.2934 18.2312 14.9402 18.2312 14.9402C18.2312 14.9402 25.0907 8.49588 26.5842 10.1752C28.0776 11.8545 25.4512 13.2616 21.7082 15.6008C17.9646 17.9393 17.6744 18.557 18.2054 19.4418C18.7372 20.3266 26.9998 13.1351 27.7759 16.1838C28.5513 19.2324 19.3434 20.1173 19.9117 22.2219C20.48 24.3274 26.3979 18.2382 27.6082 20.6107C28.8193 22.9839 19.2574 25.7722 19.18 25.7929C16.0914 26.62 8.24723 28.3726 5.50686 24.2246Z" fill="#FFD21E"></path>
|
6 |
+
</svg>"""
|
7 |
+
|
8 |
+
loading_icon_html = """<svg id="share-btn-loading-icon" style="display:none;" class="animate-spin"
|
9 |
+
style="color: #ffffff;
|
10 |
+
"
|
11 |
+
xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" fill="none" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 24 24"><circle style="opacity: 0.25;" cx="12" cy="12" r="10" stroke="white" stroke-width="4"></circle><path style="opacity: 0.75;" fill="white" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path></svg>"""
|
12 |
+
|
13 |
+
share_js = """
|
14 |
+
async () => {
|
15 |
+
const html2canvas = (await import('https://cdnjs.cloudflare.com/ajax/libs/html2canvas/1.4.1/html2canvas.esm.js')).default;
|
16 |
+
async function uploadFile(file) {
|
17 |
+
console.log(file.type)
|
18 |
+
const UPLOAD_URL = 'https://huggingface.co/uploads';
|
19 |
+
const response = await fetch(UPLOAD_URL, {
|
20 |
+
method: 'POST',
|
21 |
+
headers: {
|
22 |
+
'Content-Type': file.type,
|
23 |
+
'X-Requested-With': 'XMLHttpRequest',
|
24 |
+
},
|
25 |
+
body: file, /// <- File inherits from Blob
|
26 |
+
});
|
27 |
+
const url = await response.text();
|
28 |
+
return url;
|
29 |
+
}
|
30 |
+
async function getImageFile(div) {
|
31 |
+
return new Promise((resolve, reject) =>
|
32 |
+
html2canvas(div)
|
33 |
+
.then((canvas) => {
|
34 |
+
const imageBlob = canvas.toBlob((blob) => {
|
35 |
+
const imageId = Date.now();
|
36 |
+
const fileName = "FROMAGe-" + imageId + ".jpg";
|
37 |
+
resolve(new File([blob], fileName, { type: 'image/jpeg' }));
|
38 |
+
}, 'image/jpeg', 0.95);
|
39 |
+
})
|
40 |
+
|
41 |
+
)
|
42 |
+
}
|
43 |
+
|
44 |
+
const gradioEl = document.querySelector("gradio-app").shadowRoot || document.querySelector('body > gradio-app');
|
45 |
+
const chatbotEl = gradioEl.querySelector('#chatbot')
|
46 |
+
const imageFile = await getImageFile(chatbotEl);
|
47 |
+
console.log(imageFile);
|
48 |
+
const urlChatbotImage = await uploadFile(imageFile);
|
49 |
+
console.log(urlChatbotImage);
|
50 |
+
let titleTxt = `FROMAGe Example`;
|
51 |
+
|
52 |
+
//const shareBtnEl = gradioEl.querySelector('#share-btn');
|
53 |
+
//shareBtnEl.style.pointerEvents = 'none';
|
54 |
+
const descriptionMd = `
|
55 |
+
|
56 |
+
<img src='${urlChatbotImage}'>
|
57 |
+
`;
|
58 |
+
const params = new URLSearchParams({
|
59 |
+
title: titleTxt,
|
60 |
+
description: descriptionMd,
|
61 |
+
});
|
62 |
+
const paramsStr = params.toString();
|
63 |
+
window.open(`https://huggingface.co/spaces/jykoh/fromage/discussions/new?${paramsStr}`, '_blank');
|
64 |
+
//shareBtnEl.style.removeProperty('pointer-events');
|
65 |
+
}
|
66 |
+
"""
|
67 |
+
|
68 |
+
save_js = """
|
69 |
+
async () => {
|
70 |
+
const html2canvas = (await import('https://cdnjs.cloudflare.com/ajax/libs/html2canvas/1.4.1/html2canvas.esm.js')).default;
|
71 |
+
|
72 |
+
function saveAs(uri, filename) {
|
73 |
+
var link = document.createElement('a');
|
74 |
+
if (typeof link.download === 'string') {
|
75 |
+
link.href = uri;
|
76 |
+
link.download = filename;
|
77 |
+
|
78 |
+
//Firefox requires the link to be in the body
|
79 |
+
document.body.appendChild(link);
|
80 |
+
|
81 |
+
//simulate click
|
82 |
+
link.click();
|
83 |
+
|
84 |
+
//remove the link when done
|
85 |
+
document.body.removeChild(link);
|
86 |
+
} else {
|
87 |
+
window.open(uri);
|
88 |
+
}
|
89 |
+
}
|
90 |
+
|
91 |
+
async function getImageFile(div) {
|
92 |
+
return new Promise((resolve, reject) =>
|
93 |
+
html2canvas(div)
|
94 |
+
.then((canvas) => {
|
95 |
+
const imageId = Date.now();
|
96 |
+
const fileName = "FROMAGe-" + imageId + ".png";
|
97 |
+
saveAs(canvas.toDataURL(), fileName);
|
98 |
+
})
|
99 |
+
|
100 |
+
)
|
101 |
+
}
|
102 |
+
const gradioEl = document.querySelector("gradio-app").shadowRoot || document.querySelector('body > gradio-app');
|
103 |
+
const chatbotEl = gradioEl.querySelector('#chatbot')
|
104 |
+
const imageFile = await getImageFile(chatbotEl);
|
105 |
+
console.log(imageFile);
|
106 |
+
}
|
107 |
+
"""
|