Update app.py
Browse files
app.py
CHANGED
@@ -15,7 +15,7 @@ from transformers import (
|
|
15 |
CLIPTextModelWithProjection,
|
16 |
CLIPVisionModelWithProjection,
|
17 |
CLIPImageProcessor,
|
18 |
-
CLIPTokenizer
|
19 |
)
|
20 |
|
21 |
from transformers import CLIPTokenizer
|
@@ -33,10 +33,11 @@ if torch.cuda.is_available():
|
|
33 |
__device__ = "cuda"
|
34 |
__dtype__ = torch.float16
|
35 |
|
|
|
36 |
class Model:
|
37 |
def __init__(self):
|
38 |
self.device = __device__
|
39 |
-
|
40 |
self.text_encoder = (
|
41 |
CLIPTextModelWithProjection.from_pretrained(
|
42 |
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
|
@@ -65,102 +66,48 @@ class Model:
|
|
65 |
self.pipe = DiffusionPipeline.from_pretrained(
|
66 |
"kandinsky-community/kandinsky-2-2-decoder", torch_dtype=__dtype__
|
67 |
).to(self.device)
|
68 |
-
|
69 |
-
def inference(self, raw_data):
|
|
|
70 |
image_emb, negative_image_emb = self.pipe_prior(
|
71 |
raw_data=raw_data,
|
|
|
72 |
).to_tuple()
|
73 |
image = self.pipe(
|
74 |
image_embeds=image_emb,
|
75 |
negative_image_embeds=negative_image_emb,
|
76 |
num_inference_steps=50,
|
77 |
-
guidance_scale=
|
|
|
78 |
).images[0]
|
79 |
return image
|
80 |
-
|
81 |
-
def
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
data: dict[str, Any] = {}
|
91 |
-
data['text'] = text
|
92 |
-
|
93 |
-
txt = self.tokenizer(
|
94 |
-
text,
|
95 |
-
padding='max_length',
|
96 |
-
truncation=True,
|
97 |
-
return_tensors='pt',
|
98 |
-
)
|
99 |
-
txt_items = {k: v.to(device) for k, v in txt.items()}
|
100 |
-
new_feats = self.text_encoder(**txt_items)
|
101 |
-
new_last_hidden_states = new_feats.last_hidden_state[0].cpu().numpy()
|
102 |
-
|
103 |
-
plt.imshow(image)
|
104 |
-
plt.title('image')
|
105 |
-
plt.savefig('image_testt2.png')
|
106 |
-
plt.show()
|
107 |
-
|
108 |
-
mask_img = self.image_processor(image, return_tensors="pt").to(__device__)
|
109 |
-
vision_feats = self.vision_encoder(
|
110 |
-
**mask_img
|
111 |
-
).image_embeds
|
112 |
-
|
113 |
-
entity_tokens = self.tokenizer(keyword)["input_ids"][1:-1]
|
114 |
-
for tid in entity_tokens:
|
115 |
-
indices = np.where(txt_items["input_ids"][0].cpu().numpy() == tid)[0]
|
116 |
-
new_last_hidden_states[indices] = vision_feats[0].cpu().numpy()
|
117 |
-
print(indices)
|
118 |
-
|
119 |
-
if image2 is not None:
|
120 |
-
mask_img2 = self.image_processor(image2, return_tensors="pt").to(__device__)
|
121 |
-
vision_feats2 = self.vision_encoder(
|
122 |
-
**mask_img2
|
123 |
-
).image_embeds
|
124 |
-
if keyword2 is not None:
|
125 |
-
entity_tokens = self.tokenizer(keyword2)["input_ids"][1:-1]
|
126 |
-
for tid in entity_tokens:
|
127 |
-
indices = np.where(txt_items["input_ids"][0].cpu().numpy() == tid)[0]
|
128 |
-
new_last_hidden_states[indices] = vision_feats2[0].cpu().numpy()
|
129 |
-
print(indices)
|
130 |
-
|
131 |
-
text_feats = {
|
132 |
-
"prompt_embeds": new_feats.text_embeds.to(__device__),
|
133 |
-
"text_encoder_hidden_states": torch.tensor(new_last_hidden_states).unsqueeze(0).to(__device__),
|
134 |
-
"text_mask": txt_items["attention_mask"].to(__device__),
|
135 |
-
}
|
136 |
-
return text_feats
|
137 |
-
|
138 |
-
def run(self,
|
139 |
-
image: dict[str, PIL.Image.Image],
|
140 |
-
keyword: str,
|
141 |
-
image2: dict[str, PIL.Image.Image],
|
142 |
-
keyword2: str,
|
143 |
-
text: str,
|
144 |
-
):
|
145 |
-
|
146 |
-
# aug_feats = self.process_data(image["composite"], keyword, image2["composite"], keyword2, text)
|
147 |
sub_imgs = [image["composite"]]
|
148 |
-
if image2:
|
149 |
-
sub_imgs.append(image2["composite"])
|
150 |
sun_keywords = [keyword]
|
151 |
-
if keyword2:
|
152 |
sun_keywords.append(keyword2)
|
|
|
|
|
153 |
raw_data = {
|
154 |
"prompt": text,
|
155 |
"subject_images": sub_imgs,
|
156 |
-
"subject_keywords": sun_keywords
|
157 |
}
|
158 |
-
image = self.inference(raw_data)
|
159 |
return image
|
160 |
|
161 |
-
def create_demo():
|
162 |
|
163 |
-
|
|
|
164 |
1. Upload your image.
|
165 |
2. <span style='color: red;'>**Upload a masked subject image with white blankspace or whiten out manually using brush tool.**
|
166 |
3. Input a Keyword i.e. 'Dog'
|
@@ -169,7 +116,7 @@ def create_demo():
|
|
169 |
4-2. Input the Keyword i.e. 'Sunglasses'
|
170 |
3. Input proper text prompts, such as "A photo of Dog" or "A Dog wearing sunglasses", Please use the same keyword in the prompt.
|
171 |
4. Click the Run button.
|
172 |
-
|
173 |
|
174 |
model = Model()
|
175 |
|
@@ -180,6 +127,8 @@ def create_demo():
|
|
180 |
|
181 |
<p style="text-align: center; color: red;">This demo is currently hosted on either a small GPU or CPU. We will soon provide high-end GPU support.</p>
|
182 |
<p style="text-align: center; color: red;">Please follow the instructions from here to run it locally: <a href="https://github.com/eclipse-t2i/lambda-eclipse-inference">GitHub Inference Code</a></p>
|
|
|
|
|
183 |
"""
|
184 |
)
|
185 |
gr.Markdown(USAGE)
|
@@ -187,28 +136,41 @@ def create_demo():
|
|
187 |
with gr.Column():
|
188 |
with gr.Group():
|
189 |
gr.Markdown(
|
190 |
-
|
191 |
-
|
|
|
|
|
|
|
|
|
|
|
192 |
keyword = gr.Text(
|
193 |
-
label=
|
194 |
placeholder='e.g. "Dog", "Goofie"',
|
195 |
-
info=
|
|
|
196 |
gr.Markdown(
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
label=
|
|
|
|
|
|
|
|
|
|
|
201 |
placeholder='e.g. "Sunglasses", "Grand Canyon"',
|
202 |
-
info=
|
|
|
203 |
prompt = gr.Text(
|
204 |
-
label=
|
205 |
placeholder='e.g. "A photo of dog", "A dog wearing sunglasses"',
|
206 |
-
info=
|
|
|
207 |
|
208 |
-
run_button = gr.Button(
|
209 |
|
210 |
with gr.Column():
|
211 |
-
result = gr.Image(label=
|
212 |
|
213 |
inputs = [
|
214 |
image,
|
@@ -217,18 +179,77 @@ def create_demo():
|
|
217 |
keyword2,
|
218 |
prompt,
|
219 |
]
|
220 |
-
|
221 |
gr.Examples(
|
222 |
-
examples=[
|
223 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
224 |
fn=model.run,
|
225 |
outputs=result,
|
226 |
)
|
227 |
-
|
228 |
run_button.click(fn=model.run, inputs=inputs, outputs=result)
|
229 |
return demo
|
230 |
|
231 |
|
232 |
-
if __name__ ==
|
233 |
demo = create_demo()
|
234 |
-
demo.queue(max_size=20).launch()
|
|
|
15 |
CLIPTextModelWithProjection,
|
16 |
CLIPVisionModelWithProjection,
|
17 |
CLIPImageProcessor,
|
18 |
+
CLIPTokenizer,
|
19 |
)
|
20 |
|
21 |
from transformers import CLIPTokenizer
|
|
|
33 |
__device__ = "cuda"
|
34 |
__dtype__ = torch.float16
|
35 |
|
36 |
+
|
37 |
class Model:
|
38 |
def __init__(self):
|
39 |
self.device = __device__
|
40 |
+
|
41 |
self.text_encoder = (
|
42 |
CLIPTextModelWithProjection.from_pretrained(
|
43 |
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
|
|
|
66 |
self.pipe = DiffusionPipeline.from_pretrained(
|
67 |
"kandinsky-community/kandinsky-2-2-decoder", torch_dtype=__dtype__
|
68 |
).to(self.device)
|
69 |
+
|
70 |
+
def inference(self, raw_data, seed):
|
71 |
+
generator = torch.Generator(device="cuda").manual_seed(seed)
|
72 |
image_emb, negative_image_emb = self.pipe_prior(
|
73 |
raw_data=raw_data,
|
74 |
+
generator=generator,
|
75 |
).to_tuple()
|
76 |
image = self.pipe(
|
77 |
image_embeds=image_emb,
|
78 |
negative_image_embeds=negative_image_emb,
|
79 |
num_inference_steps=50,
|
80 |
+
guidance_scale=7.5,
|
81 |
+
generator=generator,
|
82 |
).images[0]
|
83 |
return image
|
84 |
+
|
85 |
+
def run(
|
86 |
+
self,
|
87 |
+
image: dict[str, PIL.Image.Image],
|
88 |
+
keyword: str,
|
89 |
+
image2: dict[str, PIL.Image.Image],
|
90 |
+
keyword2: str,
|
91 |
+
text: str,
|
92 |
+
seed: int,
|
93 |
+
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
sub_imgs = [image["composite"]]
|
|
|
|
|
95 |
sun_keywords = [keyword]
|
96 |
+
if keyword2 and keyword2 != "no subject":
|
97 |
sun_keywords.append(keyword2)
|
98 |
+
if image2:
|
99 |
+
sub_imgs.append(image2["composite"])
|
100 |
raw_data = {
|
101 |
"prompt": text,
|
102 |
"subject_images": sub_imgs,
|
103 |
+
"subject_keywords": sun_keywords,
|
104 |
}
|
105 |
+
image = self.inference(raw_data, seed)
|
106 |
return image
|
107 |
|
|
|
108 |
|
109 |
+
def create_demo():
|
110 |
+
USAGE = """## To run the demo, you should:
|
111 |
1. Upload your image.
|
112 |
2. <span style='color: red;'>**Upload a masked subject image with white blankspace or whiten out manually using brush tool.**
|
113 |
3. Input a Keyword i.e. 'Dog'
|
|
|
116 |
4-2. Input the Keyword i.e. 'Sunglasses'
|
117 |
3. Input proper text prompts, such as "A photo of Dog" or "A Dog wearing sunglasses", Please use the same keyword in the prompt.
|
118 |
4. Click the Run button.
|
119 |
+
"""
|
120 |
|
121 |
model = Model()
|
122 |
|
|
|
127 |
|
128 |
<p style="text-align: center; color: red;">This demo is currently hosted on either a small GPU or CPU. We will soon provide high-end GPU support.</p>
|
129 |
<p style="text-align: center; color: red;">Please follow the instructions from here to run it locally: <a href="https://github.com/eclipse-t2i/lambda-eclipse-inference">GitHub Inference Code</a></p>
|
130 |
+
|
131 |
+
<a href="https://colab.research.google.com/drive/1VcqzXZmilntec3AsIyzCqlstEhX4Pa1o?usp=sharing" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
|
132 |
"""
|
133 |
)
|
134 |
gr.Markdown(USAGE)
|
|
|
136 |
with gr.Column():
|
137 |
with gr.Group():
|
138 |
gr.Markdown(
|
139 |
+
"Upload your first masked subject image or mask out marginal space"
|
140 |
+
)
|
141 |
+
image = gr.ImageEditor(
|
142 |
+
label="Input",
|
143 |
+
type="pil",
|
144 |
+
brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed"),
|
145 |
+
)
|
146 |
keyword = gr.Text(
|
147 |
+
label="Keyword",
|
148 |
placeholder='e.g. "Dog", "Goofie"',
|
149 |
+
info="Keyword for first subject",
|
150 |
+
)
|
151 |
gr.Markdown(
|
152 |
+
"For Multi-Subject generation : Upload your second masked subject image or mask out marginal space"
|
153 |
+
)
|
154 |
+
image2 = gr.ImageEditor(
|
155 |
+
label="Input",
|
156 |
+
type="pil",
|
157 |
+
brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed"),
|
158 |
+
)
|
159 |
+
keyword2 = gr.Text(
|
160 |
+
label="Keyword",
|
161 |
placeholder='e.g. "Sunglasses", "Grand Canyon"',
|
162 |
+
info="Keyword for second subject",
|
163 |
+
)
|
164 |
prompt = gr.Text(
|
165 |
+
label="Prompt",
|
166 |
placeholder='e.g. "A photo of dog", "A dog wearing sunglasses"',
|
167 |
+
info="Keep the keywords used previously in the prompt",
|
168 |
+
)
|
169 |
|
170 |
+
run_button = gr.Button("Run")
|
171 |
|
172 |
with gr.Column():
|
173 |
+
result = gr.Image(label="Result")
|
174 |
|
175 |
inputs = [
|
176 |
image,
|
|
|
179 |
keyword2,
|
180 |
prompt,
|
181 |
]
|
182 |
+
|
183 |
gr.Examples(
|
184 |
+
examples=[
|
185 |
+
[
|
186 |
+
os.path.join(os.path.dirname(__file__), "./assets/luffy.jpg"),
|
187 |
+
"luffy",
|
188 |
+
os.path.join(os.path.dirname(__file__), "./assets/white.jpg"),
|
189 |
+
"no subject",
|
190 |
+
"luffy holding a sword",
|
191 |
+
],
|
192 |
+
[
|
193 |
+
os.path.join(os.path.dirname(__file__), "./assets/luffy.jpg"),
|
194 |
+
"luffy",
|
195 |
+
os.path.join(os.path.dirname(__file__), "./assets/white.jpg"),
|
196 |
+
"no subject",
|
197 |
+
"luffy in the living room",
|
198 |
+
],
|
199 |
+
[
|
200 |
+
os.path.join(os.path.dirname(__file__), "./assets/teapot.jpg"),
|
201 |
+
"teapot",
|
202 |
+
os.path.join(os.path.dirname(__file__), "./assets/white.jpg"),
|
203 |
+
"no subject",
|
204 |
+
"teapot on a cobblestone street",
|
205 |
+
],
|
206 |
+
[
|
207 |
+
os.path.join(os.path.dirname(__file__), "./assets/trex.jpg"),
|
208 |
+
"trex",
|
209 |
+
os.path.join(os.path.dirname(__file__), "./assets/white.jpg"),
|
210 |
+
"no subject",
|
211 |
+
"trex near a river",
|
212 |
+
],
|
213 |
+
[
|
214 |
+
os.path.join(os.path.dirname(__file__), "./assets/cat.png"),
|
215 |
+
"cat",
|
216 |
+
os.path.join(
|
217 |
+
os.path.dirname(__file__), "./assets/blue_sunglasses.png"
|
218 |
+
),
|
219 |
+
"glasses",
|
220 |
+
"A cat wearing glasses on a snowy field",
|
221 |
+
],
|
222 |
+
[
|
223 |
+
os.path.join(os.path.dirname(__file__), "./assets/statue.jpg"),
|
224 |
+
"statue",
|
225 |
+
os.path.join(os.path.dirname(__file__), "./assets/toilet.jpg"),
|
226 |
+
"toilet",
|
227 |
+
"statue sitting on a toilet",
|
228 |
+
],
|
229 |
+
[
|
230 |
+
os.path.join(os.path.dirname(__file__), "./assets/teddy.jpg"),
|
231 |
+
"teddy",
|
232 |
+
os.path.join(os.path.dirname(__file__), "./assets/luffy_hat.jpg"),
|
233 |
+
"hat",
|
234 |
+
"a teddy wearing the hat at a beach",
|
235 |
+
],
|
236 |
+
[
|
237 |
+
os.path.join(os.path.dirname(__file__), "./assets/chair.jpg"),
|
238 |
+
"chair",
|
239 |
+
os.path.join(os.path.dirname(__file__), "./assets/table.jpg"),
|
240 |
+
"table",
|
241 |
+
"a chair and table in living room",
|
242 |
+
],
|
243 |
+
],
|
244 |
+
inputs=inputs,
|
245 |
fn=model.run,
|
246 |
outputs=result,
|
247 |
)
|
248 |
+
|
249 |
run_button.click(fn=model.run, inputs=inputs, outputs=result)
|
250 |
return demo
|
251 |
|
252 |
|
253 |
+
if __name__ == "__main__":
|
254 |
demo = create_demo()
|
255 |
+
demo.queue(max_size=20).launch()
|