File size: 2,007 Bytes
70ff35f
 
 
 
 
 
 
 
 
 
 
 
7185c8b
70ff35f
8913269
 
 
70ff35f
8913269
 
 
 
 
9f3d881
8913269
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ca7690
 
 
 
74df120
7ca7690
8913269
 
74df120
 
 
 
8913269
 
 
 
 
 
 
 
 
43f386d
 
 
8913269
74df120
8913269
 
 
7185c8b
8913269
 
 
 
ef070d7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from __future__ import annotations

import gc
import pathlib
import sys

import gradio as gr
import PIL.Image
import numpy as np

import torch
from diffusers import StableDiffusionPipeline
# sys.path.insert(0, './ReVersion')

# below are original
import os
# import argparse

# import torch
from PIL import Image

# from diffusers import StableDiffusionPipeline
# sys.path.insert(0, './ReVersion')
# from templates.templates import inference_templates

import math

"""
Inference script for generating batch results
"""

def make_image_grid(imgs, rows, cols):
    assert len(imgs) == rows*cols

    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    grid_w, grid_h = grid.size

    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid


def inference_fn(
        model_id: str,
        prompt: str,
        num_samples: int,
        guidance_scale: float,
        ddim_steps: int,
    ) -> PIL.Image.Image:

    # create inference pipeline
    if torch.cuda.is_available():
        pipe = StableDiffusionPipeline.from_pretrained(os.path.join('experiments', model_id),torch_dtype=torch.float16).to('cuda')
    else:
        pipe = StableDiffusionPipeline.from_pretrained(os.path.join('experiments', model_id)).to('cpu')

    # single text prompt
    if prompt is not None:
        prompt_list = [prompt]
    else:
        prompt_list = []

    for prompt in prompt_list:
        # insert relation prompt <R>
        # prompt = prompt.lower().replace("<r>", "<R>").format(placeholder_string)
        prompt = prompt.lower().replace("<r>", "<R>").format("<R>")

        # batch generation
        images = pipe(prompt, num_inference_steps=ddim_steps, guidance_scale=guidance_scale, num_images_per_prompt=num_samples).images

        # save a grid of images
        image_grid = make_image_grid(images, rows=2, cols=math.ceil(num_samples/2))
        print(image_grid)

        return image_grid

if __name__ == "__main__":
    inference_fn()