ginipick commited on
Commit
20d0a5c
ยท
verified ยท
1 Parent(s): d30fa97

Create app-backup.py

Browse files
Files changed (1) hide show
  1. app-backup.py +418 -0
app-backup.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tempfile
2
+ import time
3
+ from collections.abc import Sequence
4
+ from typing import Any, cast
5
+ import os
6
+ from huggingface_hub import login, hf_hub_download
7
+
8
+ import gradio as gr
9
+ import numpy as np
10
+ import pillow_heif
11
+ import spaces
12
+ import torch
13
+ from gradio_image_annotation import image_annotator
14
+ from gradio_imageslider import ImageSlider
15
+ from PIL import Image
16
+ from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
17
+ from refiners.fluxion.utils import no_grad
18
+ from refiners.solutions import BoxSegmenter
19
+ from transformers import GroundingDinoForObjectDetection, GroundingDinoProcessor
20
+ from diffusers import FluxPipeline
21
+
22
+ BoundingBox = tuple[int, int, int, int]
23
+
24
+ pillow_heif.register_heif_opener()
25
+ pillow_heif.register_avif_opener()
26
+
27
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+
29
+ # HF ํ† ํฐ ์„ค์ •
30
+ HF_TOKEN = os.getenv("HF_TOKEN")
31
+ if HF_TOKEN is None:
32
+ raise ValueError("Please set the HF_TOKEN environment variable")
33
+
34
+ try:
35
+ login(token=HF_TOKEN)
36
+ except Exception as e:
37
+ raise ValueError(f"Failed to login to Hugging Face: {str(e)}")
38
+
39
+ # ๋ชจ๋ธ ์ดˆ๊ธฐํ™”
40
+ segmenter = BoxSegmenter(device="cpu")
41
+ segmenter.device = device
42
+ segmenter.model = segmenter.model.to(device=segmenter.device)
43
+
44
+ gd_model_path = "IDEA-Research/grounding-dino-base"
45
+ gd_processor = GroundingDinoProcessor.from_pretrained(gd_model_path)
46
+ gd_model = GroundingDinoForObjectDetection.from_pretrained(gd_model_path, torch_dtype=torch.float32)
47
+ gd_model = gd_model.to(device=device)
48
+ assert isinstance(gd_model, GroundingDinoForObjectDetection)
49
+
50
+ # FLUX ํŒŒ์ดํ”„๋ผ์ธ ์ดˆ๊ธฐํ™”
51
+ pipe = FluxPipeline.from_pretrained(
52
+ "black-forest-labs/FLUX.1-dev",
53
+ torch_dtype=torch.bfloat16,
54
+ use_auth_token=HF_TOKEN
55
+ )
56
+ pipe.load_lora_weights(
57
+ hf_hub_download(
58
+ "ByteDance/Hyper-SD",
59
+ "Hyper-FLUX.1-dev-8steps-lora.safetensors",
60
+ use_auth_token=HF_TOKEN
61
+ )
62
+ )
63
+ pipe.fuse_lora(lora_scale=0.125)
64
+ pipe.to(device="cuda", dtype=torch.bfloat16)
65
+
66
+ class timer:
67
+ def __init__(self, method_name="timed process"):
68
+ self.method = method_name
69
+ def __enter__(self):
70
+ self.start = time.time()
71
+ print(f"{self.method} starts")
72
+ def __exit__(self, exc_type, exc_val, exc_tb):
73
+ end = time.time()
74
+ print(f"{self.method} took {str(round(end - self.start, 2))}s")
75
+
76
+ def bbox_union(bboxes: Sequence[list[int]]) -> BoundingBox | None:
77
+ if not bboxes:
78
+ return None
79
+ for bbox in bboxes:
80
+ assert len(bbox) == 4
81
+ assert all(isinstance(x, int) for x in bbox)
82
+ return (
83
+ min(bbox[0] for bbox in bboxes),
84
+ min(bbox[1] for bbox in bboxes),
85
+ max(bbox[2] for bbox in bboxes),
86
+ max(bbox[3] for bbox in bboxes),
87
+ )
88
+
89
+ def corners_to_pixels_format(bboxes: torch.Tensor, width: int, height: int) -> torch.Tensor:
90
+ x1, y1, x2, y2 = bboxes.round().to(torch.int32).unbind(-1)
91
+ return torch.stack((x1.clamp_(0, width), y1.clamp_(0, height), x2.clamp_(0, width), y2.clamp_(0, height)), dim=-1)
92
+
93
+ def gd_detect(img: Image.Image, prompt: str) -> BoundingBox | None:
94
+ inputs = gd_processor(images=img, text=f"{prompt}.", return_tensors="pt").to(device=device)
95
+ with no_grad():
96
+ outputs = gd_model(**inputs)
97
+ width, height = img.size
98
+ results: dict[str, Any] = gd_processor.post_process_grounded_object_detection(
99
+ outputs,
100
+ inputs["input_ids"],
101
+ target_sizes=[(height, width)],
102
+ )[0]
103
+ assert "boxes" in results and isinstance(results["boxes"], torch.Tensor)
104
+ bboxes = corners_to_pixels_format(results["boxes"].cpu(), width, height)
105
+ return bbox_union(bboxes.numpy().tolist())
106
+
107
+ def apply_mask(img: Image.Image, mask_img: Image.Image, defringe: bool = True) -> Image.Image:
108
+ assert img.size == mask_img.size
109
+ img = img.convert("RGB")
110
+ mask_img = mask_img.convert("L")
111
+ if defringe:
112
+ rgb, alpha = np.asarray(img) / 255.0, np.asarray(mask_img) / 255.0
113
+ foreground = cast(np.ndarray[Any, np.dtype[np.uint8]], estimate_foreground_ml(rgb, alpha))
114
+ img = Image.fromarray((foreground * 255).astype("uint8"))
115
+ result = Image.new("RGBA", img.size)
116
+ result.paste(img, (0, 0), mask_img)
117
+ return result
118
+
119
+
120
+ def adjust_size_to_multiple_of_8(width: int, height: int) -> tuple[int, int]:
121
+ """์ด๋ฏธ์ง€ ํฌ๊ธฐ๋ฅผ 8์˜ ๋ฐฐ์ˆ˜๋กœ ์กฐ์ •ํ•˜๋Š” ํ•จ์ˆ˜"""
122
+ new_width = ((width + 7) // 8) * 8
123
+ new_height = ((height + 7) // 8) * 8
124
+ return new_width, new_height
125
+
126
+ def calculate_dimensions(aspect_ratio: str, base_size: int = 512) -> tuple[int, int]:
127
+ """์„ ํƒ๋œ ๋น„์œจ์— ๋”ฐ๋ผ ์ด๋ฏธ์ง€ ํฌ๊ธฐ ๊ณ„์‚ฐ"""
128
+ if aspect_ratio == "1:1":
129
+ return base_size, base_size
130
+ elif aspect_ratio == "16:9":
131
+ return base_size * 16 // 9, base_size
132
+ elif aspect_ratio == "9:16":
133
+ return base_size, base_size * 16 // 9
134
+ elif aspect_ratio == "4:3":
135
+ return base_size * 4 // 3, base_size
136
+ return base_size, base_size
137
+
138
+ def generate_background(prompt: str, aspect_ratio: str) -> Image.Image:
139
+ """๋ฐฐ๊ฒฝ ์ด๋ฏธ์ง€ ์ƒ์„ฑ ํ•จ์ˆ˜"""
140
+ try:
141
+ # ์„ ํƒ๋œ ๋น„์œจ์— ๋”ฐ๋ผ ํฌ๊ธฐ ๊ณ„์‚ฐ
142
+ width, height = calculate_dimensions(aspect_ratio)
143
+
144
+ # 8์˜ ๋ฐฐ์ˆ˜๋กœ ์กฐ์ •
145
+ width, height = adjust_size_to_multiple_of_8(width, height)
146
+
147
+ with timer("Background generation"):
148
+ image = pipe(
149
+ prompt=prompt,
150
+ width=width,
151
+ height=height,
152
+ num_inference_steps=8,
153
+ guidance_scale=4.0,
154
+ ).images[0]
155
+
156
+ return image
157
+ except Exception as e:
158
+ raise gr.Error(f"Background generation failed: {str(e)}")
159
+
160
+
161
+ def combine_with_background(foreground: Image.Image, background: Image.Image) -> Image.Image:
162
+ """์ „๊ฒฝ๊ณผ ๋ฐฐ๊ฒฝ ํ•ฉ์„ฑ ํ•จ์ˆ˜"""
163
+ background = background.resize(foreground.size)
164
+ return Image.alpha_composite(background.convert('RGBA'), foreground)
165
+
166
+ @spaces.GPU
167
+ def _gpu_process(img: Image.Image, prompt: str | BoundingBox | None) -> tuple[Image.Image, BoundingBox | None, list[str]]:
168
+ time_log: list[str] = []
169
+ if isinstance(prompt, str):
170
+ t0 = time.time()
171
+ bbox = gd_detect(img, prompt)
172
+ time_log.append(f"detect: {time.time() - t0}")
173
+ if not bbox:
174
+ print(time_log[0])
175
+ raise gr.Error("No object detected")
176
+ else:
177
+ bbox = prompt
178
+ t0 = time.time()
179
+ mask = segmenter(img, bbox)
180
+ time_log.append(f"segment: {time.time() - t0}")
181
+ return mask, bbox, time_log
182
+
183
+ def _process(img: Image.Image, prompt: str | BoundingBox | None, bg_prompt: str | None = None, aspect_ratio: str = "1:1") -> tuple[tuple[Image.Image, Image.Image, Image.Image], gr.DownloadButton]:
184
+ try:
185
+ if img.width > 2048 or img.height > 2048:
186
+ orig_res = max(img.width, img.height)
187
+ img.thumbnail((2048, 2048))
188
+ if isinstance(prompt, tuple):
189
+ x0, y0, x1, y1 = (int(x * 2048 / orig_res) for x in prompt)
190
+ prompt = (x0, y0, x1, y1)
191
+
192
+ mask, bbox, time_log = _gpu_process(img, prompt)
193
+ masked_alpha = apply_mask(img, mask, defringe=True)
194
+
195
+ if bg_prompt:
196
+ background = generate_background(bg_prompt, aspect_ratio)
197
+ combined = combine_with_background(masked_alpha, background)
198
+ else:
199
+ combined = Image.alpha_composite(Image.new("RGBA", masked_alpha.size, "white"), masked_alpha)
200
+
201
+ thresholded = mask.point(lambda p: 255 if p > 10 else 0)
202
+ bbox = thresholded.getbbox()
203
+ to_dl = masked_alpha.crop(bbox)
204
+
205
+ temp = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
206
+ to_dl.save(temp, format="PNG")
207
+ temp.close()
208
+
209
+ return (img, combined, masked_alpha), gr.DownloadButton(value=temp.name, interactive=True)
210
+
211
+ except Exception as e:
212
+ raise gr.Error(f"Processing failed: {str(e)}")
213
+
214
+ def on_change_bbox(prompts: dict[str, Any] | None):
215
+ return gr.update(interactive=prompts is not None)
216
+
217
+
218
+ def on_change_prompt(img: Image.Image | None, prompt: str | None, bg_prompt: str | None = None):
219
+ return gr.update(interactive=bool(img and prompt))
220
+
221
+ def process_prompt(img: Image.Image, prompt: str, bg_prompt: str | None = None, aspect_ratio: str = "1:1") -> tuple[Image.Image, Image.Image]:
222
+ try:
223
+ if img is None or prompt.strip() == "":
224
+ raise gr.Error("Please provide both image and prompt")
225
+
226
+ # Process the image
227
+ results, _ = _process(img, prompt, bg_prompt, aspect_ratio)
228
+
229
+ # ํ•ฉ์„ฑ๋œ ์ด๋ฏธ์ง€์™€ ์ถ”์ถœ๋œ ์ด๋ฏธ์ง€๋งŒ ๋ฐ˜ํ™˜
230
+ return results[1], results[2]
231
+ except Exception as e:
232
+ raise gr.Error(str(e))
233
+
234
+ def process_bbox(img: Image.Image, box_input: str) -> tuple[Image.Image, Image.Image]:
235
+ try:
236
+ if img is None or box_input.strip() == "":
237
+ raise gr.Error("Please provide both image and bounding box coordinates")
238
+
239
+ try:
240
+ coords = eval(box_input)
241
+ if not isinstance(coords, list) or len(coords) != 4:
242
+ raise ValueError("Invalid box format")
243
+ bbox = tuple(int(x) for x in coords)
244
+ except:
245
+ raise gr.Error("Invalid box format. Please provide [xmin, ymin, xmax, ymax]")
246
+
247
+ # Process the image
248
+ results, _ = _process(img, bbox)
249
+
250
+ # ํ•ฉ์„ฑ๋œ ์ด๋ฏธ์ง€์™€ ์ถ”์ถœ๋œ ์ด๋ฏธ์ง€๋งŒ ๋ฐ˜ํ™˜
251
+ return results[1], results[2]
252
+ except Exception as e:
253
+ raise gr.Error(str(e))
254
+
255
+ # Event handler functions ์ˆ˜์ •
256
+ def update_process_button(img, prompt):
257
+ return gr.update(
258
+ interactive=bool(img and prompt),
259
+ variant="primary" if bool(img and prompt) else "secondary"
260
+ )
261
+
262
+ def update_box_button(img, box_input):
263
+ try:
264
+ if img and box_input:
265
+ coords = eval(box_input)
266
+ if isinstance(coords, list) and len(coords) == 4:
267
+ return gr.update(interactive=True, variant="primary")
268
+ return gr.update(interactive=False, variant="secondary")
269
+ except:
270
+ return gr.update(interactive=False, variant="secondary")
271
+
272
+
273
+ # ๋งจ ์•ž๋ถ€๋ถ„์— CSS ์ •์˜ ์ถ”๊ฐ€
274
+ css = """
275
+ footer {display: none}
276
+ .main-title {
277
+ text-align: center;
278
+ margin: 2em 0;
279
+ padding: 1em;
280
+ background: #f7f7f7;
281
+ border-radius: 10px;
282
+ }
283
+ .main-title h1 {
284
+ color: #2196F3;
285
+ font-size: 2.5em;
286
+ margin-bottom: 0.5em;
287
+ }
288
+ .main-title p {
289
+ color: #666;
290
+ font-size: 1.2em;
291
+ }
292
+ .container {
293
+ max-width: 1200px;
294
+ margin: auto;
295
+ padding: 20px;
296
+ }
297
+ .tabs {
298
+ margin-top: 1em;
299
+ }
300
+ .input-group {
301
+ background: white;
302
+ padding: 1em;
303
+ border-radius: 8px;
304
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
305
+ }
306
+ .output-group {
307
+ background: white;
308
+ padding: 1em;
309
+ border-radius: 8px;
310
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
311
+ }
312
+ button.primary {
313
+ background: #2196F3;
314
+ border: none;
315
+ color: white;
316
+ padding: 0.5em 1em;
317
+ border-radius: 4px;
318
+ cursor: pointer;
319
+ transition: background 0.3s ease;
320
+ }
321
+ button.primary:hover {
322
+ background: #1976D2;
323
+ }
324
+ """
325
+
326
+ # UI ๋ถ€๋ถ„ ์ˆ˜์ •
327
+ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
328
+ gr.HTML("""
329
+ <div class="main-title">
330
+ <h1>๐ŸŽจ Image Object Extractor</h1>
331
+ <p>Extract objects from images using text prompts</p>
332
+ </div>
333
+ """)
334
+
335
+ with gr.Row():
336
+ with gr.Column(scale=1):
337
+ input_image = gr.Image(
338
+ type="pil",
339
+ label="Upload Image",
340
+ interactive=True
341
+ )
342
+ text_prompt = gr.Textbox(
343
+ label="Object to Extract",
344
+ placeholder="Enter what you want to extract...",
345
+ interactive=True
346
+ )
347
+ with gr.Row():
348
+ bg_prompt = gr.Textbox(
349
+ label="Background Prompt (optional)",
350
+ placeholder="Describe the background...",
351
+ interactive=True,
352
+ scale=3
353
+ )
354
+ aspect_ratio = gr.Dropdown(
355
+ choices=["1:1", "16:9", "9:16", "4:3"],
356
+ value="1:1",
357
+ label="Aspect Ratio",
358
+ interactive=True,
359
+ visible=True,
360
+ scale=1
361
+ )
362
+ process_btn = gr.Button(
363
+ "Process",
364
+ variant="primary",
365
+ interactive=False
366
+ )
367
+
368
+ with gr.Column(scale=1):
369
+ with gr.Row():
370
+ combined_image = gr.Image(
371
+ label="Combined Result",
372
+ show_download_button=True,
373
+ type="pil",
374
+ height=512
375
+ )
376
+ with gr.Row():
377
+ extracted_image = gr.Image(
378
+ label="Extracted Object",
379
+ show_download_button=True,
380
+ type="pil",
381
+ height=256
382
+ )
383
+
384
+ # Event bindings
385
+ input_image.change(
386
+ fn=update_process_button,
387
+ inputs=[input_image, text_prompt],
388
+ outputs=process_btn,
389
+ queue=False
390
+ )
391
+
392
+ text_prompt.change(
393
+ fn=update_process_button,
394
+ inputs=[input_image, text_prompt],
395
+ outputs=process_btn,
396
+ queue=False
397
+ )
398
+
399
+ # bg_prompt๊ฐ€ ๋น„์–ด์žˆ์„ ๋•Œ aspect_ratio๋ฅผ ๋น„ํ™œ์„ฑํ™”ํ•˜๋Š” ํ•จ์ˆ˜
400
+ def update_aspect_ratio(bg_prompt):
401
+ return gr.update(visible=bool(bg_prompt))
402
+
403
+ bg_prompt.change(
404
+ fn=update_aspect_ratio,
405
+ inputs=bg_prompt,
406
+ outputs=aspect_ratio,
407
+ queue=False
408
+ )
409
+
410
+ process_btn.click(
411
+ fn=process_prompt,
412
+ inputs=[input_image, text_prompt, bg_prompt, aspect_ratio],
413
+ outputs=[combined_image, extracted_image],
414
+ queue=True
415
+ )
416
+
417
+ demo.queue(max_size=30, api_open=False)
418
+ demo.launch()