ginipick commited on
Commit
0e99320
·
verified ·
1 Parent(s): 7d2ccac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -73
app.py CHANGED
@@ -3,6 +3,7 @@ import time
3
  from collections.abc import Sequence
4
  from typing import Any, cast
5
  import os
 
6
  from huggingface_hub import login, hf_hub_download
7
 
8
  import gradio as gr
@@ -19,35 +20,36 @@ from refiners.solutions import BoxSegmenter
19
  from transformers import GroundingDinoForObjectDetection, GroundingDinoProcessor
20
  from diffusers import FluxPipeline
21
  from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
22
- import gc
23
 
 
 
24
  def clear_memory():
25
- """메모리 정리 함수"""
26
  gc.collect()
27
  try:
28
  if torch.cuda.is_available():
29
  with torch.cuda.device(0): # 명시적으로 device 0 사용
30
  torch.cuda.empty_cache()
31
- except:
32
  pass
33
 
34
- # GPU 설정
35
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # 명시적으로 cuda:0 지정
36
-
37
- # GPU 설정을 try-except로 감싸기
38
  if torch.cuda.is_available():
39
  try:
40
  with torch.cuda.device(0):
41
  torch.cuda.empty_cache()
42
  torch.backends.cudnn.benchmark = True
43
  torch.backends.cuda.matmul.allow_tf32 = True
44
- except:
45
  print("Warning: Could not configure CUDA settings")
46
 
47
- # 번역 모델 초기화
 
48
  model_name = "Helsinki-NLP/opus-mt-ko-en"
49
  tokenizer = AutoTokenizer.from_pretrained(model_name)
50
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to('cpu')
 
51
  translator = pipeline("translation", model=model, tokenizer=tokenizer, device=-1)
52
 
53
  def translate_to_english(text: str) -> str:
@@ -67,6 +69,7 @@ BoundingBox = tuple[int, int, int, int]
67
  pillow_heif.register_heif_opener()
68
  pillow_heif.register_avif_opener()
69
 
 
70
  # HF 토큰 설정
71
  HF_TOKEN = os.getenv("HF_TOKEN")
72
  if HF_TOKEN is None:
@@ -77,7 +80,8 @@ try:
77
  except Exception as e:
78
  raise ValueError(f"Failed to login to Hugging Face: {str(e)}")
79
 
80
- # 모델 초기화
 
81
  segmenter = BoxSegmenter(device="cpu")
82
  segmenter.device = device
83
  segmenter.model = segmenter.model.to(device=segmenter.device)
@@ -88,15 +92,14 @@ gd_model = GroundingDinoForObjectDetection.from_pretrained(gd_model_path, torch_
88
  gd_model = gd_model.to(device=device)
89
  assert isinstance(gd_model, GroundingDinoForObjectDetection)
90
 
91
- # FLUX 파이프라인 초기화
 
92
  pipe = FluxPipeline.from_pretrained(
93
  "black-forest-labs/FLUX.1-dev",
94
  torch_dtype=torch.float16,
95
  use_auth_token=HF_TOKEN
96
  )
97
  pipe.enable_attention_slicing(slice_size="auto")
98
-
99
- # LoRA 가중치 로드
100
  pipe.load_lora_weights(
101
  hf_hub_download(
102
  "ByteDance/Hyper-SD",
@@ -105,14 +108,14 @@ pipe.load_lora_weights(
105
  )
106
  )
107
  pipe.fuse_lora(lora_scale=0.125)
108
-
109
- # GPU 설정을 try-except로 감싸기
110
  try:
111
  if torch.cuda.is_available():
112
- pipe = pipe.to("cuda:0") # 명시적으로 cuda:0 지정
113
  except Exception as e:
114
  print(f"Warning: Could not move pipeline to CUDA: {str(e)}")
115
 
 
 
116
  class timer:
117
  def __init__(self, method_name="timed process"):
118
  self.method = method_name
@@ -123,6 +126,8 @@ class timer:
123
  end = time.time()
124
  print(f"{self.method} took {str(round(end - self.start, 2))}s")
125
 
 
 
126
  def bbox_union(bboxes: Sequence[list[int]]) -> BoundingBox | None:
127
  if not bboxes:
128
  return None
@@ -166,15 +171,12 @@ def apply_mask(img: Image.Image, mask_img: Image.Image, defringe: bool = True) -
166
  result.paste(img, (0, 0), mask_img)
167
  return result
168
 
169
-
170
  def adjust_size_to_multiple_of_8(width: int, height: int) -> tuple[int, int]:
171
- """이미지 크기를 8의 배수로 조정하는 함수"""
172
  new_width = ((width + 7) // 8) * 8
173
  new_height = ((height + 7) // 8) * 8
174
  return new_width, new_height
175
 
176
  def calculate_dimensions(aspect_ratio: str, base_size: int = 512) -> tuple[int, int]:
177
- """선택된 비율에 따라 이미지 크기 계산"""
178
  if aspect_ratio == "1:1":
179
  return base_size, base_size
180
  elif aspect_ratio == "16:9":
@@ -185,7 +187,9 @@ def calculate_dimensions(aspect_ratio: str, base_size: int = 512) -> tuple[int,
185
  return base_size * 4 // 3, base_size
186
  return base_size, base_size
187
 
188
- @spaces.GPU(duration=20) # 40초에서 20초로 감소
 
 
189
  def generate_background(prompt: str, aspect_ratio: str) -> Image.Image:
190
  try:
191
  width, height = calculate_dimensions(aspect_ratio)
@@ -197,7 +201,7 @@ def generate_background(prompt: str, aspect_ratio: str) -> Image.Image:
197
  width = int(width * ratio)
198
  height = int(height * ratio)
199
  width, height = adjust_size_to_multiple_of_8(width, height)
200
-
201
  with timer("Background generation"):
202
  try:
203
  with torch.inference_mode():
@@ -211,7 +215,6 @@ def generate_background(prompt: str, aspect_ratio: str) -> Image.Image:
211
  except Exception as e:
212
  print(f"Pipeline error: {str(e)}")
213
  return Image.new('RGB', (width, height), 'white')
214
-
215
  return image
216
  except Exception as e:
217
  print(f"Background generation error: {str(e)}")
@@ -233,7 +236,6 @@ def create_position_grid():
233
  """
234
 
235
  def calculate_object_position(position: str, bg_size: tuple[int, int], obj_size: tuple[int, int]) -> tuple[int, int]:
236
- """오브젝트의 위치 계산"""
237
  bg_width, bg_height = bg_size
238
  obj_width, obj_height = obj_size
239
 
@@ -252,28 +254,21 @@ def calculate_object_position(position: str, bg_size: tuple[int, int], obj_size:
252
  return positions.get(position, positions["bottom-center"])
253
 
254
  def resize_object(image: Image.Image, scale_percent: float) -> Image.Image:
255
- """오브젝트 크기 조정"""
256
  width = int(image.width * scale_percent / 100)
257
  height = int(image.height * scale_percent / 100)
258
  return image.resize((width, height), Image.Resampling.LANCZOS)
259
 
260
  def combine_with_background(foreground: Image.Image, background: Image.Image,
261
- position: str = "bottom-center", scale_percent: float = 100) -> Image.Image:
262
- """전경과 배경 합성 함수"""
263
- # 배경 이미지 준비
264
  result = background.convert('RGBA')
265
-
266
- # 오브젝트 크기 조정
267
  scaled_foreground = resize_object(foreground, scale_percent)
268
-
269
- # 오브젝트 위치 계산
270
  x, y = calculate_object_position(position, result.size, scaled_foreground.size)
271
-
272
- # 합성
273
  result.paste(scaled_foreground, (x, y), scaled_foreground)
274
  return result
275
 
276
- @spaces.GPU(duration=30) # 120초에서 30초로 감소
 
 
277
  def _gpu_process(img: Image.Image, prompt: str | BoundingBox | None) -> tuple[Image.Image, BoundingBox | None, list[str]]:
278
  time_log: list[str] = []
279
  try:
@@ -294,6 +289,8 @@ def _gpu_process(img: Image.Image, prompt: str | BoundingBox | None) -> tuple[Im
294
  print(f"GPU process error: {str(e)}")
295
  raise
296
 
 
 
297
  def _process(img: Image.Image, prompt: str | BoundingBox | None, bg_prompt: str | None = None, aspect_ratio: str = "1:1") -> tuple[tuple[Image.Image, Image.Image, Image.Image], gr.DownloadButton]:
298
  try:
299
  # 입력 이미지 크기 제한
@@ -302,8 +299,7 @@ def _process(img: Image.Image, prompt: str | BoundingBox | None, bg_prompt: str
302
  ratio = max_size / max(img.width, img.height)
303
  new_size = (int(img.width * ratio), int(img.height * ratio))
304
  img = img.resize(new_size, Image.LANCZOS)
305
-
306
- # CUDA 메모리 관리 수정
307
  try:
308
  if torch.cuda.is_available():
309
  current_device = torch.cuda.current_device()
@@ -311,19 +307,19 @@ def _process(img: Image.Image, prompt: str | BoundingBox | None, bg_prompt: str
311
  torch.cuda.empty_cache()
312
  except Exception as e:
313
  print(f"CUDA memory management failed: {e}")
314
-
315
  with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
316
  mask, bbox, time_log = _gpu_process(img, prompt)
317
  masked_alpha = apply_mask(img, mask, defringe=True)
318
-
319
  if bg_prompt:
320
  background = generate_background(bg_prompt, aspect_ratio)
321
  combined = background
322
  else:
323
  combined = Image.alpha_composite(Image.new("RGBA", masked_alpha.size, "white"), masked_alpha)
324
-
325
  clear_memory()
326
-
327
  with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp:
328
  combined.save(temp.name)
329
  return (img, combined, masked_alpha), gr.DownloadButton(value=temp.name, interactive=True)
@@ -335,15 +331,12 @@ def _process(img: Image.Image, prompt: str | BoundingBox | None, bg_prompt: str
335
  def on_change_bbox(prompts: dict[str, Any] | None):
336
  return gr.update(interactive=prompts is not None)
337
 
338
-
339
  def on_change_prompt(img: Image.Image | None, prompt: str | None, bg_prompt: str | None = None):
340
  return gr.update(interactive=bool(img and prompt))
341
 
342
-
343
-
344
  def process_prompt(img: Image.Image, prompt: str, bg_prompt: str | None = None,
345
- aspect_ratio: str = "1:1", position: str = "bottom-center",
346
- scale_percent: float = 100) -> tuple[Image.Image, Image.Image]:
347
  try:
348
  if img is None or prompt.strip() == "":
349
  raise gr.Error("Please provide both image and prompt")
@@ -379,7 +372,7 @@ def process_prompt(img: Image.Image, prompt: str, bg_prompt: str | None = None,
379
  raise gr.Error(str(e))
380
  finally:
381
  clear_memory()
382
-
383
  def process_bbox(img: Image.Image, box_input: str) -> tuple[Image.Image, Image.Image]:
384
  try:
385
  if img is None or box_input.strip() == "":
@@ -393,15 +386,11 @@ def process_bbox(img: Image.Image, box_input: str) -> tuple[Image.Image, Image.I
393
  except:
394
  raise gr.Error("Invalid box format. Please provide [xmin, ymin, xmax, ymax]")
395
 
396
- # Process the image
397
  results, _ = _process(img, bbox)
398
-
399
- # 합성된 이미지와 추출된 이미지만 반환
400
  return results[1], results[2]
401
  except Exception as e:
402
  raise gr.Error(str(e))
403
 
404
- # Event handler functions 수정
405
  def update_process_button(img, prompt):
406
  return gr.update(
407
  interactive=bool(img and prompt),
@@ -418,7 +407,7 @@ def update_box_button(img, box_input):
418
  except:
419
  return gr.update(interactive=False, variant="secondary")
420
 
421
-
422
  # CSS 정의
423
  css = """
424
  footer {display: none}
@@ -482,9 +471,7 @@ button.primary:hover {
482
  }
483
  """
484
 
485
- # UI 구성
486
- # UI 구성 부분에서 process_btn을 위로 이동하고 position_grid.click 부분 제거
487
-
488
  # UI 구성
489
  with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
490
  gr.HTML("""
@@ -493,7 +480,6 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
493
  <p>AI Integrated Image Creator: Extract objects, generate backgrounds, and adjust ratios and positions to create complete images with AI.</p>
494
  </div>
495
  """)
496
-
497
  with gr.Row():
498
  with gr.Column(scale=1):
499
  input_image = gr.Image(
@@ -521,7 +507,6 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
521
  visible=True,
522
  scale=1
523
  )
524
-
525
  with gr.Row(visible=False) as object_controls:
526
  with gr.Column(scale=1):
527
  with gr.Row():
@@ -545,17 +530,14 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
545
  step=5,
546
  label="Object Size (%)"
547
  )
548
-
549
  process_btn = gr.Button(
550
  "Process",
551
  variant="primary",
552
  interactive=False
553
  )
554
-
555
  # 각 버튼에 대한 클릭 이벤트 처리
556
  def update_position(new_position):
557
  return new_position
558
-
559
  btn_top_left.click(fn=lambda: update_position("top-left"), outputs=position)
560
  btn_top_center.click(fn=lambda: update_position("top-center"), outputs=position)
561
  btn_top_right.click(fn=lambda: update_position("top-right"), outputs=position)
@@ -565,7 +547,6 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
565
  btn_bottom_left.click(fn=lambda: update_position("bottom-left"), outputs=position)
566
  btn_bottom_center.click(fn=lambda: update_position("bottom-center"), outputs=position)
567
  btn_bottom_right.click(fn=lambda: update_position("bottom-right"), outputs=position)
568
-
569
  with gr.Column(scale=1):
570
  with gr.Row():
571
  combined_image = gr.Image(
@@ -581,7 +562,6 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
581
  type="pil",
582
  height=256
583
  )
584
-
585
  # Event bindings
586
  input_image.change(
587
  fn=update_process_button,
@@ -589,29 +569,24 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
589
  outputs=process_btn,
590
  queue=False
591
  )
592
-
593
  text_prompt.change(
594
  fn=update_process_button,
595
  inputs=[input_image, text_prompt],
596
  outputs=process_btn,
597
  queue=False
598
  )
599
-
600
  def update_controls(bg_prompt):
601
- """배경 프롬프트 입력 여부에 따라 컨트롤 표시 업데이트"""
602
  is_visible = bool(bg_prompt)
603
  return [
604
- gr.update(visible=is_visible), # aspect_ratio
605
- gr.update(visible=is_visible), # object_controls
606
  ]
607
-
608
  bg_prompt.change(
609
  fn=update_controls,
610
  inputs=bg_prompt,
611
  outputs=[aspect_ratio, object_controls],
612
  queue=False
613
  )
614
-
615
  process_btn.click(
616
  fn=process_prompt,
617
  inputs=[
@@ -625,12 +600,23 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
625
  outputs=[combined_image, extracted_image],
626
  queue=True
627
  )
628
-
629
-
630
- demo.queue(max_size=5) # 큐 크기 제한
 
 
 
 
 
 
 
 
 
 
 
631
  demo.launch(
632
  server_name="0.0.0.0",
633
  server_port=7860,
634
  share=False,
635
- max_threads=2 # 스레드 수 제한
636
  )
 
3
  from collections.abc import Sequence
4
  from typing import Any, cast
5
  import os
6
+ import gc
7
  from huggingface_hub import login, hf_hub_download
8
 
9
  import gradio as gr
 
20
  from transformers import GroundingDinoForObjectDetection, GroundingDinoProcessor
21
  from diffusers import FluxPipeline
22
  from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
 
23
 
24
+ #############################################################
25
+ # 메모리 정리 함수
26
  def clear_memory():
 
27
  gc.collect()
28
  try:
29
  if torch.cuda.is_available():
30
  with torch.cuda.device(0): # 명시적으로 device 0 사용
31
  torch.cuda.empty_cache()
32
+ except Exception as e:
33
  pass
34
 
35
+ #############################################################
36
+ # GPU 설정 (Zero GPU 환경)
37
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
38
  if torch.cuda.is_available():
39
  try:
40
  with torch.cuda.device(0):
41
  torch.cuda.empty_cache()
42
  torch.backends.cudnn.benchmark = True
43
  torch.backends.cuda.matmul.allow_tf32 = True
44
+ except Exception as e:
45
  print("Warning: Could not configure CUDA settings")
46
 
47
+ #############################################################
48
+ # 번역 모델 초기화 (CPU에서 동작)
49
  model_name = "Helsinki-NLP/opus-mt-ko-en"
50
  tokenizer = AutoTokenizer.from_pretrained(model_name)
51
+ # 번역 모델은 CPU에 올림
52
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to("cpu")
53
  translator = pipeline("translation", model=model, tokenizer=tokenizer, device=-1)
54
 
55
  def translate_to_english(text: str) -> str:
 
69
  pillow_heif.register_heif_opener()
70
  pillow_heif.register_avif_opener()
71
 
72
+ #############################################################
73
  # HF 토큰 설정
74
  HF_TOKEN = os.getenv("HF_TOKEN")
75
  if HF_TOKEN is None:
 
80
  except Exception as e:
81
  raise ValueError(f"Failed to login to Hugging Face: {str(e)}")
82
 
83
+ #############################################################
84
+ # 객체 분할 모델 초기화
85
  segmenter = BoxSegmenter(device="cpu")
86
  segmenter.device = device
87
  segmenter.model = segmenter.model.to(device=segmenter.device)
 
92
  gd_model = gd_model.to(device=device)
93
  assert isinstance(gd_model, GroundingDinoForObjectDetection)
94
 
95
+ #############################################################
96
+ # FLUX 파이프라인 초기화 (Zero GPU용)
97
  pipe = FluxPipeline.from_pretrained(
98
  "black-forest-labs/FLUX.1-dev",
99
  torch_dtype=torch.float16,
100
  use_auth_token=HF_TOKEN
101
  )
102
  pipe.enable_attention_slicing(slice_size="auto")
 
 
103
  pipe.load_lora_weights(
104
  hf_hub_download(
105
  "ByteDance/Hyper-SD",
 
108
  )
109
  )
110
  pipe.fuse_lora(lora_scale=0.125)
 
 
111
  try:
112
  if torch.cuda.is_available():
113
+ pipe = pipe.to("cuda:0") # 명시적으로 cuda:0 이동
114
  except Exception as e:
115
  print(f"Warning: Could not move pipeline to CUDA: {str(e)}")
116
 
117
+ #############################################################
118
+ # 타이머 클래스
119
  class timer:
120
  def __init__(self, method_name="timed process"):
121
  self.method = method_name
 
126
  end = time.time()
127
  print(f"{self.method} took {str(round(end - self.start, 2))}s")
128
 
129
+ #############################################################
130
+ # 유틸리티 함수들
131
  def bbox_union(bboxes: Sequence[list[int]]) -> BoundingBox | None:
132
  if not bboxes:
133
  return None
 
171
  result.paste(img, (0, 0), mask_img)
172
  return result
173
 
 
174
  def adjust_size_to_multiple_of_8(width: int, height: int) -> tuple[int, int]:
 
175
  new_width = ((width + 7) // 8) * 8
176
  new_height = ((height + 7) // 8) * 8
177
  return new_width, new_height
178
 
179
  def calculate_dimensions(aspect_ratio: str, base_size: int = 512) -> tuple[int, int]:
 
180
  if aspect_ratio == "1:1":
181
  return base_size, base_size
182
  elif aspect_ratio == "16:9":
 
187
  return base_size * 4 // 3, base_size
188
  return base_size, base_size
189
 
190
+ #############################################################
191
+ # 배경 생성 함수 (Zero GPU에 맞게 수정)
192
+ @spaces.GPU(duration=20)
193
  def generate_background(prompt: str, aspect_ratio: str) -> Image.Image:
194
  try:
195
  width, height = calculate_dimensions(aspect_ratio)
 
201
  width = int(width * ratio)
202
  height = int(height * ratio)
203
  width, height = adjust_size_to_multiple_of_8(width, height)
204
+
205
  with timer("Background generation"):
206
  try:
207
  with torch.inference_mode():
 
215
  except Exception as e:
216
  print(f"Pipeline error: {str(e)}")
217
  return Image.new('RGB', (width, height), 'white')
 
218
  return image
219
  except Exception as e:
220
  print(f"Background generation error: {str(e)}")
 
236
  """
237
 
238
  def calculate_object_position(position: str, bg_size: tuple[int, int], obj_size: tuple[int, int]) -> tuple[int, int]:
 
239
  bg_width, bg_height = bg_size
240
  obj_width, obj_height = obj_size
241
 
 
254
  return positions.get(position, positions["bottom-center"])
255
 
256
  def resize_object(image: Image.Image, scale_percent: float) -> Image.Image:
 
257
  width = int(image.width * scale_percent / 100)
258
  height = int(image.height * scale_percent / 100)
259
  return image.resize((width, height), Image.Resampling.LANCZOS)
260
 
261
  def combine_with_background(foreground: Image.Image, background: Image.Image,
262
+ position: str = "bottom-center", scale_percent: float = 100) -> Image.Image:
 
 
263
  result = background.convert('RGBA')
 
 
264
  scaled_foreground = resize_object(foreground, scale_percent)
 
 
265
  x, y = calculate_object_position(position, result.size, scaled_foreground.size)
 
 
266
  result.paste(scaled_foreground, (x, y), scaled_foreground)
267
  return result
268
 
269
+ #############################################################
270
+ # GPU 처리 함수 (Zero GPU에 맞게 수정)
271
+ @spaces.GPU(duration=30)
272
  def _gpu_process(img: Image.Image, prompt: str | BoundingBox | None) -> tuple[Image.Image, BoundingBox | None, list[str]]:
273
  time_log: list[str] = []
274
  try:
 
289
  print(f"GPU process error: {str(e)}")
290
  raise
291
 
292
+ #############################################################
293
+ # 전체 처리 함수
294
  def _process(img: Image.Image, prompt: str | BoundingBox | None, bg_prompt: str | None = None, aspect_ratio: str = "1:1") -> tuple[tuple[Image.Image, Image.Image, Image.Image], gr.DownloadButton]:
295
  try:
296
  # 입력 이미지 크기 제한
 
299
  ratio = max_size / max(img.width, img.height)
300
  new_size = (int(img.width * ratio), int(img.height * ratio))
301
  img = img.resize(new_size, Image.LANCZOS)
302
+
 
303
  try:
304
  if torch.cuda.is_available():
305
  current_device = torch.cuda.current_device()
 
307
  torch.cuda.empty_cache()
308
  except Exception as e:
309
  print(f"CUDA memory management failed: {e}")
310
+
311
  with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
312
  mask, bbox, time_log = _gpu_process(img, prompt)
313
  masked_alpha = apply_mask(img, mask, defringe=True)
314
+
315
  if bg_prompt:
316
  background = generate_background(bg_prompt, aspect_ratio)
317
  combined = background
318
  else:
319
  combined = Image.alpha_composite(Image.new("RGBA", masked_alpha.size, "white"), masked_alpha)
320
+
321
  clear_memory()
322
+
323
  with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp:
324
  combined.save(temp.name)
325
  return (img, combined, masked_alpha), gr.DownloadButton(value=temp.name, interactive=True)
 
331
  def on_change_bbox(prompts: dict[str, Any] | None):
332
  return gr.update(interactive=prompts is not None)
333
 
 
334
  def on_change_prompt(img: Image.Image | None, prompt: str | None, bg_prompt: str | None = None):
335
  return gr.update(interactive=bool(img and prompt))
336
 
 
 
337
  def process_prompt(img: Image.Image, prompt: str, bg_prompt: str | None = None,
338
+ aspect_ratio: str = "1:1", position: str = "bottom-center",
339
+ scale_percent: float = 100) -> tuple[Image.Image, Image.Image]:
340
  try:
341
  if img is None or prompt.strip() == "":
342
  raise gr.Error("Please provide both image and prompt")
 
372
  raise gr.Error(str(e))
373
  finally:
374
  clear_memory()
375
+
376
  def process_bbox(img: Image.Image, box_input: str) -> tuple[Image.Image, Image.Image]:
377
  try:
378
  if img is None or box_input.strip() == "":
 
386
  except:
387
  raise gr.Error("Invalid box format. Please provide [xmin, ymin, xmax, ymax]")
388
 
 
389
  results, _ = _process(img, bbox)
 
 
390
  return results[1], results[2]
391
  except Exception as e:
392
  raise gr.Error(str(e))
393
 
 
394
  def update_process_button(img, prompt):
395
  return gr.update(
396
  interactive=bool(img and prompt),
 
407
  except:
408
  return gr.update(interactive=False, variant="secondary")
409
 
410
+ #############################################################
411
  # CSS 정의
412
  css = """
413
  footer {display: none}
 
471
  }
472
  """
473
 
474
+ #############################################################
 
 
475
  # UI 구성
476
  with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
477
  gr.HTML("""
 
480
  <p>AI Integrated Image Creator: Extract objects, generate backgrounds, and adjust ratios and positions to create complete images with AI.</p>
481
  </div>
482
  """)
 
483
  with gr.Row():
484
  with gr.Column(scale=1):
485
  input_image = gr.Image(
 
507
  visible=True,
508
  scale=1
509
  )
 
510
  with gr.Row(visible=False) as object_controls:
511
  with gr.Column(scale=1):
512
  with gr.Row():
 
530
  step=5,
531
  label="Object Size (%)"
532
  )
 
533
  process_btn = gr.Button(
534
  "Process",
535
  variant="primary",
536
  interactive=False
537
  )
 
538
  # 각 버튼에 대한 클릭 이벤트 처리
539
  def update_position(new_position):
540
  return new_position
 
541
  btn_top_left.click(fn=lambda: update_position("top-left"), outputs=position)
542
  btn_top_center.click(fn=lambda: update_position("top-center"), outputs=position)
543
  btn_top_right.click(fn=lambda: update_position("top-right"), outputs=position)
 
547
  btn_bottom_left.click(fn=lambda: update_position("bottom-left"), outputs=position)
548
  btn_bottom_center.click(fn=lambda: update_position("bottom-center"), outputs=position)
549
  btn_bottom_right.click(fn=lambda: update_position("bottom-right"), outputs=position)
 
550
  with gr.Column(scale=1):
551
  with gr.Row():
552
  combined_image = gr.Image(
 
562
  type="pil",
563
  height=256
564
  )
 
565
  # Event bindings
566
  input_image.change(
567
  fn=update_process_button,
 
569
  outputs=process_btn,
570
  queue=False
571
  )
 
572
  text_prompt.change(
573
  fn=update_process_button,
574
  inputs=[input_image, text_prompt],
575
  outputs=process_btn,
576
  queue=False
577
  )
 
578
  def update_controls(bg_prompt):
 
579
  is_visible = bool(bg_prompt)
580
  return [
581
+ gr.update(visible=is_visible),
582
+ gr.update(visible=is_visible),
583
  ]
 
584
  bg_prompt.change(
585
  fn=update_controls,
586
  inputs=bg_prompt,
587
  outputs=[aspect_ratio, object_controls],
588
  queue=False
589
  )
 
590
  process_btn.click(
591
  fn=process_prompt,
592
  inputs=[
 
600
  outputs=[combined_image, extracted_image],
601
  queue=True
602
  )
603
+ # 예제 섹션 추가
604
+ with gr.Accordion("Show Example", open=True):
605
+ gr.Markdown("### Example")
606
+ with gr.Row():
607
+ with gr.Column():
608
+ gr.Markdown("**Upload Image(aa1.png)**")
609
+ gr.Image(value="aa1.png", label="Upload")
610
+ with gr.Column():
611
+ gr.Markdown("**Cut Object (aa2.png)**<br>(Prompt: 'text')", elem_classes="center")
612
+ gr.Image(value="aa2.png", label="Object")
613
+ with gr.Column():
614
+ gr.Markdown("**Generated Image (aa3.png)**<br>(Background Prompt: 'alps mountain')", elem_classes="center")
615
+ gr.Image(value="aa3.png", label="Output")
616
+ demo.queue(max_size=5)
617
  demo.launch(
618
  server_name="0.0.0.0",
619
  server_port=7860,
620
  share=False,
621
+ max_threads=2
622
  )