aiqcamp commited on
Commit
202b398
ยท
verified ยท
1 Parent(s): 18640ca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +129 -45
app.py CHANGED
@@ -1,33 +1,75 @@
1
- import random
2
- import os
3
- import uuid
4
- from datetime import datetime
5
  import gradio as gr
6
  import numpy as np
 
7
  import torch
8
  from diffusers import DiffusionPipeline
9
- from PIL import Image
10
  import warnings
 
 
 
11
 
12
  # ๊ฒฝ๊ณ  ๋ฉ”์‹œ์ง€ ์ˆจ๊ธฐ๊ธฐ
13
  warnings.filterwarnings('ignore', category=UserWarning)
14
 
15
- # ์ €์žฅ ๋””๋ ‰ํ† ๋ฆฌ ์ƒ์„ฑ
16
  SAVE_DIR = "saved_images"
17
  if not os.path.exists(SAVE_DIR):
18
  os.makedirs(SAVE_DIR, exist_ok=True)
19
 
20
- # ์žฅ์น˜ ์„ค์ •
 
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
 
23
  # ๋ชจ๋ธ ๋กœ๋“œ
24
- repo_id = "black-forest-labs/FLUX.1-schnell"
25
- pipeline = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.float32)
26
- pipeline = pipeline.to(device)
 
 
 
 
 
 
 
 
27
 
28
  MAX_SEED = np.iinfo(np.int32).max
29
  MAX_IMAGE_SIZE = 2048
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  # Enhanced examples with more detailed prompts and specific styling
32
  EXAMPLES = [
33
  {
@@ -257,43 +299,85 @@ def generate_diagram(prompt, width=1024, height=1024):
257
  except Exception as e:
258
  raise gr.Error(f"๋‹ค์ด์–ด๊ทธ๋žจ ์ƒ์„ฑ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}")
259
 
 
 
 
 
 
 
 
 
260
  # Gradio ์ธํ„ฐํŽ˜์ด์Šค ์ƒ์„ฑ
261
- demo = gr.Interface(
262
- fn=generate_diagram,
263
- inputs=[
264
- gr.Textbox(
265
- label="๋‹ค์ด์–ด๊ทธ๋žจ ํ”„๋กฌํ”„ํŠธ",
266
- placeholder="๋‹ค์ด์–ด๊ทธ๋žจ ๊ตฌ์กฐ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”...",
267
- lines=10
268
- ),
269
- gr.Slider(
270
- label="๋„ˆ๋น„",
271
- minimum=512,
272
- maximum=MAX_IMAGE_SIZE,
273
- step=128,
274
- value=1024
275
- ),
276
- gr.Slider(
277
- label="๋†’์ด",
278
- minimum=512,
279
- maximum=MAX_IMAGE_SIZE,
280
- step=128,
281
- value=1024
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
  )
283
- ],
284
- outputs=gr.Image(label="์ƒ์„ฑ๋œ ๋‹ค์ด์–ด๊ทธ๋žจ"),
285
- title="๐ŸŽจ FLUX ๋‹ค์ด์–ด๊ทธ๋žจ ์ƒ์„ฑ๊ธฐ",
286
- description="FLUX AI๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์•„๋ฆ„๋‹ค์šด ์†๊ทธ๋ฆผ ์Šคํƒ€์ผ์˜ ๋‹ค์ด์–ด๊ทธ๋žจ์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค",
287
- article="""
288
- ### ๋” ๋‚˜์€ ๊ฒฐ๊ณผ๋ฅผ ์œ„ํ•œ ํŒ
289
- - ๋ช…ํ™•ํ•œ ๊ณ„์ธต ๊ตฌ์กฐ ์‚ฌ์šฉ
290
- - ๋Œ€๊ด„ํ˜ธ ์•ˆ์— ์•„์ด์ฝ˜ ์„ค๋ช… ํฌํ•จ
291
- - ๊ฐ„๊ฒฐํ•˜๊ณ  ์˜๋ฏธ ์žˆ๋Š” ํ…์ŠคํŠธ ์‚ฌ์šฉ
292
- - ์ผ๊ด€๋œ ํ˜•์‹ ์œ ์ง€
293
- """,
294
- examples=GRADIO_EXAMPLES,
295
- cache_examples=True
296
- )
297
 
298
  # ์•ฑ ์‹คํ–‰
299
  if __name__ == "__main__":
 
 
 
 
 
1
  import gradio as gr
2
  import numpy as np
3
+ import random
4
  import torch
5
  from diffusers import DiffusionPipeline
 
6
  import warnings
7
+ import os
8
+ from datetime import datetime
9
+ import uuid
10
 
11
  # ๊ฒฝ๊ณ  ๋ฉ”์‹œ์ง€ ์ˆจ๊ธฐ๊ธฐ
12
  warnings.filterwarnings('ignore', category=UserWarning)
13
 
14
+ # ์ €์žฅ ๋””๋ ‰ํ† ๋ฆฌ ์ƒ์„ฑ
15
  SAVE_DIR = "saved_images"
16
  if not os.path.exists(SAVE_DIR):
17
  os.makedirs(SAVE_DIR, exist_ok=True)
18
 
19
+ # ์žฅ์น˜ ๋ฐ dtype ์„ค์ •
20
+ dtype = torch.float32 if torch.cuda.is_available() else torch.float32
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
 
23
  # ๋ชจ๋ธ ๋กœ๋“œ
24
+ pipe = DiffusionPipeline.from_pretrained(
25
+ "black-forest-labs/FLUX.1-schnell",
26
+ torch_dtype=dtype,
27
+ device_map="auto",
28
+ use_safetensors=True
29
+ ).to(device)
30
+
31
+ # ๋ฉ”๋ชจ๋ฆฌ ์ตœ์ ํ™”
32
+ pipe.enable_attention_slicing()
33
+ if device == "cpu":
34
+ pipe.enable_sequential_cpu_offload()
35
 
36
  MAX_SEED = np.iinfo(np.int32).max
37
  MAX_IMAGE_SIZE = 2048
38
 
39
+ def generate_diagram(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4):
40
+ """FLUX AI๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋‹ค์ด์–ด๊ทธ๋žจ ์ƒ์„ฑ"""
41
+ try:
42
+ if randomize_seed:
43
+ seed = random.randint(0, MAX_SEED)
44
+ generator = torch.Generator(device=device).manual_seed(seed)
45
+
46
+ with torch.no_grad():
47
+ image = pipe(
48
+ prompt=prompt,
49
+ width=width,
50
+ height=height,
51
+ num_inference_steps=num_inference_steps,
52
+ generator=generator,
53
+ guidance_scale=0.0
54
+ ).images[0]
55
+
56
+ # ์ด๋ฏธ์ง€ ์ €์žฅ
57
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
58
+ unique_id = str(uuid.uuid4())[:8]
59
+ filename = f"diagram_{timestamp}_{unique_id}.png"
60
+ save_path = os.path.join(SAVE_DIR, filename)
61
+ image.save(save_path)
62
+
63
+ return image, seed
64
+
65
+ except Exception as e:
66
+ raise gr.Error(f"๋‹ค์ด์–ด๊ทธ๋žจ ์ƒ์„ฑ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}")
67
+ finally:
68
+ if torch.cuda.is_available():
69
+ torch.cuda.empty_cache()
70
+
71
+
72
+
73
  # Enhanced examples with more detailed prompts and specific styling
74
  EXAMPLES = [
75
  {
 
299
  except Exception as e:
300
  raise gr.Error(f"๋‹ค์ด์–ด๊ทธ๋žจ ์ƒ์„ฑ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}")
301
 
302
+ # CSS ์Šคํƒ€์ผ
303
+ css="""
304
+ #col-container {
305
+ margin: 0 auto;
306
+ max-width: 520px;
307
+ }
308
+ """
309
+
310
  # Gradio ์ธํ„ฐํŽ˜์ด์Šค ์ƒ์„ฑ
311
+ with gr.Blocks(css=css) as demo:
312
+ with gr.Column(elem_id="col-container"):
313
+ gr.Markdown("""# FLUX ๋‹ค์ด์–ด๊ทธ๋žจ ์ƒ์„ฑ๊ธฐ
314
+ FLUX AI๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์•„๋ฆ„๋‹ค์šด ์†๊ทธ๋ฆผ ์Šคํƒ€์ผ์˜ ๋‹ค์ด์–ด๊ทธ๋žจ์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค
315
+ """)
316
+
317
+ with gr.Row():
318
+ prompt = gr.Text(
319
+ label="ํ”„๋กฌํ”„ํŠธ",
320
+ show_label=False,
321
+ max_lines=1,
322
+ placeholder="๋‹ค์ด์–ด๊ทธ๋žจ ๊ตฌ์กฐ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”...",
323
+ container=False,
324
+ )
325
+ run_button = gr.Button("์ƒ์„ฑ", scale=0)
326
+
327
+ result = gr.Image(label="์ƒ์„ฑ๋œ ๋‹ค์ด์–ด๊ทธ๋žจ", show_label=False)
328
+
329
+ with gr.Accordion("๊ณ ๊ธ‰ ์„ค์ •", open=False):
330
+ seed = gr.Slider(
331
+ label="์‹œ๋“œ",
332
+ minimum=0,
333
+ maximum=MAX_SEED,
334
+ step=1,
335
+ value=0,
336
+ )
337
+
338
+ randomize_seed = gr.Checkbox(label="๋žœ๋ค ์‹œ๋“œ", value=True)
339
+
340
+ with gr.Row():
341
+ width = gr.Slider(
342
+ label="๋„ˆ๋น„",
343
+ minimum=256,
344
+ maximum=MAX_IMAGE_SIZE,
345
+ step=32,
346
+ value=1024,
347
+ )
348
+
349
+ height = gr.Slider(
350
+ label="๋†’์ด",
351
+ minimum=256,
352
+ maximum=MAX_IMAGE_SIZE,
353
+ step=32,
354
+ value=1024,
355
+ )
356
+
357
+ num_inference_steps = gr.Slider(
358
+ label="์ถ”๋ก  ๋‹จ๊ณ„ ์ˆ˜",
359
+ minimum=1,
360
+ maximum=50,
361
+ step=1,
362
+ value=4,
363
+ )
364
+
365
+ # ์˜ˆ์ œ ์ถ”๊ฐ€
366
+ gr.Examples(
367
+ examples=EXAMPLES, # ์ด์ „์— ์ •์˜๋œ ์˜ˆ์ œ๋“ค ์‚ฌ์šฉ
368
+ fn=generate_diagram,
369
+ inputs=[prompt],
370
+ outputs=[result, seed],
371
+ cache_examples=True
372
  )
373
+
374
+ # ์ด๋ฒคํŠธ ํ•ธ๋“ค๋Ÿฌ
375
+ gr.on(
376
+ triggers=[run_button.click, prompt.submit],
377
+ fn=generate_diagram,
378
+ inputs=[prompt, seed, randomize_seed, width, height, num_inference_steps],
379
+ outputs=[result, seed]
380
+ )
 
 
 
 
 
 
381
 
382
  # ์•ฑ ์‹คํ–‰
383
  if __name__ == "__main__":