Spaces:
Running
on
Zero
Running
on
Zero
1inkusFace
commited on
Update ip_adapter/ip_adapter.py
Browse files- ip_adapter/ip_adapter.py +59 -2
ip_adapter/ip_adapter.py
CHANGED
@@ -125,6 +125,10 @@ class IPAdapter:
|
|
125 |
self,
|
126 |
pil_image,
|
127 |
prompt=None,
|
|
|
|
|
|
|
|
|
128 |
negative_prompt=None,
|
129 |
scale=1.0,
|
130 |
num_samples=4,
|
@@ -163,11 +167,29 @@ class IPAdapter:
|
|
163 |
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
|
164 |
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
165 |
|
|
|
166 |
with torch.inference_mode():
|
167 |
prompt_embeds = self.pipe._encode_prompt(
|
168 |
prompt, device=self.device, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt)
|
169 |
negative_prompt_embeds_, prompt_embeds_ = prompt_embeds.chunk(2)
|
170 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
171 |
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
|
172 |
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
|
173 |
|
@@ -204,6 +226,10 @@ class IPAdapterXL(IPAdapter):
|
|
204 |
pil_image_4=None,
|
205 |
pil_image_5=None,
|
206 |
prompt=None,
|
|
|
|
|
|
|
|
|
207 |
negative_prompt=None,
|
208 |
text_scale=1.0,
|
209 |
ip_scale=1.0,
|
@@ -280,11 +306,42 @@ class IPAdapterXL(IPAdapter):
|
|
280 |
uncond_image_prompt_embeds = torch.cat(uncond_image_prompt_embeds_list).mean(dim=0).unsqueeze(0)
|
281 |
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
|
282 |
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
283 |
-
|
|
|
|
|
|
|
284 |
with torch.inference_mode():
|
285 |
prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = self.pipe.encode_prompt(
|
286 |
prompt, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
287 |
prompt_embeds = prompt_embeds * text_scale
|
|
|
288 |
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
|
289 |
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
|
290 |
|
|
|
125 |
self,
|
126 |
pil_image,
|
127 |
prompt=None,
|
128 |
+
prompt2=None,
|
129 |
+
prompt3=None,
|
130 |
+
prompt4=None,
|
131 |
+
prompt5=None,
|
132 |
negative_prompt=None,
|
133 |
scale=1.0,
|
134 |
num_samples=4,
|
|
|
167 |
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
|
168 |
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
169 |
|
170 |
+
prompt_embeds_list=[]
|
171 |
with torch.inference_mode():
|
172 |
prompt_embeds = self.pipe._encode_prompt(
|
173 |
prompt, device=self.device, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt)
|
174 |
negative_prompt_embeds_, prompt_embeds_ = prompt_embeds.chunk(2)
|
175 |
+
prompt_embeds_list.append(prompt_embeds)
|
176 |
+
if prompt2 is not None:
|
177 |
+
prompt_embeds = self.pipe._encode_prompt(
|
178 |
+
prompt2, device=self.device, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt)
|
179 |
+
prompt_embeds_list.append(prompt_embeds)
|
180 |
+
if prompt3 is not None:
|
181 |
+
prompt_embeds = self.pipe._encode_prompt(
|
182 |
+
prompt3, device=self.device, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt)
|
183 |
+
prompt_embeds_list.append(prompt_embeds)
|
184 |
+
if prompt4 is not None:
|
185 |
+
prompt_embeds = self.pipe._encode_prompt(
|
186 |
+
prompt4, device=self.device, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt)
|
187 |
+
prompt_embeds_list.append(prompt_embeds)
|
188 |
+
if prompt5 is not None:
|
189 |
+
prompt_embeds = self.pipe._encode_prompt(
|
190 |
+
prompt5, device=self.device, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt)
|
191 |
+
prompt_embeds_list.append(prompt_embeds)
|
192 |
+
|
193 |
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
|
194 |
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
|
195 |
|
|
|
226 |
pil_image_4=None,
|
227 |
pil_image_5=None,
|
228 |
prompt=None,
|
229 |
+
prompt2=None,
|
230 |
+
prompt3=None,
|
231 |
+
prompt4=None,
|
232 |
+
prompt5=None,
|
233 |
negative_prompt=None,
|
234 |
text_scale=1.0,
|
235 |
ip_scale=1.0,
|
|
|
306 |
uncond_image_prompt_embeds = torch.cat(uncond_image_prompt_embeds_list).mean(dim=0).unsqueeze(0)
|
307 |
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
|
308 |
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
309 |
+
|
310 |
+
prompt_embeds_list=[]
|
311 |
+
pooled_prompt_embeds_list=[]
|
312 |
+
|
313 |
with torch.inference_mode():
|
314 |
prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = self.pipe.encode_prompt(
|
315 |
prompt, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt)
|
316 |
+
prompt_embeds_list.append(prompt_embeds)
|
317 |
+
pooled_prompt_embeds_list.append(pooled_prompt_embeds)
|
318 |
+
if prompt2 is not None:
|
319 |
+
prompt_embeds, negative_prompt_embeds_, pooled_prompt_embeds, negative_pooled_prompt_embeds_ = self.pipe.encode_prompt(
|
320 |
+
prompt2, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt)
|
321 |
+
prompt_embeds_list.append(prompt_embeds)
|
322 |
+
pooled_prompt_embeds_list.append(pooled_prompt_embeds)
|
323 |
+
if prompt3 is not None:
|
324 |
+
prompt_embeds, negative_prompt_embeds_, pooled_prompt_embeds, negative_pooled_prompt_embeds_ = self.pipe.encode_prompt(
|
325 |
+
prompt3, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt)
|
326 |
+
prompt_embeds_list.append(prompt_embeds)
|
327 |
+
pooled_prompt_embeds_list.append(pooled_prompt_embeds)
|
328 |
+
if prompt4 is not None:
|
329 |
+
prompt_embeds, negative_prompt_embeds_, pooled_prompt_embeds, negative_pooled_prompt_embeds_ = self.pipe.encode_prompt(
|
330 |
+
prompt4, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt)
|
331 |
+
prompt_embeds_list.append(prompt_embeds)
|
332 |
+
pooled_prompt_embeds_list.append(pooled_prompt_embeds)
|
333 |
+
if prompt5 is not None:
|
334 |
+
prompt_embeds, negative_prompt_embeds_, pooled_prompt_embeds, negative_pooled_prompt_embeds_ = self.pipe.encode_prompt(
|
335 |
+
prompt5, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt)
|
336 |
+
prompt_embeds_list.append(prompt_embeds)
|
337 |
+
pooled_prompt_embeds_list.append(pooled_prompt_embeds)
|
338 |
+
prompt_embeds = torch.cat(prompt_embeds_list)
|
339 |
+
prompt_embeds = torch.mean(prompt_embeds,dim=0,keepdim=True)
|
340 |
+
pooled_prompt_embeds = torch.cat(pooled_prompt_embeds_list)
|
341 |
+
pooled_prompt_embeds = torch.mean(pooled_prompt_embeds,dim=0,keepdim=True)
|
342 |
+
|
343 |
prompt_embeds = prompt_embeds * text_scale
|
344 |
+
pooled_prompt_embeds = pooled_prompt_embeds * text_scale
|
345 |
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
|
346 |
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
|
347 |
|