1inkusFace commited on
Commit
4843f86
·
verified ·
1 Parent(s): b810973

Update pipeline_stable_diffusion_3_ipa.py

Browse files
Files changed (1) hide show
  1. pipeline_stable_diffusion_3_ipa.py +13 -1
pipeline_stable_diffusion_3_ipa.py CHANGED
@@ -1170,8 +1170,20 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
1170
  image_prompt_embeds_5 = self.encode_clip_image_emb(clip_image, device, dtype)
1171
  image_prompt_embeds_5 = image_prompt_embeds_5 * scale_5
1172
  image_prompt_embeds_list.append(image_prompt_embeds_5)
 
 
 
 
 
 
 
 
 
 
1173
 
1174
- clip_image_embeds = torch.cat(image_prompt_embeds_list).mean(dim=0).unsqueeze(0)
 
 
1175
 
1176
  # 4. Prepare timesteps
1177
  timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
 
1170
  image_prompt_embeds_5 = self.encode_clip_image_emb(clip_image, device, dtype)
1171
  image_prompt_embeds_5 = image_prompt_embeds_5 * scale_5
1172
  image_prompt_embeds_list.append(image_prompt_embeds_5)
1173
+
1174
+ # Concatenate the image embeddings
1175
+ concatenated_embeds = torch.cat(image_prompt_embeds_list, dim=1) # Concatenate along dimension 1
1176
+
1177
+ # Create a linear layer
1178
+ embedding_dim = concatenated_embeds.shape[-1] # Get the embedding dimension
1179
+ linear_layer = nn.Linear(embedding_dim * len(image_prompt_embeds_list), embedding_dim)
1180
+
1181
+ # Pass the concatenated embeddings through the linear layer
1182
+ combined_embeds = linear_layer(concatenated_embeds)
1183
 
1184
+ # Add a ReLU activation for non-linearity (optional)
1185
+ combined_embeds = torch.relu(combined_embeds)
1186
+ clip_image_embeds = clip_image_embeds #torch.cat(image_prompt_embeds_list).mean(dim=0).unsqueeze(0)
1187
 
1188
  # 4. Prepare timesteps
1189
  timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)