File size: 6,607 Bytes
3f6feda 0908d26 3f6feda 0908d26 e59aef1 e86060e 0908d26 bdb9a9a 0908d26 bdb9a9a 0908d26 bdb9a9a 0908d26 bdb9a9a 0908d26 9b4c737 0908d26 38d1b63 bdb9a9a 9b4c737 5900058 9c77bf2 0908d26 bdb9a9a 38d1b63 bdb9a9a 38d1b63 9756dd2 e59aef1 bdb9a9a 38d1b63 0908d26 43c15c5 e59aef1 0908d26 e59aef1 0908d26 9b4c737 379b9be 0908d26 30cb293 3edb685 0908d26 3edb685 0908d26 dd1bf75 30cb293 0908d26 38d1b63 0908d26 30cb293 0908d26 30cb293 0908d26 30cb293 3f6feda dd1bf75 0908d26 38d1b63 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
import gradio as gr
import numpy as np
from PIL import Image
import torch
from diffusers import ControlNetModel, UniPCMultistepScheduler
from hico_pipeline import StableDiffusionControlNetMultiLayoutPipeline
device = "cuda" if torch.cuda.is_available() else "cpu"
# Initialize model
controlnet = ControlNetModel.from_pretrained("qihoo360/HiCo_T2I", torch_dtype=torch.float16)
pipe = StableDiffusionControlNetMultiLayoutPipeline.from_pretrained(
"krnl/realisticVisionV51_v51VAE", controlnet=[controlnet], torch_dtype=torch.float16
)
pipe = pipe.to(device)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
MAX_SEED = np.iinfo(np.int32).max
# Store objects
object_classes_list = ["A photograph of a young woman wrapped in a towel wearing a pair of sunglasses", "a towel", "a young woman wrapped in a towel wearing a pair of sunglasses", "a pair of sunglasses"]
object_bboxes_list = ["0,0,512,512", "17,77,144,155", "16,28,157,155", "82,44,129,63"]
# Function to add or update the prompt in the list
def submit_prompt(prompt):
if object_classes_list:
object_classes_list[0] = prompt # Overwrite the first element if it exists
else:
object_classes_list.insert(0, prompt) # Add to the beginning if the list is empty
if not object_bboxes_list:
object_bboxes_list.insert(0, "0,0,512,512") # Add the default bounding box if the list is empty
combined_list = [[cls, bbox] for cls, bbox in zip(object_classes_list, object_bboxes_list)]
return combined_list, gr.update(interactive=False) # Make the prompt input non-editable
# Function to add a new object with validation
def add_object(object_class, bbox):
try:
# Split and convert bbox string into integers
x1, y1, x2, y2 = map(int, bbox.split(","))
# Validate the coordinates
if x2 < x1 or y2 < y1:
return "Error: x2 cannot be less than x1 and y2 cannot be less than y1.", []
if x1 < 0 or y1 < 0 or x2 > 512 or y2 > 512:
return "Error: Coordinates must be between 0 and 512.", []
# If validation passes, add to the lists
object_classes_list.append(object_class)
object_bboxes_list.append(bbox)
combined_list = [[cls, bbox] for cls, bbox in zip(object_classes_list, object_bboxes_list)]
return combined_list
except ValueError:
return "Error: Invalid input format. Use x1,y1,x2,y2.", []
# Function to generate images based on added objects
def generate_image(prompt, guidance_scale, num_inference_steps, randomize_seed, seed):
img_width, img_height = 512, 512
r_image = np.zeros((img_height, img_width, 3), dtype=np.uint8)
list_cond_image = []
for bbox in object_bboxes_list:
x1, y1, x2, y2 = map(int, bbox.split(","))
cond_image = np.zeros_like(r_image, dtype=np.uint8)
cond_image[y1:y2, x1:x2] = 255
list_cond_image.append(Image.fromarray(cond_image).convert('RGB'))
if randomize_seed or seed is None:
seed = np.random.randint(0, MAX_SEED)
generator = torch.manual_seed(seed)
image = pipe(
prompt=prompt,
layo_prompt=object_classes_list,
guess_mode=False,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
image=list_cond_image,
fuse_type="avg",
width=512,
height=512
).images[0]
print(type(image),'image')
return image, seed
# Function to clear all arrays and reset the UI
def clear_arrays():
object_classes_list.clear()
object_bboxes_list.clear()
return [], gr.update(value="", interactive=True) # Clear the objects and reset the prompt
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("# HiCo_T2I 512px")
gr.Markdown(" You can directly click **Generate Image** or customize it by first entering the global caption, followed by subcaptions and their corresponding coordinates.")
# Put prompt and submit button in the same row
with gr.Group():
with gr.Row():
# Prompt input and submit button
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt here",
container=False,
)
submit_button = gr.Button("Submit Prompt", scale=0)
# Always visible DataFrame
objects_display = gr.Dataframe(
headers=["Caption", "Bounding Box"],
value=[[cls, bbox] for cls, bbox in zip(object_classes_list, object_bboxes_list)]
)
with gr.Row():
object_class_input = gr.Textbox(label="Sub-caption", placeholder="Enter Sub-caption (e.g., apple)")
bbox_input = gr.Textbox(label="Bounding Box (x1,y1,x2,y2 and >=0 and <=512)", placeholder="Enter bounding box coordinates")
add_button = gr.Button("Add")
# Advanced settings in a collapsible accordion
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=10.0,
step=0.1,
value=7.5
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
step=1,
value=50
)
generate_button = gr.Button("Generate Image")
result = gr.Image(label="Generated Image")
# Refresh button to clear arrays and reset inputs (moved below the result)
refresh_button = gr.Button("Refresh")
# Submit the prompt and update the display
submit_button.click(
fn=submit_prompt,
inputs=prompt,
outputs=[objects_display, prompt]
)
# Add object and update display
add_button.click(
fn=add_object,
inputs=[object_class_input, bbox_input],
outputs=[objects_display]
)
# Generate image based on added objects
generate_button.click(
fn=generate_image,
inputs=[prompt, guidance_scale, num_inference_steps, randomize_seed, seed],
outputs=[result, seed]
)
# Refresh button to clear arrays and reset inputs
refresh_button.click(
fn=clear_arrays,
inputs=None,
outputs=[objects_display, prompt]
)
if __name__ == "__main__":
demo.launch()
|