aiqcamp commited on
Commit
0c60b80
·
verified ·
1 Parent(s): e2173ea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -85
app.py CHANGED
@@ -23,7 +23,6 @@ subprocess.run('pip install flash-attn --no-build-isolation',
23
  # Setup and initialization code
24
  cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
25
  PERSISTENT_DIR = os.environ.get("PERSISTENT_DIR", ".")
26
- gallery_path = path.join(PERSISTENT_DIR, "gallery")
27
 
28
  os.environ["TRANSFORMERS_CACHE"] = cache_path
29
  os.environ["HF_HUB_CACHE"] = cache_path
@@ -31,10 +30,6 @@ os.environ["HF_HOME"] = cache_path
31
 
32
  torch.backends.cuda.matmul.allow_tf32 = True
33
 
34
- # Create gallery directory
35
- if not path.exists(gallery_path):
36
- os.makedirs(gallery_path, exist_ok=True)
37
-
38
  # Florence 모델 초기화
39
  florence_models = {
40
  'gokaygokay/Florence-2-Flux-Large': AutoModelForCausalLM.from_pretrained(
@@ -63,12 +58,6 @@ def filter_prompt(prompt):
63
  "sex"
64
  ]
65
 
66
- # "nude", "naked", "nsfw", "porn", "sex", "explicit", "adult", "xxx",
67
- # "erotic", "sensual", "seductive", "provocative", "intimate",
68
- # "violence", "gore", "blood", "death", "kill", "murder", "torture",
69
- # "drug", "suicide", "abuse", "hate", "discrimination"
70
- # ]
71
-
72
  prompt_lower = prompt.lower()
73
 
74
  for keyword in inappropriate_keywords:
@@ -107,59 +96,6 @@ pipe.safety_checker = safety_checker.StableDiffusionSafetyChecker.from_pretraine
107
  "CompVis/stable-diffusion-safety-checker"
108
  )
109
 
110
-
111
-
112
- def save_image(image):
113
- """Save the generated image and return the path"""
114
- try:
115
- if not os.path.exists(gallery_path):
116
- try:
117
- os.makedirs(gallery_path, exist_ok=True)
118
- except Exception as e:
119
- print(f"Failed to create gallery directory: {str(e)}")
120
- return None
121
-
122
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
123
- random_suffix = os.urandom(4).hex()
124
- filename = f"generated_{timestamp}_{random_suffix}.png"
125
- filepath = os.path.join(gallery_path, filename)
126
-
127
- try:
128
- if isinstance(image, Image.Image):
129
- image.save(filepath, "PNG", quality=100)
130
- else:
131
- image = Image.fromarray(image)
132
- image.save(filepath, "PNG", quality=100)
133
-
134
- if not os.path.exists(filepath):
135
- print(f"Warning: Failed to verify saved image at {filepath}")
136
- return None
137
-
138
- return filepath
139
- except Exception as e:
140
- print(f"Failed to save image: {str(e)}")
141
- return None
142
-
143
- except Exception as e:
144
- print(f"Error in save_image: {str(e)}")
145
- return None
146
-
147
- def load_gallery():
148
- try:
149
- os.makedirs(gallery_path, exist_ok=True)
150
-
151
- image_files = []
152
- for f in os.listdir(gallery_path):
153
- if f.lower().endswith(('.png', '.jpg', '.jpeg')):
154
- full_path = os.path.join(gallery_path, f)
155
- image_files.append((full_path, os.path.getmtime(full_path)))
156
-
157
- image_files.sort(key=lambda x: x[1], reverse=True)
158
- return [f[0] for f in image_files]
159
- except Exception as e:
160
- print(f"Error loading gallery: {str(e)}")
161
- return []
162
-
163
  @spaces.GPU
164
  def generate_caption(image, model_name='gokaygokay/Florence-2-Flux-Large'):
165
  image = Image.fromarray(image)
@@ -189,7 +125,7 @@ def process_and_save_image(height, width, steps, scales, prompt, seed):
189
  is_safe, filtered_prompt = filter_prompt(prompt)
190
  if not is_safe:
191
  gr.Warning("The prompt contains inappropriate content.")
192
- return None, load_gallery()
193
 
194
  with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), timer("inference"):
195
  try:
@@ -203,14 +139,10 @@ def process_and_save_image(height, width, steps, scales, prompt, seed):
203
  max_sequence_length=256
204
  ).images[0]
205
 
206
- saved_path = save_image(generated_image)
207
- if saved_path is None:
208
- print("Warning: Failed to save generated image")
209
-
210
- return generated_image, load_gallery()
211
  except Exception as e:
212
  print(f"Error in image generation: {str(e)}")
213
- return None, load_gallery()
214
 
215
  def get_random_seed():
216
  return torch.randint(0, 1000000, (1,)).item()
@@ -342,7 +274,6 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
342
  "3.jpg",
343
  "1.jpg",
344
  "4.jpg",
345
-
346
  ]
347
  gr.Examples(
348
  examples=example_images,
@@ -429,18 +360,6 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
429
  label="Generated Image",
430
  elem_classes=["output-image"]
431
  )
432
-
433
- gallery = gr.Gallery(
434
- label="Generated Images Gallery",
435
- show_label=True,
436
- columns=[4],
437
- rows=[2],
438
- height="auto",
439
- object_fit="cover",
440
- elem_classes=["gallery-container"]
441
- )
442
-
443
- gallery.value = load_gallery()
444
 
445
  # Event handlers
446
  caption_button.click(
@@ -452,7 +371,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
452
  generate_btn.click(
453
  process_and_save_image,
454
  inputs=[height, width, steps, scales, prompt, seed],
455
- outputs=[output, gallery]
456
  )
457
 
458
  randomize_seed.click(
 
23
  # Setup and initialization code
24
  cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
25
  PERSISTENT_DIR = os.environ.get("PERSISTENT_DIR", ".")
 
26
 
27
  os.environ["TRANSFORMERS_CACHE"] = cache_path
28
  os.environ["HF_HUB_CACHE"] = cache_path
 
30
 
31
  torch.backends.cuda.matmul.allow_tf32 = True
32
 
 
 
 
 
33
  # Florence 모델 초기화
34
  florence_models = {
35
  'gokaygokay/Florence-2-Flux-Large': AutoModelForCausalLM.from_pretrained(
 
58
  "sex"
59
  ]
60
 
 
 
 
 
 
 
61
  prompt_lower = prompt.lower()
62
 
63
  for keyword in inappropriate_keywords:
 
96
  "CompVis/stable-diffusion-safety-checker"
97
  )
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  @spaces.GPU
100
  def generate_caption(image, model_name='gokaygokay/Florence-2-Flux-Large'):
101
  image = Image.fromarray(image)
 
125
  is_safe, filtered_prompt = filter_prompt(prompt)
126
  if not is_safe:
127
  gr.Warning("The prompt contains inappropriate content.")
128
+ return None
129
 
130
  with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), timer("inference"):
131
  try:
 
139
  max_sequence_length=256
140
  ).images[0]
141
 
142
+ return generated_image
 
 
 
 
143
  except Exception as e:
144
  print(f"Error in image generation: {str(e)}")
145
+ return None
146
 
147
  def get_random_seed():
148
  return torch.randint(0, 1000000, (1,)).item()
 
274
  "3.jpg",
275
  "1.jpg",
276
  "4.jpg",
 
277
  ]
278
  gr.Examples(
279
  examples=example_images,
 
360
  label="Generated Image",
361
  elem_classes=["output-image"]
362
  )
 
 
 
 
 
 
 
 
 
 
 
 
363
 
364
  # Event handlers
365
  caption_button.click(
 
371
  generate_btn.click(
372
  process_and_save_image,
373
  inputs=[height, width, steps, scales, prompt, seed],
374
+ outputs=[output]
375
  )
376
 
377
  randomize_seed.click(