Spaces:
Runtime error
Runtime error
Update clip_slider_pipeline.py
Browse files- clip_slider_pipeline.py +7 -3
clip_slider_pipeline.py
CHANGED
@@ -4,7 +4,7 @@ import random
|
|
4 |
from tqdm import tqdm
|
5 |
from constants import SUBJECTS, MEDIUMS
|
6 |
from PIL import Image
|
7 |
-
|
8 |
class CLIPSlider:
|
9 |
def __init__(
|
10 |
self,
|
@@ -214,7 +214,7 @@ class CLIPSliderXL(CLIPSlider):
|
|
214 |
):
|
215 |
# if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
|
216 |
# if pooler token only [-4,4] work well
|
217 |
-
|
218 |
text_encoders = [self.pipe.text_encoder, self.pipe.text_encoder_2]
|
219 |
tokenizers = [self.pipe.tokenizer, self.pipe.tokenizer_2]
|
220 |
with torch.no_grad():
|
@@ -282,9 +282,13 @@ class CLIPSliderXL(CLIPSlider):
|
|
282 |
|
283 |
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
|
284 |
pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
|
285 |
-
|
|
|
286 |
torch.manual_seed(seed)
|
|
|
287 |
image = self.pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
|
288 |
**pipeline_kwargs).images[0]
|
|
|
|
|
289 |
|
290 |
return image
|
|
|
4 |
from tqdm import tqdm
|
5 |
from constants import SUBJECTS, MEDIUMS
|
6 |
from PIL import Image
|
7 |
+
import time
|
8 |
class CLIPSlider:
|
9 |
def __init__(
|
10 |
self,
|
|
|
214 |
):
|
215 |
# if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
|
216 |
# if pooler token only [-4,4] work well
|
217 |
+
start_time = time.time()
|
218 |
text_encoders = [self.pipe.text_encoder, self.pipe.text_encoder_2]
|
219 |
tokenizers = [self.pipe.tokenizer, self.pipe.tokenizer_2]
|
220 |
with torch.no_grad():
|
|
|
282 |
|
283 |
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
|
284 |
pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
|
285 |
+
end_time = time.time()
|
286 |
+
print(f"generation time - before pipe: {end_time - start_time:.2f} ms")
|
287 |
torch.manual_seed(seed)
|
288 |
+
start_time = time.time()
|
289 |
image = self.pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
|
290 |
**pipeline_kwargs).images[0]
|
291 |
+
end_time = time.time()
|
292 |
+
print(f"generation time - pipe: {end_time - start_time:.2f} ms")
|
293 |
|
294 |
return image
|