boomcheng commited on
Commit
0c65251
·
verified ·
1 Parent(s): 3edb685

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -149
app.py CHANGED
@@ -1,165 +1,40 @@
1
  import gradio as gr
2
- import numpy as np
3
- from PIL import Image
4
- import torch
5
- from diffusers import ControlNetModel, UniPCMultistepScheduler
6
- from hico_pipeline import StableDiffusionControlNetMultiLayoutPipeline
7
 
8
- device = "cuda" if torch.cuda.is_available() else "cpu"
9
-
10
- # Initialize model
11
- controlnet = ControlNetModel.from_pretrained("qihoo360/HiCo_T2I", torch_dtype=torch.float16)
12
- pipe = StableDiffusionControlNetMultiLayoutPipeline.from_pretrained(
13
- "krnl/realisticVisionV51_v51VAE", controlnet=[controlnet], torch_dtype=torch.float16
14
- )
15
- pipe = pipe.to(device)
16
- pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
17
-
18
- MAX_SEED = np.iinfo(np.int32).max
19
-
20
- # Store objects
21
  object_classes_list = []
22
  object_bboxes_list = []
23
 
24
- # Function to add or update the prompt in the list
25
- def submit_prompt(prompt):
26
- if object_classes_list:
27
- object_classes_list[0] = prompt # Overwrite the first element if it exists
28
- else:
29
- object_classes_list.insert(0, prompt) # Add to the beginning if the list is empty
30
-
31
- if not object_bboxes_list:
32
- object_bboxes_list.insert(0, "0,0,512,512") # Add the default bounding box if the list is empty
33
-
34
- combined_list = [[cls, bbox] for cls, bbox in zip(object_classes_list, object_bboxes_list)]
35
- return combined_list, gr.update(interactive=False) # Make the prompt input non-editable
36
-
37
- # Function to add a new object with validation
38
- def add_object(object_class, bbox):
39
- try:
40
- # Split and convert bbox string into integers
41
- x1, y1, x2, y2 = map(int, bbox.split(","))
42
-
43
- # Validate the coordinates
44
- if x2 < x1 or y2 < y1:
45
- return "Error: x2 cannot be less than x1 and y2 cannot be less than y1.", []
46
- if x1 < 0 or y1 < 0 or x2 > 512 or y2 > 512:
47
- return "Error: Coordinates must be between 0 and 512.", []
48
-
49
- # If validation passes, add to the lists
50
- object_classes_list.append(object_class)
51
- object_bboxes_list.append(bbox)
52
- combined_list = [[cls, bbox] for cls, bbox in zip(object_classes_list, object_bboxes_list)]
53
- return combined_list
54
-
55
- except ValueError:
56
- return "Error: Invalid input format. Use x1,y1,x2,y2.", []
57
-
58
- # Function to generate images based on added objects
59
- def generate_image(prompt, guidance_scale, num_inference_steps, randomize_seed, seed):
60
- img_width, img_height = 512, 512
61
- r_image = np.zeros((img_height, img_width, 3), dtype=np.uint8)
62
- list_cond_image = []
63
-
64
- for bbox in object_bboxes_list:
65
- x1, y1, x2, y2 = map(int, bbox.split(","))
66
- cond_image = np.zeros_like(r_image, dtype=np.uint8)
67
- cond_image[y1:y2, x1:x2] = 255
68
- list_cond_image.append(Image.fromarray(cond_image).convert('RGB'))
69
-
70
- if randomize_seed or seed is None:
71
- seed = np.random.randint(0, MAX_SEED)
72
-
73
- generator = torch.manual_seed(seed)
74
-
75
- image = pipe(
76
- prompt=prompt,
77
- layo_prompt=object_classes_list,
78
- guess_mode=False,
79
- guidance_scale=guidance_scale,
80
- num_inference_steps=num_inference_steps,
81
- image=list_cond_image,
82
- fuse_type="avg",
83
- width=512,
84
- height=512
85
- ).images[0]
86
-
87
- return image, seed
88
 
89
- # Gradio UI
90
  with gr.Blocks() as demo:
91
- gr.Markdown("# Text-to-Image Generator with Object Addition")
92
-
93
- # Put prompt and submit button in the same row
94
  with gr.Group():
95
  with gr.Row():
96
- # Replace gr.Textbox with gr.Text for a single-line input field
97
  prompt = gr.Text(
98
- label="Prompt", # Label for the input field
99
- show_label=False, # Hide the label
100
- max_lines=1, # Single-line input
101
- placeholder="Enter your prompt here", # Placeholder text
102
- container=False # Remove the container background
103
  )
104
- # Replace the button with the simplified "Run" button
105
- submit_button = gr.Button("Submit Prompt", scale=0) # Add scale for button size
106
-
107
- # Always visible DataFrame
108
- objects_display = gr.Dataframe(
109
- headers=["Object Class", "Bounding Box"],
110
- value=[]
111
- )
112
-
113
- with gr.Row():
114
- object_class_input = gr.Textbox(label="Object Class", placeholder="Enter object class (e.g., Object_1)")
115
- bbox_input = gr.Textbox(label="Bounding Box (x1,y1,x2,y2)", placeholder="Enter bounding box coordinates")
116
 
117
- add_button = gr.Button("Add Object")
 
118
 
119
- # Advanced settings in a collapsible accordion
120
- with gr.Accordion("Advanced Settings", open=False):
121
- seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
122
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
123
-
124
- with gr.Row():
125
- guidance_scale = gr.Slider(
126
- label="Guidance scale",
127
- minimum=0.0,
128
- maximum=10.0,
129
- step=0.1,
130
- value=7.5
131
- )
132
- num_inference_steps = gr.Slider(
133
- label="Number of inference steps",
134
- minimum=1,
135
- maximum=50,
136
- step=1,
137
- value=50
138
- )
139
-
140
- generate_button = gr.Button("Generate Image")
141
- result = gr.Image(label="Generated Image")
142
-
143
- # Submit the prompt and update the display
144
- submit_button.click(
145
- fn=submit_prompt,
146
- inputs=prompt,
147
- outputs=[objects_display, prompt] # Update both the display and prompt input
148
- )
149
-
150
- # Add object and update display
151
- add_button.click(
152
- fn=add_object,
153
- inputs=[object_class_input, bbox_input],
154
- outputs=[objects_display]
155
- )
156
 
157
- # Generate image based on added objects
158
- generate_button.click(
159
- fn=generate_image,
160
- inputs=[prompt, guidance_scale, num_inference_steps, randomize_seed, seed],
161
- outputs=[result, seed]
162
  )
163
 
164
- if __name__ == "__main__":
165
- demo.launch()
 
1
  import gradio as gr
 
 
 
 
 
2
 
3
+ # Arrays to be cleared
 
 
 
 
 
 
 
 
 
 
 
 
4
  object_classes_list = []
5
  object_bboxes_list = []
6
 
7
+ # Function to clear all arrays
8
+ def clear_arrays():
9
+ object_classes_list.clear()
10
+ object_bboxes_list.clear()
11
+ return [], gr.update(value="", interactive=True) # Clear result and reset prompt input
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
 
13
  with gr.Blocks() as demo:
 
 
 
14
  with gr.Group():
15
  with gr.Row():
16
+ # Prompt input and submit button
17
  prompt = gr.Text(
18
+ label="Prompt",
19
+ show_label=False,
20
+ max_lines=1,
21
+ placeholder="Enter your prompt here",
22
+ container=False
23
  )
24
+ submit_button = gr.Button("Submit Prompt", scale=0)
 
 
 
 
 
 
 
 
 
 
 
25
 
26
+ # Gallery to display results (for demonstration purposes)
27
+ result = gr.Gallery(label="Result", columns=3, show_label=False)
28
 
29
+ # Refresh button to clear arrays
30
+ refresh_button = gr.Button("Refresh")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
+ # Add functionality to the refresh button
33
+ refresh_button.click(
34
+ fn=clear_arrays, # Function to clear arrays
35
+ inputs=None,
36
+ outputs=[result, prompt] # Clear the result and reset the prompt input
37
  )
38
 
39
+ # Launch the Gradio app
40
+ demo.launch()