boomcheng commited on
Commit
30cb293
·
verified ·
1 Parent(s): f3a669b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -10
app.py CHANGED
@@ -24,22 +24,24 @@ object_bboxes_list = []
24
  # Function to add or update the prompt in the list
25
  def submit_prompt(prompt):
26
  if object_classes_list:
27
- object_classes_list[0] = prompt
28
  else:
29
- object_classes_list.insert(0, prompt)
30
 
31
  if not object_bboxes_list:
32
- object_bboxes_list.insert(0, "0,0,512,512")
33
 
34
  combined_list = [[cls, bbox] for cls, bbox in zip(object_classes_list, object_bboxes_list)]
35
- return combined_list, gr.update(interactive=False)
36
 
37
  # Function to add a new object with validation
38
  def add_object(object_class, bbox):
39
  try:
40
  x1, y1, x2, y2 = map(int, bbox.split(","))
41
- if x2 < x1 or y2 < y1 or x1 < 0 or y1 < 0 or x2 > 512 or y2 > 512:
42
- return "Error: Invalid coordinates.", []
 
 
43
  object_classes_list.append(object_class)
44
  object_bboxes_list.append(bbox)
45
  combined_list = [[cls, bbox] for cls, bbox in zip(object_classes_list, object_bboxes_list)]
@@ -47,21 +49,52 @@ def add_object(object_class, bbox):
47
  except ValueError:
48
  return "Error: Invalid input format. Use x1,y1,x2,y2.", []
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  # Gradio UI with custom CSS
51
  css = """
52
  #custom-prompt {
53
  width: 400px; /* Set the width of the prompt input box */
54
  }
55
  #custom-button {
56
- width: 150px; /* Set the width of the submit button */
57
- height: 38px; /* Set the height of the submit button */
58
  }
59
  """
60
 
61
  with gr.Blocks(css=css) as demo:
62
  gr.Markdown("# Text-to-Image Generator with Object Addition")
63
 
64
- # Put prompt and submit button in the same row and apply custom styles
65
  with gr.Row():
66
  prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here", elem_id="custom-prompt")
67
  submit_button = gr.Button("Submit Prompt", elem_id="custom-button")
@@ -72,7 +105,56 @@ with gr.Blocks(css=css) as demo:
72
  value=[]
73
  )
74
 
75
- # Add remaining UI components...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  if __name__ == "__main__":
78
  demo.launch()
 
24
  # Function to add or update the prompt in the list
25
  def submit_prompt(prompt):
26
  if object_classes_list:
27
+ object_classes_list[0] = prompt # Overwrite the first element if it exists
28
  else:
29
+ object_classes_list.insert(0, prompt) # Add to the beginning if the list is empty
30
 
31
  if not object_bboxes_list:
32
+ object_bboxes_list.insert(0, "0,0,512,512") # Add the default bounding box if the list is empty
33
 
34
  combined_list = [[cls, bbox] for cls, bbox in zip(object_classes_list, object_bboxes_list)]
35
+ return combined_list, gr.update(interactive=False) # Make the prompt input non-editable
36
 
37
  # Function to add a new object with validation
38
  def add_object(object_class, bbox):
39
  try:
40
  x1, y1, x2, y2 = map(int, bbox.split(","))
41
+ if x2 < x1 or y2 < y1:
42
+ return "Error: x2 cannot be less than x1 and y2 cannot be less than y1.", []
43
+ if x1 < 0 or y1 < 0 or x2 > 512 or y2 > 512:
44
+ return "Error: Coordinates must be between 0 and 512.", []
45
  object_classes_list.append(object_class)
46
  object_bboxes_list.append(bbox)
47
  combined_list = [[cls, bbox] for cls, bbox in zip(object_classes_list, object_bboxes_list)]
 
49
  except ValueError:
50
  return "Error: Invalid input format. Use x1,y1,x2,y2.", []
51
 
52
+ # Function to generate images based on added objects
53
+ def generate_image(prompt, guidance_scale, num_inference_steps, randomize_seed, seed):
54
+ img_width, img_height = 512, 512
55
+ r_image = np.zeros((img_height, img_width, 3), dtype=np.uint8)
56
+ list_cond_image = []
57
+
58
+ for bbox in object_bboxes_list:
59
+ x1, y1, x2, y2 = map(int, bbox.split(","))
60
+ cond_image = np.zeros_like(r_image, dtype=np.uint8)
61
+ cond_image[y1:y2, x1:x2] = 255
62
+ list_cond_image.append(Image.fromarray(cond_image).convert('RGB'))
63
+
64
+ if randomize_seed or seed is None:
65
+ seed = np.random.randint(0, MAX_SEED)
66
+
67
+ generator = torch.manual_seed(seed)
68
+
69
+ image = pipe(
70
+ prompt=prompt,
71
+ layo_prompt=object_classes_list,
72
+ guess_mode=False,
73
+ guidance_scale=guidance_scale,
74
+ num_inference_steps=num_inference_steps,
75
+ image=list_cond_image,
76
+ fuse_type="avg",
77
+ width=512,
78
+ height=512
79
+ ).images[0]
80
+
81
+ return image, seed
82
+
83
  # Gradio UI with custom CSS
84
  css = """
85
  #custom-prompt {
86
  width: 400px; /* Set the width of the prompt input box */
87
  }
88
  #custom-button {
89
+ width: 120px; /* Set the width of the submit button */
90
+ height: 40px; /* Set the height of the submit button */
91
  }
92
  """
93
 
94
  with gr.Blocks(css=css) as demo:
95
  gr.Markdown("# Text-to-Image Generator with Object Addition")
96
 
97
+ # Put prompt and submit button in the same row with custom styles
98
  with gr.Row():
99
  prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here", elem_id="custom-prompt")
100
  submit_button = gr.Button("Submit Prompt", elem_id="custom-button")
 
105
  value=[]
106
  )
107
 
108
+ with gr.Row():
109
+ object_class_input = gr.Textbox(label="Object Class", placeholder="Enter object class (e.g., Object_1)")
110
+ bbox_input = gr.Textbox(label="Bounding Box (x1,y1,x2,y2)", placeholder="Enter bounding box coordinates")
111
+
112
+ add_button = gr.Button("Add Object")
113
+
114
+ # Advanced settings in a collapsible accordion
115
+ with gr.Accordion("Advanced Settings", open=False):
116
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
117
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
118
+
119
+ with gr.Row():
120
+ guidance_scale = gr.Slider(
121
+ label="Guidance scale",
122
+ minimum=0.0,
123
+ maximum=10.0,
124
+ step=0.1,
125
+ value=7.5
126
+ )
127
+ num_inference_steps = gr.Slider(
128
+ label="Number of inference steps",
129
+ minimum=1,
130
+ maximum=50,
131
+ step=1,
132
+ value=50
133
+ )
134
+
135
+ generate_button = gr.Button("Generate Image")
136
+ result = gr.Image(label="Generated Image")
137
+
138
+ # Submit the prompt and update the display
139
+ submit_button.click(
140
+ fn=submit_prompt,
141
+ inputs=prompt,
142
+ outputs=[objects_display, prompt]
143
+ )
144
+
145
+ # Add object and update display
146
+ add_button.click(
147
+ fn=add_object,
148
+ inputs=[object_class_input, bbox_input],
149
+ outputs=[objects_display]
150
+ )
151
+
152
+ # Generate image based on added objects
153
+ generate_button.click(
154
+ fn=generate_image,
155
+ inputs=[prompt, guidance_scale, num_inference_steps, randomize_seed, seed],
156
+ outputs=[result, seed]
157
+ )
158
 
159
  if __name__ == "__main__":
160
  demo.launch()