boomcheng commited on
Commit
0908d26
·
verified ·
1 Parent(s): 0c65251

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +155 -22
app.py CHANGED
@@ -1,40 +1,173 @@
1
  import gradio as gr
 
 
 
 
 
2
 
3
- # Arrays to be cleared
 
 
 
 
 
 
 
 
 
 
 
 
4
  object_classes_list = []
5
  object_bboxes_list = []
6
 
7
- # Function to clear all arrays
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  def clear_arrays():
9
  object_classes_list.clear()
10
  object_bboxes_list.clear()
11
- return [], gr.update(value="", interactive=True) # Clear result and reset prompt input
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- with gr.Blocks() as demo:
14
- with gr.Group():
15
  with gr.Row():
16
- # Prompt input and submit button
17
- prompt = gr.Text(
18
- label="Prompt",
19
- show_label=False,
20
- max_lines=1,
21
- placeholder="Enter your prompt here",
22
- container=False
23
  )
24
- submit_button = gr.Button("Submit Prompt", scale=0)
 
 
 
 
 
 
 
 
 
25
 
26
- # Gallery to display results (for demonstration purposes)
27
- result = gr.Gallery(label="Result", columns=3, show_label=False)
 
 
 
 
28
 
29
- # Refresh button to clear arrays
30
- refresh_button = gr.Button("Refresh")
 
 
 
 
31
 
32
- # Add functionality to the refresh button
33
  refresh_button.click(
34
- fn=clear_arrays, # Function to clear arrays
35
  inputs=None,
36
- outputs=[result, prompt] # Clear the result and reset the prompt input
 
 
 
 
 
 
 
37
  )
38
 
39
- # Launch the Gradio app
40
- demo.launch()
 
1
  import gradio as gr
2
+ import numpy as np
3
+ from PIL import Image
4
+ import torch
5
+ from diffusers import ControlNetModel, UniPCMultistepScheduler
6
+ from hico_pipeline import StableDiffusionControlNetMultiLayoutPipeline
7
 
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
+
10
+ # Initialize model
11
+ controlnet = ControlNetModel.from_pretrained("qihoo360/HiCo_T2I", torch_dtype=torch.float16)
12
+ pipe = StableDiffusionControlNetMultiLayoutPipeline.from_pretrained(
13
+ "krnl/realisticVisionV51_v51VAE", controlnet=[controlnet], torch_dtype=torch.float16
14
+ )
15
+ pipe = pipe.to(device)
16
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
17
+
18
+ MAX_SEED = np.iinfo(np.int32).max
19
+
20
+ # Store objects
21
  object_classes_list = []
22
  object_bboxes_list = []
23
 
24
+ # Function to add or update the prompt in the list
25
+ def submit_prompt(prompt):
26
+ if object_classes_list:
27
+ object_classes_list[0] = prompt # Overwrite the first element if it exists
28
+ else:
29
+ object_classes_list.insert(0, prompt) # Add to the beginning if the list is empty
30
+
31
+ if not object_bboxes_list:
32
+ object_bboxes_list.insert(0, "0,0,512,512") # Add the default bounding box if the list is empty
33
+
34
+ combined_list = [[cls, bbox] for cls, bbox in zip(object_classes_list, object_bboxes_list)]
35
+ return combined_list, gr.update(interactive=False) # Make the prompt input non-editable
36
+
37
+ # Function to add a new object with validation
38
+ def add_object(object_class, bbox):
39
+ try:
40
+ x1, y1, x2, y2 = map(int, bbox.split(","))
41
+ if x2 < x1 or y2 < y1:
42
+ return "Error: x2 cannot be less than x1 and y2 cannot be less than y1.", []
43
+ if x1 < 0 or y1 < 0 or x2 > 512 or y2 > 512:
44
+ return "Error: Coordinates must be between 0 and 512.", []
45
+ object_classes_list.append(object_class)
46
+ object_bboxes_list.append(bbox)
47
+ combined_list = [[cls, bbox] for cls, bbox in zip(object_classes_list, object_bboxes_list)]
48
+ return combined_list
49
+ except ValueError:
50
+ return "Error: Invalid input format. Use x1,y1,x2,y2.", []
51
+
52
+ # Function to generate images based on added objects
53
+ def generate_image(prompt, guidance_scale, num_inference_steps, randomize_seed, seed):
54
+ img_width, img_height = 512, 512
55
+ r_image = np.zeros((img_height, img_width, 3), dtype=np.uint8)
56
+ list_cond_image = []
57
+
58
+ for bbox in object_bboxes_list:
59
+ x1, y1, x2, y2 = map(int, bbox.split(","))
60
+ cond_image = np.zeros_like(r_image, dtype=np.uint8)
61
+ cond_image[y1:y2, x1:x2] = 255
62
+ list_cond_image.append(Image.fromarray(cond_image).convert('RGB'))
63
+
64
+ if randomize_seed or seed is None:
65
+ seed = np.random.randint(0, MAX_SEED)
66
+
67
+ generator = torch.manual_seed(seed)
68
+
69
+ image = pipe(
70
+ prompt=prompt,
71
+ layo_prompt=object_classes_list,
72
+ guess_mode=False,
73
+ guidance_scale=guidance_scale,
74
+ num_inference_steps=num_inference_steps,
75
+ image=list_cond_image,
76
+ fuse_type="avg",
77
+ width=512,
78
+ height=512
79
+ ).images[0]
80
+
81
+ return image, seed
82
+
83
+ # Function to clear all arrays and reset the UI
84
  def clear_arrays():
85
  object_classes_list.clear()
86
  object_bboxes_list.clear()
87
+ return [], gr.update(value="", interactive=True) # Clear the objects and reset the prompt
88
+
89
+ # Gradio UI with custom CSS for orange buttons
90
+ css = """
91
+ button {
92
+ background-color: orange !important;
93
+ color: white !important;
94
+ border: none !important;
95
+ font-weight: bold;
96
+ }
97
+ """
98
+
99
+ with gr.Blocks(css=css) as demo:
100
+ gr.Markdown("# Text-to-Image Generator with Object Addition")
101
+
102
+ # Put prompt and submit button in the same row and adjust sizes
103
+ with gr.Row():
104
+ prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here").style(width=500)
105
+ submit_button = gr.Button("Submit Prompt").style(width=100)
106
+
107
+ # Always visible DataFrame
108
+ objects_display = gr.Dataframe(
109
+ headers=["Object Class", "Bounding Box"],
110
+ value=[]
111
+ )
112
+
113
+ with gr.Row():
114
+ object_class_input = gr.Textbox(label="Object Class", placeholder="Enter object class (e.g., Object_1)")
115
+ bbox_input = gr.Textbox(label="Bounding Box (x1,y1,x2,y2)", placeholder="Enter bounding box coordinates")
116
+
117
+ add_button = gr.Button("Add Object")
118
+ refresh_button = gr.Button("Refresh") # New Refresh button
119
+
120
+ # Advanced settings in a collapsible accordion
121
+ with gr.Accordion("Advanced Settings", open=False):
122
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
123
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
124
 
 
 
125
  with gr.Row():
126
+ guidance_scale = gr.Slider(
127
+ label="Guidance scale",
128
+ minimum=0.0,
129
+ maximum=10.0,
130
+ step=0.1,
131
+ value=7.5
 
132
  )
133
+ num_inference_steps = gr.Slider(
134
+ label="Number of inference steps",
135
+ minimum=1,
136
+ maximum=50,
137
+ step=1,
138
+ value=50
139
+ )
140
+
141
+ generate_button = gr.Button("Generate Image")
142
+ result = gr.Image(label="Generated Image")
143
 
144
+ # Submit the prompt and update the display
145
+ submit_button.click(
146
+ fn=submit_prompt,
147
+ inputs=prompt,
148
+ outputs=[objects_display, prompt]
149
+ )
150
 
151
+ # Add object and update display
152
+ add_button.click(
153
+ fn=add_object,
154
+ inputs=[object_class_input, bbox_input],
155
+ outputs=[objects_display]
156
+ )
157
 
158
+ # Refresh button to clear arrays and reset inputs
159
  refresh_button.click(
160
+ fn=clear_arrays,
161
  inputs=None,
162
+ outputs=[objects_display, prompt]
163
+ )
164
+
165
+ # Generate image based on added objects
166
+ generate_button.click(
167
+ fn=generate_image,
168
+ inputs=[prompt, guidance_scale, num_inference_steps, randomize_seed, seed],
169
+ outputs=[result, seed]
170
  )
171
 
172
+ if __name__ == "__main__":
173
+ demo.launch()