boomcheng commited on
Commit
be7e4dd
·
verified ·
1 Parent(s): a8c917d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -52
app.py CHANGED
@@ -1,6 +1,5 @@
1
  import gradio as gr
2
  import numpy as np
3
- import random
4
  from PIL import Image
5
  import torch
6
  from diffusers import ControlNetModel, UniPCMultistepScheduler
@@ -10,45 +9,41 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
10
 
11
  # Initialize model
12
  controlnet = ControlNetModel.from_pretrained("qihoo360/HiCo_T2I", torch_dtype=torch.float16)
13
- print("ControlNet 模型加载完成!")
14
  pipe = StableDiffusionControlNetMultiLayoutPipeline.from_pretrained(
15
  "krnl/realisticVisionV51_v51VAE", controlnet=[controlnet], torch_dtype=torch.float16
16
  )
17
- print("Stable Diffusion 管道加载完成!")
18
  pipe = pipe.to(device)
19
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
20
 
21
  MAX_SEED = np.iinfo(np.int32).max
22
 
23
- # Function to dynamically update object input fields
24
- def update_object_inputs(num_objects):
25
- captions = [gr.Textbox(label=f"Subcaption for Object {i+1}", placeholder=f"Enter caption for Object {i+1}") for i in range(num_objects)]
26
- bbox_coords = [gr.Textbox(label=f"Bounding Box for Object {i+1} (x1, y1, x2, y2)", placeholder="e.g., 50, 50, 150, 150") for i in range(num_objects)]
27
- return captions + bbox_coords
28
-
29
- # Inference function
30
- def infer(prompt, num_objects, subcaptions, bboxes, guidance_scale, num_inference_steps, seed):
31
- obj_class = ["Background"] + subcaptions
32
- obj_bbox = [[0, 0, 512, 512]] + [list(map(int, bbox.split(','))) for bbox in bboxes]
33
-
34
  img_width, img_height = 512, 512
35
  r_image = np.zeros((img_height, img_width, 3), dtype=np.uint8)
36
- list_cond_image = [np.zeros_like(r_image, dtype=np.uint8)]
37
- for bbox in obj_bbox[1:]:
38
- x1, y1, x2, y2 = bbox
 
39
  cond_image = np.zeros_like(r_image, dtype=np.uint8)
40
  cond_image[y1:y2, x1:x2] = 255
41
- list_cond_image.append(cond_image)
42
-
43
- list_cond_image_pil = [Image.fromarray(img).convert('RGB') for img in list_cond_image]
 
 
 
 
 
 
44
 
45
- if seed is None:
46
- seed = random.randint(0, MAX_SEED)
47
  generator = torch.manual_seed(seed)
48
 
49
  image = pipe(
50
  prompt=prompt,
51
- layo_prompt=obj_class,
52
  guess_mode=False,
53
  guidance_scale=guidance_scale,
54
  num_inference_steps=num_inference_steps,
@@ -62,42 +57,27 @@ def infer(prompt, num_objects, subcaptions, bboxes, guidance_scale, num_inferenc
62
 
63
  # Gradio UI
64
  with gr.Blocks() as demo:
65
- gr.Markdown("# Text-to-Image with Layout Control")
66
 
67
- # Global Caption and Object Number
68
  with gr.Row():
69
- prompt = gr.Textbox(label="Global Caption", placeholder="Enter a global caption", value="123")
70
- num_objects = gr.Slider(label="Number of Objects", minimum=1, maximum=5, step=1, value=1)
71
-
72
- # Dynamic inputs for subcaptions and bounding boxes
73
- subcaptions_column = gr.Column(visible=False)
74
- bbox_column = gr.Column(visible=False)
75
-
76
- # "确定" 按钮
77
- confirm_button = gr.Button("确定")
78
-
79
- # Update inputs when the "确定" button is clicked
80
- def on_confirm_click(n):
81
- inputs = update_object_inputs(n)
82
- return {subcaptions_column: inputs[:n], bbox_column: inputs[n:], "visible": True}
83
 
84
- confirm_button.click(on_confirm_click, inputs=num_objects, outputs=[subcaptions_column, bbox_column])
85
-
86
- # Advanced settings
87
- with gr.Accordion("Advanced Settings", open=False):
88
  guidance_scale = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=10.0, step=0.1, value=7.5)
89
  num_inference_steps = gr.Slider(label="Number of Inference Steps", minimum=1, maximum=50, step=1, value=50)
90
- seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, interactive=True)
91
 
92
- # Generate button and result image
93
- generate_button = gr.Button("Generate Image")
94
- result_image = gr.Image(label="Generated Image")
 
 
95
 
96
- # Link button to inference function
97
- generate_button.click(
98
- fn=infer,
99
- inputs=[prompt, num_objects, subcaptions_column, bbox_column, guidance_scale, num_inference_steps, seed],
100
- outputs=[result_image, seed]
101
  )
102
 
103
  if __name__ == "__main__":
 
1
  import gradio as gr
2
  import numpy as np
 
3
  from PIL import Image
4
  import torch
5
  from diffusers import ControlNetModel, UniPCMultistepScheduler
 
9
 
10
  # Initialize model
11
  controlnet = ControlNetModel.from_pretrained("qihoo360/HiCo_T2I", torch_dtype=torch.float16)
12
+ print("ControlNet 模型加载完成!")
13
  pipe = StableDiffusionControlNetMultiLayoutPipeline.from_pretrained(
14
  "krnl/realisticVisionV51_v51VAE", controlnet=[controlnet], torch_dtype=torch.float16
15
  )
16
+ print("Stable Diffusion 管道加载完成!")
17
  pipe = pipe.to(device)
18
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
19
 
20
  MAX_SEED = np.iinfo(np.int32).max
21
 
22
+ # Function to generate images based on user input
23
+ def generate_user_data(object_classes, object_bboxes):
 
 
 
 
 
 
 
 
 
24
  img_width, img_height = 512, 512
25
  r_image = np.zeros((img_height, img_width, 3), dtype=np.uint8)
26
+ list_cond_image = []
27
+
28
+ for bbox in object_bboxes:
29
+ x1, y1, x2, y2 = map(int, bbox.split(","))
30
  cond_image = np.zeros_like(r_image, dtype=np.uint8)
31
  cond_image[y1:y2, x1:x2] = 255
32
+ list_cond_image.append(Image.fromarray(cond_image).convert('RGB'))
33
+
34
+ return object_classes.split(","), list_cond_image
35
+
36
+ # Inference function
37
+ def infer(prompt, guidance_scale, num_inference_steps, randomize_seed, seed, object_classes, object_bboxes):
38
+ obj_classes, list_cond_image_pil = generate_user_data(object_classes, object_bboxes)
39
+ if randomize_seed or seed is None:
40
+ seed = np.random.randint(0, MAX_SEED)
41
 
 
 
42
  generator = torch.manual_seed(seed)
43
 
44
  image = pipe(
45
  prompt=prompt,
46
+ layo_prompt=obj_classes,
47
  guess_mode=False,
48
  guidance_scale=guidance_scale,
49
  num_inference_steps=num_inference_steps,
 
57
 
58
  # Gradio UI
59
  with gr.Blocks() as demo:
60
+ gr.Markdown("# Text-to-Image Generator with Manual Input")
61
 
 
62
  with gr.Row():
63
+ prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here")
64
+ object_classes = gr.Textbox(label="Object Classes (comma-separated)", placeholder="e.g., Object_1,Object_2")
65
+ object_bboxes = gr.Textbox(label="Bounding Boxes (format: x1,y1,x2,y2; separated by commas)", placeholder="e.g., 50,50,150,150")
 
 
 
 
 
 
 
 
 
 
 
66
 
67
+ with gr.Row():
 
 
 
68
  guidance_scale = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=10.0, step=0.1, value=7.5)
69
  num_inference_steps = gr.Slider(label="Number of Inference Steps", minimum=1, maximum=50, step=1, value=50)
 
70
 
71
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
72
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
73
+
74
+ run_button = gr.Button("Generate")
75
+ result = gr.Image(label="Generated Image")
76
 
77
+ run_button.click(
78
+ infer,
79
+ inputs=[prompt, guidance_scale, num_inference_steps, randomize_seed, seed, object_classes, object_bboxes],
80
+ outputs=[result, seed]
 
81
  )
82
 
83
  if __name__ == "__main__":