Abdulrahman1989 commited on
Commit
d70cd87
·
1 Parent(s): 6765d50

fix SDXLImageGenerator

Browse files
Files changed (1) hide show
  1. SDXLImageGenerator.py +13 -8
SDXLImageGenerator.py CHANGED
@@ -1,6 +1,4 @@
1
- import torch
2
- from diffusers import DiffusionPipeline
3
- import time
4
 
5
  class SDXLImageGenerator:
6
  def __init__(self):
@@ -21,12 +19,19 @@ class SDXLImageGenerator:
21
  self.pipe.to(self.device)
22
 
23
  def generate_images(self, prompts):
24
- images = []
25
  start_time = time.time()
26
- for i, prompt in enumerate(prompts):
27
- gen_image = self.pipe(prompt=prompt).images[0]
28
- images.append(gen_image)
 
 
 
 
 
 
 
 
29
 
30
  end_time = time.time()
31
  print("Total Time SDXL: %4f seconds" % (end_time - start_time))
32
- return images
 
1
+ from io import BytesIO
 
 
2
 
3
  class SDXLImageGenerator:
4
  def __init__(self):
 
19
  self.pipe.to(self.device)
20
 
21
  def generate_images(self, prompts):
 
22
  start_time = time.time()
23
+
24
+ # Generate images in a batch
25
+ outputs = self.pipe(prompt=prompts)
26
+ images = outputs.images
27
+
28
+ # Convert images to PNG byte data
29
+ png_images = []
30
+ for image in images:
31
+ buffer = BytesIO()
32
+ image.save(buffer, format="PNG")
33
+ png_images.append(buffer.getvalue()) # PNG data in bytes
34
 
35
  end_time = time.time()
36
  print("Total Time SDXL: %4f seconds" % (end_time - start_time))
37
+ return png_images