Spaces:
Sleeping
Sleeping
from __future__ import annotations | |
import gc | |
import pathlib | |
import sys | |
import gradio as gr | |
import PIL.Image | |
import numpy as np | |
import torch | |
from diffusers import StableDiffusionPipeline | |
# sys.path.insert(0, './ReVersion') | |
# below are original | |
import os | |
# import argparse | |
# import torch | |
from PIL import Image | |
# from diffusers import StableDiffusionPipeline | |
# sys.path.insert(0, './ReVersion') | |
# from templates.templates import inference_templates | |
import math | |
""" | |
Inference script for generating batch results | |
""" | |
def make_image_grid(imgs, rows, cols): | |
assert len(imgs) == rows*cols | |
w, h = imgs[0].size | |
grid = Image.new('RGB', size=(cols*w, rows*h)) | |
grid_w, grid_h = grid.size | |
for i, img in enumerate(imgs): | |
grid.paste(img, box=(i%cols*w, i//cols*h)) | |
return grid | |
def inference_fn( | |
model_id: str, | |
prompt: str, | |
num_samples: int, | |
guidance_scale: float, | |
ddim_steps: int, | |
) -> PIL.Image.Image: | |
# create inference pipeline | |
if torch.cuda.is_available(): | |
pipe = StableDiffusionPipeline.from_pretrained(os.path.join('experiments', model_id),torch_dtype=torch.float16).to('cuda') | |
else: | |
pipe = StableDiffusionPipeline.from_pretrained(os.path.join('experiments', model_id)).to('cpu') | |
# single text prompt | |
if prompt is not None: | |
prompt_list = [prompt] | |
else: | |
prompt_list = [] | |
for prompt in prompt_list: | |
# insert relation prompt <R> | |
# prompt = prompt.lower().replace("<r>", "<R>").format(placeholder_string) | |
prompt = prompt.lower().replace("<r>", "<R>").format("<R>") | |
# batch generation | |
images = pipe(prompt, num_inference_steps=ddim_steps, guidance_scale=guidance_scale, num_images_per_prompt=num_samples).images | |
# save a grid of images | |
image_grid = make_image_grid(images, rows=2, cols=math.ceil(num_samples/2)) | |
print(image_grid) | |
return image_grid | |
if __name__ == "__main__": | |
inference_fn() | |