Update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1 |
import gradio as gr
|
2 |
import numpy as np
|
3 |
-
import random
|
4 |
from PIL import Image
|
5 |
import torch
|
6 |
from diffusers import ControlNetModel, UniPCMultistepScheduler
|
@@ -10,45 +9,41 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
10 |
|
11 |
# Initialize model
|
12 |
controlnet = ControlNetModel.from_pretrained("qihoo360/HiCo_T2I", torch_dtype=torch.float16)
|
13 |
-
print("ControlNet 模型加载完成!")
|
14 |
pipe = StableDiffusionControlNetMultiLayoutPipeline.from_pretrained(
|
15 |
"krnl/realisticVisionV51_v51VAE", controlnet=[controlnet], torch_dtype=torch.float16
|
16 |
)
|
17 |
-
print("Stable Diffusion 管道加载完成!")
|
18 |
pipe = pipe.to(device)
|
19 |
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
20 |
|
21 |
MAX_SEED = np.iinfo(np.int32).max
|
22 |
|
23 |
-
# Function to
|
24 |
-
def
|
25 |
-
captions = [gr.Textbox(label=f"Subcaption for Object {i+1}", placeholder=f"Enter caption for Object {i+1}") for i in range(num_objects)]
|
26 |
-
bbox_coords = [gr.Textbox(label=f"Bounding Box for Object {i+1} (x1, y1, x2, y2)", placeholder="e.g., 50, 50, 150, 150") for i in range(num_objects)]
|
27 |
-
return captions + bbox_coords
|
28 |
-
|
29 |
-
# Inference function
|
30 |
-
def infer(prompt, num_objects, subcaptions, bboxes, guidance_scale, num_inference_steps, seed):
|
31 |
-
obj_class = ["Background"] + subcaptions
|
32 |
-
obj_bbox = [[0, 0, 512, 512]] + [list(map(int, bbox.split(','))) for bbox in bboxes]
|
33 |
-
|
34 |
img_width, img_height = 512, 512
|
35 |
r_image = np.zeros((img_height, img_width, 3), dtype=np.uint8)
|
36 |
-
list_cond_image = [
|
37 |
-
|
38 |
-
|
|
|
39 |
cond_image = np.zeros_like(r_image, dtype=np.uint8)
|
40 |
cond_image[y1:y2, x1:x2] = 255
|
41 |
-
list_cond_image.append(cond_image)
|
42 |
-
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
-
if seed is None:
|
46 |
-
seed = random.randint(0, MAX_SEED)
|
47 |
generator = torch.manual_seed(seed)
|
48 |
|
49 |
image = pipe(
|
50 |
prompt=prompt,
|
51 |
-
layo_prompt=
|
52 |
guess_mode=False,
|
53 |
guidance_scale=guidance_scale,
|
54 |
num_inference_steps=num_inference_steps,
|
@@ -62,42 +57,27 @@ def infer(prompt, num_objects, subcaptions, bboxes, guidance_scale, num_inferenc
|
|
62 |
|
63 |
# Gradio UI
|
64 |
with gr.Blocks() as demo:
|
65 |
-
gr.Markdown("# Text-to-Image with
|
66 |
|
67 |
-
# Global Caption and Object Number
|
68 |
with gr.Row():
|
69 |
-
prompt = gr.Textbox(label="
|
70 |
-
|
71 |
-
|
72 |
-
# Dynamic inputs for subcaptions and bounding boxes
|
73 |
-
subcaptions_column = gr.Column(visible=False)
|
74 |
-
bbox_column = gr.Column(visible=False)
|
75 |
-
|
76 |
-
# "确定" 按钮
|
77 |
-
confirm_button = gr.Button("确定")
|
78 |
-
|
79 |
-
# Update inputs when the "确定" button is clicked
|
80 |
-
def on_confirm_click(n):
|
81 |
-
inputs = update_object_inputs(n)
|
82 |
-
return {subcaptions_column: inputs[:n], bbox_column: inputs[n:], "visible": True}
|
83 |
|
84 |
-
|
85 |
-
|
86 |
-
# Advanced settings
|
87 |
-
with gr.Accordion("Advanced Settings", open=False):
|
88 |
guidance_scale = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=10.0, step=0.1, value=7.5)
|
89 |
num_inference_steps = gr.Slider(label="Number of Inference Steps", minimum=1, maximum=50, step=1, value=50)
|
90 |
-
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, interactive=True)
|
91 |
|
92 |
-
|
93 |
-
|
94 |
-
|
|
|
|
|
95 |
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
outputs=[result_image, seed]
|
101 |
)
|
102 |
|
103 |
if __name__ == "__main__":
|
|
|
1 |
import gradio as gr
|
2 |
import numpy as np
|
|
|
3 |
from PIL import Image
|
4 |
import torch
|
5 |
from diffusers import ControlNetModel, UniPCMultistepScheduler
|
|
|
9 |
|
10 |
# Initialize model
|
11 |
controlnet = ControlNetModel.from_pretrained("qihoo360/HiCo_T2I", torch_dtype=torch.float16)
|
12 |
+
print("ControlNet 模型加载完成!")
|
13 |
pipe = StableDiffusionControlNetMultiLayoutPipeline.from_pretrained(
|
14 |
"krnl/realisticVisionV51_v51VAE", controlnet=[controlnet], torch_dtype=torch.float16
|
15 |
)
|
16 |
+
print("Stable Diffusion 管道加载完成!")
|
17 |
pipe = pipe.to(device)
|
18 |
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
19 |
|
20 |
MAX_SEED = np.iinfo(np.int32).max
|
21 |
|
22 |
+
# Function to generate images based on user input
|
23 |
+
def generate_user_data(object_classes, object_bboxes):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
img_width, img_height = 512, 512
|
25 |
r_image = np.zeros((img_height, img_width, 3), dtype=np.uint8)
|
26 |
+
list_cond_image = []
|
27 |
+
|
28 |
+
for bbox in object_bboxes:
|
29 |
+
x1, y1, x2, y2 = map(int, bbox.split(","))
|
30 |
cond_image = np.zeros_like(r_image, dtype=np.uint8)
|
31 |
cond_image[y1:y2, x1:x2] = 255
|
32 |
+
list_cond_image.append(Image.fromarray(cond_image).convert('RGB'))
|
33 |
+
|
34 |
+
return object_classes.split(","), list_cond_image
|
35 |
+
|
36 |
+
# Inference function
|
37 |
+
def infer(prompt, guidance_scale, num_inference_steps, randomize_seed, seed, object_classes, object_bboxes):
|
38 |
+
obj_classes, list_cond_image_pil = generate_user_data(object_classes, object_bboxes)
|
39 |
+
if randomize_seed or seed is None:
|
40 |
+
seed = np.random.randint(0, MAX_SEED)
|
41 |
|
|
|
|
|
42 |
generator = torch.manual_seed(seed)
|
43 |
|
44 |
image = pipe(
|
45 |
prompt=prompt,
|
46 |
+
layo_prompt=obj_classes,
|
47 |
guess_mode=False,
|
48 |
guidance_scale=guidance_scale,
|
49 |
num_inference_steps=num_inference_steps,
|
|
|
57 |
|
58 |
# Gradio UI
|
59 |
with gr.Blocks() as demo:
|
60 |
+
gr.Markdown("# Text-to-Image Generator with Manual Input")
|
61 |
|
|
|
62 |
with gr.Row():
|
63 |
+
prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here")
|
64 |
+
object_classes = gr.Textbox(label="Object Classes (comma-separated)", placeholder="e.g., Object_1,Object_2")
|
65 |
+
object_bboxes = gr.Textbox(label="Bounding Boxes (format: x1,y1,x2,y2; separated by commas)", placeholder="e.g., 50,50,150,150")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
|
67 |
+
with gr.Row():
|
|
|
|
|
|
|
68 |
guidance_scale = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=10.0, step=0.1, value=7.5)
|
69 |
num_inference_steps = gr.Slider(label="Number of Inference Steps", minimum=1, maximum=50, step=1, value=50)
|
|
|
70 |
|
71 |
+
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
|
72 |
+
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
|
73 |
+
|
74 |
+
run_button = gr.Button("Generate")
|
75 |
+
result = gr.Image(label="Generated Image")
|
76 |
|
77 |
+
run_button.click(
|
78 |
+
infer,
|
79 |
+
inputs=[prompt, guidance_scale, num_inference_steps, randomize_seed, seed, object_classes, object_bboxes],
|
80 |
+
outputs=[result, seed]
|
|
|
81 |
)
|
82 |
|
83 |
if __name__ == "__main__":
|