ReVersion / inference.py
Ziqi's picture
update
962e77d
raw
history blame
2.01 kB
from __future__ import annotations
import gc
import pathlib
import sys
import gradio as gr
import PIL.Image
import numpy as np
import torch
from diffusers import StableDiffusionPipeline
# sys.path.insert(0, './ReVersion')
# below are original
import os
# import argparse
# import torch
from PIL import Image
# from diffusers import StableDiffusionPipeline
# sys.path.insert(0, './ReVersion')
# from templates.templates import inference_templates
import math
"""
Inference script for generating batch results
"""
def make_image_grid(imgs, rows, cols):
assert len(imgs) == rows*cols
w, h = imgs[0].size
grid = Image.new('RGB', size=(cols*w, rows*h))
grid_w, grid_h = grid.size
for i, img in enumerate(imgs):
grid.paste(img, box=(i%cols*w, i//cols*h))
return grid
def inference_fn(
model_id: str,
prompt: str,
num_samples: int,
guidance_scale: float,
ddim_steps: int,
) -> PIL.Image.Image:
# create inference pipeline
if torch.cuda.is_available():
pipe = StableDiffusionPipeline.from_pretrained(os.path.join('experiments', model_id),torch_dtype=torch.float16).to('cuda')
else:
pipe = StableDiffusionPipeline.from_pretrained(os.path.join('experiments', model_id)).to('cpu')
# single text prompt
if prompt is not None:
prompt_list = [prompt]
else:
prompt_list = []
for prompt in prompt_list:
# insert relation prompt <R>
# prompt = prompt.lower().replace("<r>", "<R>").format(placeholder_string)
prompt = prompt.lower().replace("<r>", "<R>").format("<R>")
# batch generation
images = pipe(prompt, num_inference_steps=ddim_steps, guidance_scale=guidance_scale, num_images_per_prompt=num_samples).images
# save a grid of images
image_grid = make_image_grid(images, rows=2, cols=math.ceil(num_samples/2))
print(image_grid)
return image_grid
if __name__ == "__main__":
inference_fn()