File size: 6,607 Bytes
3f6feda
0908d26
 
 
 
 
3f6feda
0908d26
 
 
 
 
 
 
 
 
 
 
 
 
e59aef1
 
e86060e
0908d26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bdb9a9a
0908d26
bdb9a9a
 
0908d26
 
 
 
bdb9a9a
 
0908d26
 
 
 
bdb9a9a
0908d26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b4c737
0908d26
 
38d1b63
 
 
 
 
 
bdb9a9a
 
9b4c737
5900058
9c77bf2
0908d26
bdb9a9a
 
 
38d1b63
bdb9a9a
38d1b63
 
 
 
9756dd2
e59aef1
bdb9a9a
38d1b63
0908d26
 
 
43c15c5
e59aef1
0908d26
 
 
e59aef1
 
0908d26
9b4c737
379b9be
0908d26
 
 
 
30cb293
3edb685
0908d26
 
 
 
 
 
3edb685
0908d26
 
 
 
 
 
 
 
 
 
dd1bf75
 
 
30cb293
0908d26
 
 
 
38d1b63
0908d26
30cb293
0908d26
 
 
 
 
 
30cb293
0908d26
 
 
 
 
30cb293
3f6feda
dd1bf75
 
 
 
 
 
 
0908d26
38d1b63
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import gradio as gr
import numpy as np
from PIL import Image
import torch
from diffusers import ControlNetModel, UniPCMultistepScheduler
from hico_pipeline import StableDiffusionControlNetMultiLayoutPipeline

device = "cuda" if torch.cuda.is_available() else "cpu"

# Initialize model
controlnet = ControlNetModel.from_pretrained("qihoo360/HiCo_T2I", torch_dtype=torch.float16)
pipe = StableDiffusionControlNetMultiLayoutPipeline.from_pretrained(
    "krnl/realisticVisionV51_v51VAE", controlnet=[controlnet], torch_dtype=torch.float16
)
pipe = pipe.to(device)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)

MAX_SEED = np.iinfo(np.int32).max

# Store objects
object_classes_list = ["A photograph of a young woman wrapped in a towel wearing a pair of sunglasses", "a towel", "a young woman wrapped in a towel wearing a pair of sunglasses", "a pair of sunglasses"]
object_bboxes_list = ["0,0,512,512", "17,77,144,155", "16,28,157,155", "82,44,129,63"]

# Function to add or update the prompt in the list
def submit_prompt(prompt):
    if object_classes_list:
        object_classes_list[0] = prompt  # Overwrite the first element if it exists
    else:
        object_classes_list.insert(0, prompt)  # Add to the beginning if the list is empty

    if not object_bboxes_list:
        object_bboxes_list.insert(0, "0,0,512,512")  # Add the default bounding box if the list is empty

    combined_list = [[cls, bbox] for cls, bbox in zip(object_classes_list, object_bboxes_list)]
    return combined_list, gr.update(interactive=False)  # Make the prompt input non-editable

# Function to add a new object with validation
def add_object(object_class, bbox):
    try:
        # Split and convert bbox string into integers
        x1, y1, x2, y2 = map(int, bbox.split(","))
        
        # Validate the coordinates
        if x2 < x1 or y2 < y1:
            return "Error: x2 cannot be less than x1 and y2 cannot be less than y1.", []
        if x1 < 0 or y1 < 0 or x2 > 512 or y2 > 512:
            return "Error: Coordinates must be between 0 and 512.", []

        # If validation passes, add to the lists
        object_classes_list.append(object_class)
        object_bboxes_list.append(bbox)
        combined_list = [[cls, bbox] for cls, bbox in zip(object_classes_list, object_bboxes_list)]
        return combined_list

    except ValueError:
        return "Error: Invalid input format. Use x1,y1,x2,y2.", []

# Function to generate images based on added objects
def generate_image(prompt, guidance_scale, num_inference_steps, randomize_seed, seed):
    img_width, img_height = 512, 512
    r_image = np.zeros((img_height, img_width, 3), dtype=np.uint8)
    list_cond_image = []

    for bbox in object_bboxes_list:
        x1, y1, x2, y2 = map(int, bbox.split(","))
        cond_image = np.zeros_like(r_image, dtype=np.uint8)
        cond_image[y1:y2, x1:x2] = 255
        list_cond_image.append(Image.fromarray(cond_image).convert('RGB'))

    if randomize_seed or seed is None:
        seed = np.random.randint(0, MAX_SEED)

    generator = torch.manual_seed(seed)

    image = pipe(
        prompt=prompt,
        layo_prompt=object_classes_list,
        guess_mode=False,
        guidance_scale=guidance_scale,
        num_inference_steps=num_inference_steps,
        image=list_cond_image,
        fuse_type="avg",
        width=512,
        height=512
    ).images[0]
    print(type(image),'image')
    return image, seed

# Function to clear all arrays and reset the UI
def clear_arrays():
    object_classes_list.clear()
    object_bboxes_list.clear()
    return [], gr.update(value="", interactive=True)  # Clear the objects and reset the prompt

# Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("# HiCo_T2I 512px")
    gr.Markdown(" You can directly click **Generate Image** or customize it by first entering the global caption, followed by subcaptions and their corresponding coordinates.")


    # Put prompt and submit button in the same row
    with gr.Group():
        with gr.Row():
            # Prompt input and submit button
            prompt = gr.Text(
                label="Prompt",
                show_label=False,
                max_lines=1,
                placeholder="Enter your prompt here",
                container=False,
                
            )
            submit_button = gr.Button("Submit Prompt", scale=0)

    # Always visible DataFrame
    objects_display = gr.Dataframe(
        headers=["Caption", "Bounding Box"],
        value=[[cls, bbox] for cls, bbox in zip(object_classes_list, object_bboxes_list)]
    )

    with gr.Row():
        object_class_input = gr.Textbox(label="Sub-caption", placeholder="Enter Sub-caption (e.g., apple)")
        bbox_input = gr.Textbox(label="Bounding Box (x1,y1,x2,y2 and >=0 and <=512)", placeholder="Enter bounding box coordinates")

    add_button = gr.Button("Add")

    # Advanced settings in a collapsible accordion
    with gr.Accordion("Advanced Settings", open=False):
        seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
        randomize_seed = gr.Checkbox(label="Randomize seed", value=True)

        with gr.Row():
            guidance_scale = gr.Slider(
                label="Guidance scale",
                minimum=0.0,
                maximum=10.0,
                step=0.1,
                value=7.5
            )
            num_inference_steps = gr.Slider(
                label="Number of inference steps",
                minimum=1,
                maximum=50,
                step=1,
                value=50
            )

    generate_button = gr.Button("Generate Image")
    result = gr.Image(label="Generated Image")

    # Refresh button to clear arrays and reset inputs (moved below the result)
    refresh_button = gr.Button("Refresh")

    # Submit the prompt and update the display
    submit_button.click(
        fn=submit_prompt,
        inputs=prompt,
        outputs=[objects_display, prompt]
    )

    # Add object and update display
    add_button.click(
        fn=add_object,
        inputs=[object_class_input, bbox_input],
        outputs=[objects_display]
    )

    # Generate image based on added objects
    generate_button.click(
        fn=generate_image,
        inputs=[prompt, guidance_scale, num_inference_steps, randomize_seed, seed],
        outputs=[result, seed]
    )

    # Refresh button to clear arrays and reset inputs
    refresh_button.click(
        fn=clear_arrays,
        inputs=None,
        outputs=[objects_display, prompt]
    )

if __name__ == "__main__":
    demo.launch()