1inkusFace commited on
Commit
10ce5fa
·
verified ·
1 Parent(s): 837ba80

Update ip_adapter/ip_adapter.py

Browse files
Files changed (1) hide show
  1. ip_adapter/ip_adapter.py +59 -2
ip_adapter/ip_adapter.py CHANGED
@@ -125,6 +125,10 @@ class IPAdapter:
125
  self,
126
  pil_image,
127
  prompt=None,
 
 
 
 
128
  negative_prompt=None,
129
  scale=1.0,
130
  num_samples=4,
@@ -163,11 +167,29 @@ class IPAdapter:
163
  uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
164
  uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
165
 
 
166
  with torch.inference_mode():
167
  prompt_embeds = self.pipe._encode_prompt(
168
  prompt, device=self.device, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt)
169
  negative_prompt_embeds_, prompt_embeds_ = prompt_embeds.chunk(2)
170
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
172
  negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
173
 
@@ -204,6 +226,10 @@ class IPAdapterXL(IPAdapter):
204
  pil_image_4=None,
205
  pil_image_5=None,
206
  prompt=None,
 
 
 
 
207
  negative_prompt=None,
208
  text_scale=1.0,
209
  ip_scale=1.0,
@@ -280,11 +306,42 @@ class IPAdapterXL(IPAdapter):
280
  uncond_image_prompt_embeds = torch.cat(uncond_image_prompt_embeds_list).mean(dim=0).unsqueeze(0)
281
  uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
282
  uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
283
-
 
 
 
284
  with torch.inference_mode():
285
  prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = self.pipe.encode_prompt(
286
  prompt, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
  prompt_embeds = prompt_embeds * text_scale
 
288
  prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
289
  negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
290
 
 
125
  self,
126
  pil_image,
127
  prompt=None,
128
+ prompt2=None,
129
+ prompt3=None,
130
+ prompt4=None,
131
+ prompt5=None,
132
  negative_prompt=None,
133
  scale=1.0,
134
  num_samples=4,
 
167
  uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
168
  uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
169
 
170
+ prompt_embeds_list=[]
171
  with torch.inference_mode():
172
  prompt_embeds = self.pipe._encode_prompt(
173
  prompt, device=self.device, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt)
174
  negative_prompt_embeds_, prompt_embeds_ = prompt_embeds.chunk(2)
175
+ prompt_embeds_list.append(prompt_embeds)
176
+ if prompt2 is not None:
177
+ prompt_embeds = self.pipe._encode_prompt(
178
+ prompt2, device=self.device, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt)
179
+ prompt_embeds_list.append(prompt_embeds)
180
+ if prompt3 is not None:
181
+ prompt_embeds = self.pipe._encode_prompt(
182
+ prompt3, device=self.device, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt)
183
+ prompt_embeds_list.append(prompt_embeds)
184
+ if prompt4 is not None:
185
+ prompt_embeds = self.pipe._encode_prompt(
186
+ prompt4, device=self.device, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt)
187
+ prompt_embeds_list.append(prompt_embeds)
188
+ if prompt5 is not None:
189
+ prompt_embeds = self.pipe._encode_prompt(
190
+ prompt5, device=self.device, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt)
191
+ prompt_embeds_list.append(prompt_embeds)
192
+
193
  prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
194
  negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
195
 
 
226
  pil_image_4=None,
227
  pil_image_5=None,
228
  prompt=None,
229
+ prompt2=None,
230
+ prompt3=None,
231
+ prompt4=None,
232
+ prompt5=None,
233
  negative_prompt=None,
234
  text_scale=1.0,
235
  ip_scale=1.0,
 
306
  uncond_image_prompt_embeds = torch.cat(uncond_image_prompt_embeds_list).mean(dim=0).unsqueeze(0)
307
  uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
308
  uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
309
+
310
+ prompt_embeds_list=[]
311
+ pooled_prompt_embeds_list=[]
312
+
313
  with torch.inference_mode():
314
  prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = self.pipe.encode_prompt(
315
  prompt, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt)
316
+ prompt_embeds_list.append(prompt_embeds)
317
+ pooled_prompt_embeds_list.append(pooled_prompt_embeds)
318
+ if prompt2 is not None:
319
+ prompt_embeds, negative_prompt_embeds_, pooled_prompt_embeds, negative_pooled_prompt_embeds_ = self.pipe.encode_prompt(
320
+ prompt2, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt)
321
+ prompt_embeds_list.append(prompt_embeds)
322
+ pooled_prompt_embeds_list.append(pooled_prompt_embeds)
323
+ if prompt3 is not None:
324
+ prompt_embeds, negative_prompt_embeds_, pooled_prompt_embeds, negative_pooled_prompt_embeds_ = self.pipe.encode_prompt(
325
+ prompt3, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt)
326
+ prompt_embeds_list.append(prompt_embeds)
327
+ pooled_prompt_embeds_list.append(pooled_prompt_embeds)
328
+ if prompt4 is not None:
329
+ prompt_embeds, negative_prompt_embeds_, pooled_prompt_embeds, negative_pooled_prompt_embeds_ = self.pipe.encode_prompt(
330
+ prompt4, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt)
331
+ prompt_embeds_list.append(prompt_embeds)
332
+ pooled_prompt_embeds_list.append(pooled_prompt_embeds)
333
+ if prompt5 is not None:
334
+ prompt_embeds, negative_prompt_embeds_, pooled_prompt_embeds, negative_pooled_prompt_embeds_ = self.pipe.encode_prompt(
335
+ prompt5, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt)
336
+ prompt_embeds_list.append(prompt_embeds)
337
+ pooled_prompt_embeds_list.append(pooled_prompt_embeds)
338
+ prompt_embeds = torch.cat(prompt_embeds_list)
339
+ prompt_embeds = torch.mean(prompt_embeds,dim=0,keepdim=True)
340
+ pooled_prompt_embeds = torch.cat(pooled_prompt_embeds_list)
341
+ pooled_prompt_embeds = torch.mean(pooled_prompt_embeds,dim=0,keepdim=True)
342
+
343
  prompt_embeds = prompt_embeds * text_scale
344
+ pooled_prompt_embeds = pooled_prompt_embeds * text_scale
345
  prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
346
  negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
347