aiqcamp commited on
Commit
e2173ea
·
verified ·
1 Parent(s): bf2703f

Create app-backup.py

Browse files
Files changed (1) hide show
  1. app-backup.py +469 -0
app-backup.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import argparse
3
+ import os
4
+ import time
5
+ from os import path
6
+ import shutil
7
+ from datetime import datetime
8
+ from safetensors.torch import load_file
9
+ from huggingface_hub import hf_hub_download
10
+ import gradio as gr
11
+ import torch
12
+ from diffusers import FluxPipeline
13
+ from diffusers.pipelines.stable_diffusion import safety_checker
14
+ from PIL import Image
15
+ from transformers import AutoProcessor, AutoModelForCausalLM
16
+ import subprocess
17
+
18
+ # Flash Attention 설치
19
+ subprocess.run('pip install flash-attn --no-build-isolation',
20
+ env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"},
21
+ shell=True)
22
+
23
+ # Setup and initialization code
24
+ cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
25
+ PERSISTENT_DIR = os.environ.get("PERSISTENT_DIR", ".")
26
+ gallery_path = path.join(PERSISTENT_DIR, "gallery")
27
+
28
+ os.environ["TRANSFORMERS_CACHE"] = cache_path
29
+ os.environ["HF_HUB_CACHE"] = cache_path
30
+ os.environ["HF_HOME"] = cache_path
31
+
32
+ torch.backends.cuda.matmul.allow_tf32 = True
33
+
34
+ # Create gallery directory
35
+ if not path.exists(gallery_path):
36
+ os.makedirs(gallery_path, exist_ok=True)
37
+
38
+ # Florence 모델 초기화
39
+ florence_models = {
40
+ 'gokaygokay/Florence-2-Flux-Large': AutoModelForCausalLM.from_pretrained(
41
+ 'gokaygokay/Florence-2-Flux-Large',
42
+ trust_remote_code=True
43
+ ).eval(),
44
+ 'gokaygokay/Florence-2-Flux': AutoModelForCausalLM.from_pretrained(
45
+ 'gokaygokay/Florence-2-Flux',
46
+ trust_remote_code=True
47
+ ).eval(),
48
+ }
49
+
50
+ florence_processors = {
51
+ 'gokaygokay/Florence-2-Flux-Large': AutoProcessor.from_pretrained(
52
+ 'gokaygokay/Florence-2-Flux-Large',
53
+ trust_remote_code=True
54
+ ),
55
+ 'gokaygokay/Florence-2-Flux': AutoProcessor.from_pretrained(
56
+ 'gokaygokay/Florence-2-Flux',
57
+ trust_remote_code=True
58
+ ),
59
+ }
60
+
61
+ def filter_prompt(prompt):
62
+ inappropriate_keywords = [
63
+ "sex"
64
+ ]
65
+
66
+ # "nude", "naked", "nsfw", "porn", "sex", "explicit", "adult", "xxx",
67
+ # "erotic", "sensual", "seductive", "provocative", "intimate",
68
+ # "violence", "gore", "blood", "death", "kill", "murder", "torture",
69
+ # "drug", "suicide", "abuse", "hate", "discrimination"
70
+ # ]
71
+
72
+ prompt_lower = prompt.lower()
73
+
74
+ for keyword in inappropriate_keywords:
75
+ if keyword in prompt_lower:
76
+ return False, "부적절한 내용이 포함된 프롬프트입니다."
77
+
78
+ return True, prompt
79
+
80
+ class timer:
81
+ def __init__(self, method_name="timed process"):
82
+ self.method = method_name
83
+ def __enter__(self):
84
+ self.start = time.time()
85
+ print(f"{self.method} starts")
86
+ def __exit__(self, exc_type, exc_val, exc_tb):
87
+ end = time.time()
88
+ print(f"{self.method} took {str(round(end - self.start, 2))}s")
89
+
90
+ # Model initialization
91
+ if not path.exists(cache_path):
92
+ os.makedirs(cache_path, exist_ok=True)
93
+
94
+ pipe = FluxPipeline.from_pretrained(
95
+ "black-forest-labs/FLUX.1-dev",
96
+ torch_dtype=torch.bfloat16
97
+ )
98
+ pipe.load_lora_weights(
99
+ hf_hub_download(
100
+ "ByteDance/Hyper-SD",
101
+ "Hyper-FLUX.1-dev-8steps-lora.safetensors"
102
+ )
103
+ )
104
+ pipe.fuse_lora(lora_scale=0.125)
105
+ pipe.to(device="cuda", dtype=torch.bfloat16)
106
+ pipe.safety_checker = safety_checker.StableDiffusionSafetyChecker.from_pretrained(
107
+ "CompVis/stable-diffusion-safety-checker"
108
+ )
109
+
110
+
111
+
112
+ def save_image(image):
113
+ """Save the generated image and return the path"""
114
+ try:
115
+ if not os.path.exists(gallery_path):
116
+ try:
117
+ os.makedirs(gallery_path, exist_ok=True)
118
+ except Exception as e:
119
+ print(f"Failed to create gallery directory: {str(e)}")
120
+ return None
121
+
122
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
123
+ random_suffix = os.urandom(4).hex()
124
+ filename = f"generated_{timestamp}_{random_suffix}.png"
125
+ filepath = os.path.join(gallery_path, filename)
126
+
127
+ try:
128
+ if isinstance(image, Image.Image):
129
+ image.save(filepath, "PNG", quality=100)
130
+ else:
131
+ image = Image.fromarray(image)
132
+ image.save(filepath, "PNG", quality=100)
133
+
134
+ if not os.path.exists(filepath):
135
+ print(f"Warning: Failed to verify saved image at {filepath}")
136
+ return None
137
+
138
+ return filepath
139
+ except Exception as e:
140
+ print(f"Failed to save image: {str(e)}")
141
+ return None
142
+
143
+ except Exception as e:
144
+ print(f"Error in save_image: {str(e)}")
145
+ return None
146
+
147
+ def load_gallery():
148
+ try:
149
+ os.makedirs(gallery_path, exist_ok=True)
150
+
151
+ image_files = []
152
+ for f in os.listdir(gallery_path):
153
+ if f.lower().endswith(('.png', '.jpg', '.jpeg')):
154
+ full_path = os.path.join(gallery_path, f)
155
+ image_files.append((full_path, os.path.getmtime(full_path)))
156
+
157
+ image_files.sort(key=lambda x: x[1], reverse=True)
158
+ return [f[0] for f in image_files]
159
+ except Exception as e:
160
+ print(f"Error loading gallery: {str(e)}")
161
+ return []
162
+
163
+ @spaces.GPU
164
+ def generate_caption(image, model_name='gokaygokay/Florence-2-Flux-Large'):
165
+ image = Image.fromarray(image)
166
+ task_prompt = "<DESCRIPTION>"
167
+ prompt = task_prompt + "Describe this image in great detail."
168
+
169
+ if image.mode != "RGB":
170
+ image = image.convert("RGB")
171
+
172
+ model = florence_models[model_name]
173
+ processor = florence_processors[model_name]
174
+
175
+ inputs = processor(text=prompt, images=image, return_tensors="pt")
176
+ generated_ids = model.generate(
177
+ input_ids=inputs["input_ids"],
178
+ pixel_values=inputs["pixel_values"],
179
+ max_new_tokens=1024,
180
+ num_beams=3,
181
+ repetition_penalty=1.10,
182
+ )
183
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
184
+ parsed_answer = processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.width, image.height))
185
+ return parsed_answer["<DESCRIPTION>"]
186
+
187
+ @spaces.GPU
188
+ def process_and_save_image(height, width, steps, scales, prompt, seed):
189
+ is_safe, filtered_prompt = filter_prompt(prompt)
190
+ if not is_safe:
191
+ gr.Warning("The prompt contains inappropriate content.")
192
+ return None, load_gallery()
193
+
194
+ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), timer("inference"):
195
+ try:
196
+ generated_image = pipe(
197
+ prompt=[filtered_prompt],
198
+ generator=torch.Generator().manual_seed(int(seed)),
199
+ num_inference_steps=int(steps),
200
+ guidance_scale=float(scales),
201
+ height=int(height),
202
+ width=int(width),
203
+ max_sequence_length=256
204
+ ).images[0]
205
+
206
+ saved_path = save_image(generated_image)
207
+ if saved_path is None:
208
+ print("Warning: Failed to save generated image")
209
+
210
+ return generated_image, load_gallery()
211
+ except Exception as e:
212
+ print(f"Error in image generation: {str(e)}")
213
+ return None, load_gallery()
214
+
215
+ def get_random_seed():
216
+ return torch.randint(0, 1000000, (1,)).item()
217
+
218
+ def update_seed():
219
+ return get_random_seed()
220
+
221
+ # CSS 스타일
222
+ css = """
223
+ footer {display: none !important}
224
+ .gradio-container {
225
+ max-width: 1200px;
226
+ margin: auto;
227
+ }
228
+ .contain {
229
+ background: rgba(255, 255, 255, 0.05);
230
+ border-radius: 12px;
231
+ padding: 20px;
232
+ }
233
+ .generate-btn {
234
+ background: linear-gradient(90deg, #4B79A1 0%, #283E51 100%) !important;
235
+ border: none !important;
236
+ color: white !important;
237
+ }
238
+ .generate-btn:hover {
239
+ transform: translateY(-2px);
240
+ box-shadow: 0 5px 15px rgba(0,0,0,0.2);
241
+ }
242
+ .title {
243
+ text-align: center;
244
+ font-size: 2.5em;
245
+ font-weight: bold;
246
+ margin-bottom: 1em;
247
+ background: linear-gradient(90deg, #4B79A1 0%, #283E51 100%);
248
+ -webkit-background-clip: text;
249
+ -webkit-text-fill-color: transparent;
250
+ }
251
+ .tabs {
252
+ margin-top: 20px;
253
+ border-radius: 10px;
254
+ overflow: hidden;
255
+ }
256
+ .tab-nav {
257
+ background: linear-gradient(90deg, #4B79A1 0%, #283E51 100%);
258
+ padding: 10px;
259
+ }
260
+ .tab-nav button {
261
+ color: white;
262
+ border: none;
263
+ padding: 10px 20px;
264
+ margin: 0 5px;
265
+ border-radius: 5px;
266
+ transition: all 0.3s ease;
267
+ }
268
+ .tab-nav button.selected {
269
+ background: rgba(255, 255, 255, 0.2);
270
+ }
271
+ .image-upload-container {
272
+ border: 2px dashed #4B79A1;
273
+ border-radius: 10px;
274
+ padding: 20px;
275
+ text-align: center;
276
+ transition: all 0.3s ease;
277
+ }
278
+ .image-upload-container:hover {
279
+ border-color: #283E51;
280
+ background: rgba(75, 121, 161, 0.1);
281
+ }
282
+ .primary-btn {
283
+ background: linear-gradient(90deg, #4B79A1 0%, #283E51 100%) !important;
284
+ font-size: 1.2em !important;
285
+ padding: 12px 20px !important;
286
+ margin-top: 20px !important;
287
+ }
288
+ hr {
289
+ border: none;
290
+ border-top: 1px solid rgba(75, 121, 161, 0.2);
291
+ margin: 20px 0;
292
+ }
293
+ .input-section {
294
+ background: rgba(255, 255, 255, 0.03);
295
+ border-radius: 12px;
296
+ padding: 20px;
297
+ margin-bottom: 20px;
298
+ }
299
+ .output-section {
300
+ background: rgba(255, 255, 255, 0.03);
301
+ border-radius: 12px;
302
+ padding: 20px;
303
+ }
304
+ .example-images {
305
+ display: grid;
306
+ grid-template-columns: repeat(4, 1fr);
307
+ gap: 10px;
308
+ margin-bottom: 20px;
309
+ }
310
+ .example-images img {
311
+ width: 100%;
312
+ height: 150px;
313
+ object-fit: cover;
314
+ border-radius: 8px;
315
+ cursor: pointer;
316
+ transition: transform 0.2s;
317
+ }
318
+ .example-images img:hover {
319
+ transform: scale(1.05);
320
+ }
321
+ """
322
+
323
+ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
324
+ gr.HTML('<div class="title">FLUX VisionReply</div>')
325
+ gr.HTML('<div style="text-align: center; margin-bottom: 2em;">Upload an image(Image2Text2Image)</div>')
326
+
327
+ with gr.Row():
328
+ # 왼쪽 컬럼: 입력 섹션
329
+ with gr.Column(scale=3):
330
+ # 이미지 업로드 섹션
331
+ input_image = gr.Image(
332
+ label="Upload Image (Optional)",
333
+ type="numpy",
334
+ elem_classes=["image-upload-container"]
335
+ )
336
+
337
+ # 예시 이미지 갤러리 추가
338
+ example_images = [
339
+ "5.jpg",
340
+ "6.jpg",
341
+ "2.jpg",
342
+ "3.jpg",
343
+ "1.jpg",
344
+ "4.jpg",
345
+
346
+ ]
347
+ gr.Examples(
348
+ examples=example_images,
349
+ inputs=input_image,
350
+ label="Example Images",
351
+ examples_per_page=4
352
+ )
353
+
354
+ # Florence 모델 선택 - 숨김 처리
355
+ florence_model = gr.Dropdown(
356
+ choices=list(florence_models.keys()),
357
+ label="Caption Model",
358
+ value='gokaygokay/Florence-2-Flux-Large',
359
+ visible=False
360
+ )
361
+
362
+ caption_button = gr.Button(
363
+ "🔍 Generate Caption from Image",
364
+ elem_classes=["generate-btn"]
365
+ )
366
+
367
+ # 구분선
368
+ gr.HTML('<hr style="margin: 20px 0;">')
369
+
370
+ # 텍스트 프롬프트 섹션
371
+ prompt = gr.Textbox(
372
+ label="Image Description",
373
+ placeholder="Enter text description or use generated caption above...",
374
+ lines=3
375
+ )
376
+
377
+ with gr.Accordion("Advanced Settings", open=False):
378
+ with gr.Row():
379
+ height = gr.Slider(
380
+ label="Height",
381
+ minimum=256,
382
+ maximum=1152,
383
+ step=64,
384
+ value=1024
385
+ )
386
+ width = gr.Slider(
387
+ label="Width",
388
+ minimum=256,
389
+ maximum=1152,
390
+ step=64,
391
+ value=1024
392
+ )
393
+
394
+ with gr.Row():
395
+ steps = gr.Slider(
396
+ label="Inference Steps",
397
+ minimum=6,
398
+ maximum=25,
399
+ step=1,
400
+ value=8
401
+ )
402
+ scales = gr.Slider(
403
+ label="Guidance Scale",
404
+ minimum=0.0,
405
+ maximum=5.0,
406
+ step=0.1,
407
+ value=3.5
408
+ )
409
+
410
+ seed = gr.Number(
411
+ label="Seed",
412
+ value=get_random_seed(),
413
+ precision=0
414
+ )
415
+
416
+ randomize_seed = gr.Button(
417
+ "🎲 Randomize Seed",
418
+ elem_classes=["generate-btn"]
419
+ )
420
+
421
+ generate_btn = gr.Button(
422
+ "✨ Generate Image",
423
+ elem_classes=["generate-btn", "primary-btn"]
424
+ )
425
+
426
+ # 오른쪽 컬럼: 출력 섹션
427
+ with gr.Column(scale=4):
428
+ output = gr.Image(
429
+ label="Generated Image",
430
+ elem_classes=["output-image"]
431
+ )
432
+
433
+ gallery = gr.Gallery(
434
+ label="Generated Images Gallery",
435
+ show_label=True,
436
+ columns=[4],
437
+ rows=[2],
438
+ height="auto",
439
+ object_fit="cover",
440
+ elem_classes=["gallery-container"]
441
+ )
442
+
443
+ gallery.value = load_gallery()
444
+
445
+ # Event handlers
446
+ caption_button.click(
447
+ generate_caption,
448
+ inputs=[input_image, florence_model],
449
+ outputs=[prompt]
450
+ )
451
+
452
+ generate_btn.click(
453
+ process_and_save_image,
454
+ inputs=[height, width, steps, scales, prompt, seed],
455
+ outputs=[output, gallery]
456
+ )
457
+
458
+ randomize_seed.click(
459
+ update_seed,
460
+ outputs=[seed]
461
+ )
462
+
463
+ generate_btn.click(
464
+ update_seed,
465
+ outputs=[seed]
466
+ )
467
+
468
+ if __name__ == "__main__":
469
+ demo.launch(allowed_paths=[PERSISTENT_DIR])