FLUX-Vision / app.py
aiqcamp's picture
Update app.py
0c60b80 verified
import spaces
import argparse
import os
import time
from os import path
import shutil
from datetime import datetime
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
import gradio as gr
import torch
from diffusers import FluxPipeline
from diffusers.pipelines.stable_diffusion import safety_checker
from PIL import Image
from transformers import AutoProcessor, AutoModelForCausalLM
import subprocess
# Flash Attention ์„ค์น˜
subprocess.run('pip install flash-attn --no-build-isolation',
env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"},
shell=True)
# Setup and initialization code
cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
PERSISTENT_DIR = os.environ.get("PERSISTENT_DIR", ".")
os.environ["TRANSFORMERS_CACHE"] = cache_path
os.environ["HF_HUB_CACHE"] = cache_path
os.environ["HF_HOME"] = cache_path
torch.backends.cuda.matmul.allow_tf32 = True
# Florence ๋ชจ๋ธ ์ดˆ๊ธฐํ™”
florence_models = {
'gokaygokay/Florence-2-Flux-Large': AutoModelForCausalLM.from_pretrained(
'gokaygokay/Florence-2-Flux-Large',
trust_remote_code=True
).eval(),
'gokaygokay/Florence-2-Flux': AutoModelForCausalLM.from_pretrained(
'gokaygokay/Florence-2-Flux',
trust_remote_code=True
).eval(),
}
florence_processors = {
'gokaygokay/Florence-2-Flux-Large': AutoProcessor.from_pretrained(
'gokaygokay/Florence-2-Flux-Large',
trust_remote_code=True
),
'gokaygokay/Florence-2-Flux': AutoProcessor.from_pretrained(
'gokaygokay/Florence-2-Flux',
trust_remote_code=True
),
}
def filter_prompt(prompt):
inappropriate_keywords = [
"sex"
]
prompt_lower = prompt.lower()
for keyword in inappropriate_keywords:
if keyword in prompt_lower:
return False, "๋ถ€์ ์ ˆํ•œ ๋‚ด์šฉ์ด ํฌํ•จ๋œ ํ”„๋กฌํ”„ํŠธ์ž…๋‹ˆ๋‹ค."
return True, prompt
class timer:
def __init__(self, method_name="timed process"):
self.method = method_name
def __enter__(self):
self.start = time.time()
print(f"{self.method} starts")
def __exit__(self, exc_type, exc_val, exc_tb):
end = time.time()
print(f"{self.method} took {str(round(end - self.start, 2))}s")
# Model initialization
if not path.exists(cache_path):
os.makedirs(cache_path, exist_ok=True)
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16
)
pipe.load_lora_weights(
hf_hub_download(
"ByteDance/Hyper-SD",
"Hyper-FLUX.1-dev-8steps-lora.safetensors"
)
)
pipe.fuse_lora(lora_scale=0.125)
pipe.to(device="cuda", dtype=torch.bfloat16)
pipe.safety_checker = safety_checker.StableDiffusionSafetyChecker.from_pretrained(
"CompVis/stable-diffusion-safety-checker"
)
@spaces.GPU
def generate_caption(image, model_name='gokaygokay/Florence-2-Flux-Large'):
image = Image.fromarray(image)
task_prompt = "<DESCRIPTION>"
prompt = task_prompt + "Describe this image in great detail."
if image.mode != "RGB":
image = image.convert("RGB")
model = florence_models[model_name]
processor = florence_processors[model_name]
inputs = processor(text=prompt, images=image, return_tensors="pt")
generated_ids = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
num_beams=3,
repetition_penalty=1.10,
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
parsed_answer = processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.width, image.height))
return parsed_answer["<DESCRIPTION>"]
@spaces.GPU
def process_and_save_image(height, width, steps, scales, prompt, seed):
is_safe, filtered_prompt = filter_prompt(prompt)
if not is_safe:
gr.Warning("The prompt contains inappropriate content.")
return None
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), timer("inference"):
try:
generated_image = pipe(
prompt=[filtered_prompt],
generator=torch.Generator().manual_seed(int(seed)),
num_inference_steps=int(steps),
guidance_scale=float(scales),
height=int(height),
width=int(width),
max_sequence_length=256
).images[0]
return generated_image
except Exception as e:
print(f"Error in image generation: {str(e)}")
return None
def get_random_seed():
return torch.randint(0, 1000000, (1,)).item()
def update_seed():
return get_random_seed()
# CSS ์Šคํƒ€์ผ
css = """
footer {display: none !important}
.gradio-container {
max-width: 1200px;
margin: auto;
}
.contain {
background: rgba(255, 255, 255, 0.05);
border-radius: 12px;
padding: 20px;
}
.generate-btn {
background: linear-gradient(90deg, #4B79A1 0%, #283E51 100%) !important;
border: none !important;
color: white !important;
}
.generate-btn:hover {
transform: translateY(-2px);
box-shadow: 0 5px 15px rgba(0,0,0,0.2);
}
.title {
text-align: center;
font-size: 2.5em;
font-weight: bold;
margin-bottom: 1em;
background: linear-gradient(90deg, #4B79A1 0%, #283E51 100%);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
}
.tabs {
margin-top: 20px;
border-radius: 10px;
overflow: hidden;
}
.tab-nav {
background: linear-gradient(90deg, #4B79A1 0%, #283E51 100%);
padding: 10px;
}
.tab-nav button {
color: white;
border: none;
padding: 10px 20px;
margin: 0 5px;
border-radius: 5px;
transition: all 0.3s ease;
}
.tab-nav button.selected {
background: rgba(255, 255, 255, 0.2);
}
.image-upload-container {
border: 2px dashed #4B79A1;
border-radius: 10px;
padding: 20px;
text-align: center;
transition: all 0.3s ease;
}
.image-upload-container:hover {
border-color: #283E51;
background: rgba(75, 121, 161, 0.1);
}
.primary-btn {
background: linear-gradient(90deg, #4B79A1 0%, #283E51 100%) !important;
font-size: 1.2em !important;
padding: 12px 20px !important;
margin-top: 20px !important;
}
hr {
border: none;
border-top: 1px solid rgba(75, 121, 161, 0.2);
margin: 20px 0;
}
.input-section {
background: rgba(255, 255, 255, 0.03);
border-radius: 12px;
padding: 20px;
margin-bottom: 20px;
}
.output-section {
background: rgba(255, 255, 255, 0.03);
border-radius: 12px;
padding: 20px;
}
.example-images {
display: grid;
grid-template-columns: repeat(4, 1fr);
gap: 10px;
margin-bottom: 20px;
}
.example-images img {
width: 100%;
height: 150px;
object-fit: cover;
border-radius: 8px;
cursor: pointer;
transition: transform 0.2s;
}
.example-images img:hover {
transform: scale(1.05);
}
"""
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
gr.HTML('<div class="title">FLUX VisionReply</div>')
gr.HTML('<div style="text-align: center; margin-bottom: 2em;">Upload an image(Image2Text2Image)</div>')
with gr.Row():
# ์™ผ์ชฝ ์ปฌ๋Ÿผ: ์ž…๋ ฅ ์„น์…˜
with gr.Column(scale=3):
# ์ด๋ฏธ์ง€ ์—…๋กœ๋“œ ์„น์…˜
input_image = gr.Image(
label="Upload Image (Optional)",
type="numpy",
elem_classes=["image-upload-container"]
)
# ์˜ˆ์‹œ ์ด๋ฏธ์ง€ ๊ฐค๋Ÿฌ๋ฆฌ ์ถ”๊ฐ€
example_images = [
"5.jpg",
"6.jpg",
"2.jpg",
"3.jpg",
"1.jpg",
"4.jpg",
]
gr.Examples(
examples=example_images,
inputs=input_image,
label="Example Images",
examples_per_page=4
)
# Florence ๋ชจ๋ธ ์„ ํƒ - ์ˆจ๊น€ ์ฒ˜๋ฆฌ
florence_model = gr.Dropdown(
choices=list(florence_models.keys()),
label="Caption Model",
value='gokaygokay/Florence-2-Flux-Large',
visible=False
)
caption_button = gr.Button(
"๐Ÿ” Generate Caption from Image",
elem_classes=["generate-btn"]
)
# ๊ตฌ๋ถ„์„ 
gr.HTML('<hr style="margin: 20px 0;">')
# ํ…์ŠคํŠธ ํ”„๋กฌํ”„ํŠธ ์„น์…˜
prompt = gr.Textbox(
label="Image Description",
placeholder="Enter text description or use generated caption above...",
lines=3
)
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
height = gr.Slider(
label="Height",
minimum=256,
maximum=1152,
step=64,
value=1024
)
width = gr.Slider(
label="Width",
minimum=256,
maximum=1152,
step=64,
value=1024
)
with gr.Row():
steps = gr.Slider(
label="Inference Steps",
minimum=6,
maximum=25,
step=1,
value=8
)
scales = gr.Slider(
label="Guidance Scale",
minimum=0.0,
maximum=5.0,
step=0.1,
value=3.5
)
seed = gr.Number(
label="Seed",
value=get_random_seed(),
precision=0
)
randomize_seed = gr.Button(
"๐ŸŽฒ Randomize Seed",
elem_classes=["generate-btn"]
)
generate_btn = gr.Button(
"โœจ Generate Image",
elem_classes=["generate-btn", "primary-btn"]
)
# ์˜ค๋ฅธ์ชฝ ์ปฌ๋Ÿผ: ์ถœ๋ ฅ ์„น์…˜
with gr.Column(scale=4):
output = gr.Image(
label="Generated Image",
elem_classes=["output-image"]
)
# Event handlers
caption_button.click(
generate_caption,
inputs=[input_image, florence_model],
outputs=[prompt]
)
generate_btn.click(
process_and_save_image,
inputs=[height, width, steps, scales, prompt, seed],
outputs=[output]
)
randomize_seed.click(
update_seed,
outputs=[seed]
)
generate_btn.click(
update_seed,
outputs=[seed]
)
if __name__ == "__main__":
demo.launch(allowed_paths=[PERSISTENT_DIR])