Spaces:
Starting
on
A10G
Starting
on
A10G
File size: 2,007 Bytes
70ff35f 7185c8b 70ff35f 8913269 70ff35f 8913269 9f3d881 8913269 7ca7690 74df120 7ca7690 8913269 74df120 8913269 43f386d 8913269 74df120 8913269 7185c8b 8913269 ef070d7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
from __future__ import annotations
import gc
import pathlib
import sys
import gradio as gr
import PIL.Image
import numpy as np
import torch
from diffusers import StableDiffusionPipeline
# sys.path.insert(0, './ReVersion')
# below are original
import os
# import argparse
# import torch
from PIL import Image
# from diffusers import StableDiffusionPipeline
# sys.path.insert(0, './ReVersion')
# from templates.templates import inference_templates
import math
"""
Inference script for generating batch results
"""
def make_image_grid(imgs, rows, cols):
assert len(imgs) == rows*cols
w, h = imgs[0].size
grid = Image.new('RGB', size=(cols*w, rows*h))
grid_w, grid_h = grid.size
for i, img in enumerate(imgs):
grid.paste(img, box=(i%cols*w, i//cols*h))
return grid
def inference_fn(
model_id: str,
prompt: str,
num_samples: int,
guidance_scale: float,
ddim_steps: int,
) -> PIL.Image.Image:
# create inference pipeline
if torch.cuda.is_available():
pipe = StableDiffusionPipeline.from_pretrained(os.path.join('experiments', model_id),torch_dtype=torch.float16).to('cuda')
else:
pipe = StableDiffusionPipeline.from_pretrained(os.path.join('experiments', model_id)).to('cpu')
# single text prompt
if prompt is not None:
prompt_list = [prompt]
else:
prompt_list = []
for prompt in prompt_list:
# insert relation prompt <R>
# prompt = prompt.lower().replace("<r>", "<R>").format(placeholder_string)
prompt = prompt.lower().replace("<r>", "<R>").format("<R>")
# batch generation
images = pipe(prompt, num_inference_steps=ddim_steps, guidance_scale=guidance_scale, num_images_per_prompt=num_samples).images
# save a grid of images
image_grid = make_image_grid(images, rows=2, cols=math.ceil(num_samples/2))
print(image_grid)
return image_grid
if __name__ == "__main__":
inference_fn()
|