Spaces:
Paused
Paused
import torch | |
from diffusers import DiffusionPipeline | |
import time | |
class SDXLImageGenerator: | |
def __init__(self): | |
# Check if cuda is available | |
self.use_cuda = torch.cuda.is_available() | |
# Set proper device based on cuda availability | |
self.device = torch.device("cuda" if self.use_cuda else "cpu") | |
print("CUDA: ", self.device) | |
# Load the pipeline | |
self.pipe = DiffusionPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-xl-base-1.0", | |
torch_dtype=torch.float16, | |
use_safetensors=True, | |
variant="fp16" | |
) | |
self.pipe.to(self.device) | |
def generate_images(self, prompts): | |
images = [] | |
start_time = time.time() | |
for i, prompt in enumerate(prompts): | |
gen_image = self.pipe(prompt=prompt).images[0] | |
images.append(gen_image) | |
end_time = time.time() | |
print("Total Time SDXL: %4f seconds" % (end_time - start_time)) | |
return images | |