1inkusFace commited on
Commit
e1b3316
·
verified ·
1 Parent(s): 81c0ae2

Update pipeline_stable_diffusion_3_ipa.py

Browse files
Files changed (1) hide show
  1. pipeline_stable_diffusion_3_ipa.py +26 -10
pipeline_stable_diffusion_3_ipa.py CHANGED
@@ -1147,31 +1147,47 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
1147
  if clip_image != None:
1148
  print('Using primary image.')
1149
  clip_image = clip_image.resize((max(clip_image.size), max(clip_image.size)))
1150
- clip_image_embeds_1 = self.encode_clip_image_emb(clip_image, device, dtype)
 
 
 
 
1151
  image_prompt_embeds_list.append(clip_image_embeds_1)
1152
  if clip_image_2 != None:
1153
  print('Using secondary image.')
1154
  clip_image_2 = clip_image_2.resize((max(clip_image_2.size), max(clip_image_2.size)))
1155
- image_prompt_embeds_2 = self.encode_clip_image_emb(clip_image_2, device, dtype)
1156
- image_prompt_embeds_2 = image_prompt_embeds_2 * scale_2
 
 
 
1157
  image_prompt_embeds_list.append(image_prompt_embeds_2)
1158
  if clip_image_3 != None:
1159
  print('Using tertiary image.')
1160
  clip_image_3 = clip_image_3.resize((max(clip_image_3.size), max(clip_image_3.size)))
1161
- image_prompt_embeds_3 = self.encode_clip_image_emb(clip_image_3, device, dtype)
1162
- image_prompt_embeds_3 = image_prompt_embeds_3 * scale_3
 
 
 
1163
  image_prompt_embeds_list.append(image_prompt_embeds_3)
1164
  if clip_image_4 != None:
1165
  print('Using quaternary image.')
1166
  clip_image_4 = clip_image_4.resize((max(clip_image_4.size), max(clip_image_4.size)))
1167
- image_prompt_embeds_4 = self.encode_clip_image_emb(clip_image_4, device, dtype)
1168
- image_prompt_embeds_4 = image_prompt_embeds_4 * scale_4
 
 
 
1169
  image_prompt_embeds_list.append(image_prompt_embeds_4)
1170
  if clip_image_5 != None:
1171
  print('Using quinary image.')
1172
  clip_image_5 = clip_image_5.resize((max(clip_image_5.size), max(clip_image_5.size)))
1173
- image_prompt_embeds_5 = self.encode_clip_image_emb(clip_image_5, device, dtype)
1174
- image_prompt_embeds_5 = image_prompt_embeds_5 * scale_5
 
 
 
1175
  image_prompt_embeds_list.append(image_prompt_embeds_5)
1176
 
1177
  # Concatenate the image embeddings
@@ -1190,7 +1206,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
1190
 
1191
  #clip_image_embeds = torch.mean(torch.stack(image_prompt_embeds_list), dim=0) # working
1192
 
1193
- clip_image_embeds = torch.cat(image_prompt_embeds_list).mean(dim=0).unsqueeze(0)
1194
  bs_embed, seq_len, _ = clip_image_embeds.shape
1195
  clip_image_embeds = clip_image_embeds.repeat(1, 1, 1)
1196
  clip_image_embeds = clip_image_embeds.view(2, -1)
 
1147
  if clip_image != None:
1148
  print('Using primary image.')
1149
  clip_image = clip_image.resize((max(clip_image.size), max(clip_image.size)))
1150
+ #clip_image_embeds_1 = self.encode_clip_image_emb(clip_image, device, dtype)
1151
+ clip_image_embeds_1 = self.clip_image_processor(images=clip_image, return_tensors="pt").pixel_values
1152
+ clip_image_embeds_1 = clip_image_embeds_1.to(device, dtype=dtype)
1153
+ clip_image_embeds_1 = self.image_encoder(clip_image_embeds_1, output_hidden_states=True).hidden_states[-2]
1154
+ clip_image_embeds_1 = clip_image_embeds_1 * scale_1
1155
  image_prompt_embeds_list.append(clip_image_embeds_1)
1156
  if clip_image_2 != None:
1157
  print('Using secondary image.')
1158
  clip_image_2 = clip_image_2.resize((max(clip_image_2.size), max(clip_image_2.size)))
1159
+ #clip_image_embeds_2 = self.encode_clip_image_emb(clip_image, device, dtype)
1160
+ clip_image_embeds_2 = self.clip_image_processor(images=clip_image_2, return_tensors="pt").pixel_values
1161
+ clip_image_embeds_2 = clip_image_embeds_2.to(device, dtype=dtype)
1162
+ clip_image_embeds_2 = self.image_encoder(clip_image_embeds_2, output_hidden_states=True).hidden_states[-2]
1163
+ clip_image_embeds_2 = clip_image_embeds_2 * scale_2
1164
  image_prompt_embeds_list.append(image_prompt_embeds_2)
1165
  if clip_image_3 != None:
1166
  print('Using tertiary image.')
1167
  clip_image_3 = clip_image_3.resize((max(clip_image_3.size), max(clip_image_3.size)))
1168
+ #clip_image_embeds_3 = self.encode_clip_image_emb(clip_image, device, dtype)
1169
+ clip_image_embeds_3 = self.clip_image_processor(images=clip_image_3, return_tensors="pt").pixel_values
1170
+ clip_image_embeds_3 = clip_image_embeds_3.to(device, dtype=dtype)
1171
+ clip_image_embeds_3 = self.image_encoder(clip_image_embeds_3, output_hidden_states=True).hidden_states[-2]
1172
+ clip_image_embeds_3 = clip_image_embeds_3 * scale_3
1173
  image_prompt_embeds_list.append(image_prompt_embeds_3)
1174
  if clip_image_4 != None:
1175
  print('Using quaternary image.')
1176
  clip_image_4 = clip_image_4.resize((max(clip_image_4.size), max(clip_image_4.size)))
1177
+ #clip_image_embeds_4 = self.encode_clip_image_emb(clip_image, device, dtype)
1178
+ clip_image_embeds_4 = self.clip_image_processor(images=clip_image_4, return_tensors="pt").pixel_values
1179
+ clip_image_embeds_4 = clip_image_embeds_4.to(device, dtype=dtype)
1180
+ clip_image_embeds_2 = self.image_encoder(clip_image_embeds_4, output_hidden_states=True).hidden_states[-2]
1181
+ clip_image_embeds_4 = clip_image_embeds_4 * scale_4
1182
  image_prompt_embeds_list.append(image_prompt_embeds_4)
1183
  if clip_image_5 != None:
1184
  print('Using quinary image.')
1185
  clip_image_5 = clip_image_5.resize((max(clip_image_5.size), max(clip_image_5.size)))
1186
+ #clip_image_embeds_5 = self.encode_clip_image_emb(clip_image, device, dtype)
1187
+ clip_image_embeds_5 = self.clip_image_processor(images=clip_image_5, return_tensors="pt").pixel_values
1188
+ clip_image_embeds_5 = clip_image_embeds_5.to(device, dtype=dtype)
1189
+ clip_image_embeds_5 = self.image_encoder(clip_image_embeds_5, output_hidden_states=True).hidden_states[-2]
1190
+ clip_image_embeds_5 = clip_image_embeds_5 * scale_5
1191
  image_prompt_embeds_list.append(image_prompt_embeds_5)
1192
 
1193
  # Concatenate the image embeddings
 
1206
 
1207
  #clip_image_embeds = torch.mean(torch.stack(image_prompt_embeds_list), dim=0) # working
1208
 
1209
+ clip_image_embeds = torch.cat([torch.zeros_like(image_prompt_embeds_list),image_prompt_embeds_list]).mean(dim=0).unsqueeze(0)
1210
  bs_embed, seq_len, _ = clip_image_embeds.shape
1211
  clip_image_embeds = clip_image_embeds.repeat(1, 1, 1)
1212
  clip_image_embeds = clip_image_embeds.view(2, -1)