Abdulrahman1989 commited on
Commit
343346a
·
1 Parent(s): 17397c2

Add SDXLImageGenerator class

Browse files
Files changed (1) hide show
  1. SDXLImageGenerator.py +30 -0
SDXLImageGenerator.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers import DiffusionPipeline
3
+ import time
4
+
5
+ class SDXLImageGenerator:
6
+ def __init__(self):
7
+ # Check if cuda is available
8
+ self.use_cuda = torch.cuda.is_available()
9
+ # Set proper device based on cuda availability
10
+ self.device = torch.device("cuda" if self.use_cuda else "cpu")
11
+
12
+ # Load the pipeline
13
+ self.pipe = DiffusionPipeline.from_pretrained(
14
+ "stabilityai/stable-diffusion-xl-base-1.0",
15
+ torch_dtype=torch.float16,
16
+ use_safetensors=True,
17
+ variant="fp16"
18
+ )
19
+ self.pipe.to(self.device)
20
+
21
+ def generate_images(self, prompts):
22
+ images = []
23
+ start_time = time.time()
24
+ for i, prompt in enumerate(prompts):
25
+ gen_image = self.pipe(prompt=prompt).images[0]
26
+ images.append(gen_image)
27
+
28
+ end_time = time.time()
29
+ print("Total Time SDXL: %4f seconds" % (end_time - start_time))
30
+ return images