File size: 2,726 Bytes
70ff35f
 
 
 
 
 
 
 
 
 
 
 
 
 
8913269
 
 
70ff35f
8913269
 
 
 
 
9f3d881
8913269
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b0d680
 
8913269
 
 
1b0d680
8913269
 
1b0d680
8913269
 
ef070d7
 
8913269
 
 
 
 
 
 
ef070d7
 
 
8913269
 
 
43f386d
 
 
8913269
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef070d7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from __future__ import annotations

import gc
import pathlib
import sys

import gradio as gr
import PIL.Image
import numpy as np

import torch
from diffusers import StableDiffusionPipeline
sys.path.insert(0, './ReVersion')

# below are original
import os
# import argparse

# import torch
from PIL import Image

# from diffusers import StableDiffusionPipeline
# sys.path.insert(0, './ReVersion')
# from templates.templates import inference_templates

import math

"""
Inference script for generating batch results
"""

def make_image_grid(imgs, rows, cols):
    assert len(imgs) == rows*cols

    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    grid_w, grid_h = grid.size

    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid


def inference_fn(
        model_id,
        prompt,
        num_samples,
        guidance_scale,
        device
    ):

    # create inference pipeline
    pipe = StableDiffusionPipeline.from_pretrained(model_id,torch_dtype=torch.float16).to(device)

    # make directory to save images
    image_root_folder = os.path.join('experiments', model_id, 'inference')
    os.makedirs(image_root_folder, exist_ok = True)

    # if prompt is None and args.template_name is None:
    #     raise ValueError("please input a single prompt through'--prompt' or select a batch of prompts using '--template_name'.")

    # single text prompt
    if prompt is not None:
        prompt_list = [prompt]
    else:
        prompt_list = []

    # if args.template_name is not None:
    #     # read the selected text prompts for generation
    #     prompt_list.extend(inference_templates[args.template_name])

    for prompt in prompt_list:
        # insert relation prompt <R>
        # prompt = prompt.lower().replace("<r>", "<R>").format(placeholder_string)
        prompt = prompt.lower().replace("<r>", "<R>").format("<R>")


        # make sub-folder
        image_folder = os.path.join(image_root_folder, prompt, 'samples')
        os.makedirs(image_folder, exist_ok = True)

        # batch generation
        images = pipe(prompt, num_inference_steps=50, guidance_scale=guidance_scale, num_images_per_prompt=num_samples).images

        # save generated images
        for idx, image in enumerate(images):
            image_name = f"{str(idx).zfill(4)}.png"
            image_path = os.path.join(image_folder, image_name)
            image.save(image_path)

        # save a grid of images
        image_grid = make_image_grid(images, rows=2, cols=math.ceil(num_samples/2))
        image_grid_path = os.path.join(image_root_folder, prompt, f'{prompt}.png')

        return image_grid

if __name__ == "__main__":
    inference_fn()