import os import gradio as gr import numpy as np import random import spaces import torch import json import logging from diffusers import DiffusionPipeline from huggingface_hub import login import time from datetime import datetime from io import BytesIO import torch.nn.functional as F import time import boto3 from io import BytesIO import re import json # Login Hugging Face Hub HF_TOKEN = os.environ.get("HF_TOKEN") login(token=HF_TOKEN) import diffusers print(diffusers.__version__) # init dtype = torch.float16 # use float16 for fast generate device = "cuda" if torch.cuda.is_available() else "cpu" base_model = "black-forest-labs/FLUX.1-dev" # load pipe pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype).to(device) MAX_SEED = 2**32 - 1 class calculateDuration: def __init__(self, activity_name=""): self.activity_name = activity_name def __enter__(self): self.start_time = time.time() self.start_time_formatted = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(self.start_time)) print(f"Activity: {self.activity_name}, Start time: {self.start_time_formatted}") return self def __exit__(self, exc_type, exc_value, traceback): self.end_time = time.time() self.elapsed_time = self.end_time - self.start_time self.end_time_formatted = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(self.end_time)) if self.activity_name: print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds") else: print(f"Elapsed time: {self.elapsed_time:.6f} seconds") @spaces.GPU(duration=120) @torch.inference_mode() def generate_image(prompt, steps, seed, cfg_scale, width, height, progress): with calculateDuration(f"Make a new generator:${seed}"): pipe.to(device) generator = torch.Generator(device=device).manual_seed(seed) with calculateDuration("Generating image"): # Generate image generated_image = pipe( prompt=prompt, num_inference_steps=steps, guidance_scale=cfg_scale, width=width, height=height, generator=generator, ).images[0] progress(99, "Generate image success!") return generated_image def upload_image_to_r2(image, account_id, access_key, secret_key, bucket_name): print("upload_image_to_r2", account_id, access_key, secret_key, bucket_name) connectionUrl = f"https://{account_id}.r2.cloudflarestorage.com" s3 = boto3.client( 's3', endpoint_url=connectionUrl, region_name='auto', aws_access_key_id=access_key, aws_secret_access_key=secret_key ) current_time = datetime.now().strftime("%Y/%m/%d/%H%M%S") image_file = f"generated_images/{current_time}_{random.randint(0, MAX_SEED)}.png" buffer = BytesIO() image.save(buffer, "PNG") buffer.seek(0) s3.upload_fileobj(buffer, bucket_name, image_file) print("upload finish", image_file) return image_file def run_lora(prompt, lora_strings_json, cfg_scale, steps, randomize_seed, seed, width, height, upload_to_r2, account_id, access_key, secret_key, bucket, progress=gr.Progress(track_tqdm=True)): print("run_lora", prompt, lora_strings_json, cfg_scale, steps, width, height) gr.Info("Starting process") # Load LoRA weights lora_configs = None if lora_strings_json: try: lora_configs = json.loads(lora_strings_json) except: gr.Warning("Parse lora config json failed") print("parse lora config json failed") if lora_configs: with calculateDuration("Loading LoRA weights"): active_adapters = pipe.get_active_adapters() print("get_active_adapters", active_adapters) adapter_names = [] adapter_weights = [] for lora_info in lora_configs: lora_repo = lora_info.get("repo") weights = lora_info.get("weights") adapter_name = lora_info.get("adapter_name") adapter_weight = lora_info.get("adapter_weight") adapter_names.append(adapter_name) adapter_weights.append(adapter_weight) if adapter_name in active_adapters: print(f"Adapter '{adapter_name}' is already loaded, skipping.") continue if lora_repo and weights and adapter_name: # load lora try: pipe.load_lora_weights(lora_repo, weight_name=weights, adapter_name=adapter_name) except ValueError as e: print(f"Error loading LoRA adapter: {e}") continue # set lora weights if len(adapter_names) > 0: pipe.set_adapters(adapter_names, adapter_weights=adapter_weights) # Set random seed for reproducibility if randomize_seed: with calculateDuration("Set random seed"): seed = random.randint(0, MAX_SEED) # Generate image error_message = "" try: print("Start applying for zeroGPU resources") final_image = generate_image(prompt, steps, seed, cfg_scale, width, height, progress) except Exception as e: error_message = str(e) gr.Error(error_message) print("Run error", e) final_image = None if final_image: if upload_to_r2: with calculateDuration("Upload image"): url = upload_image_to_r2(final_image, account_id, access_key, secret_key, bucket) result = {"status": "success", "message": "upload image success", "url": url} else: result = {"status": "success", "message": "Image generated but not uploaded"} else: result = {"status": "failed", "message": error_message} gr.Info("Completed!") progress(100, "Completed!") return final_image, seed, json.dumps(result) # Gradio interface css=""" #col-container { margin: 0 auto; max-width: 640px; } """ with gr.Blocks(css=css) as demo: gr.Markdown("flux-dev-multi-lora") with gr.Row(): with gr.Column(): prompt = gr.Text(label="Prompt", placeholder="Enter prompt", lines=10) lora_strings_json = gr.Text(label="LoRA Configs (JSON List String)", placeholder='[{"repo": "lora_repo1", "weights": "weights1", "adapter_name": "adapter_name1", "adapter_weight": 1}, {"repo": "lora_repo2", "weights": "weights2", "adapter_name": "adapter_name2", "adapter_weight": 1}]', lines=5) run_button = gr.Button("Run", scale=0) with gr.Accordion("Advanced Settings", open=False): with gr.Row(): seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True) randomize_seed = gr.Checkbox(label="Randomize seed", value=True) with gr.Row(): width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1024) height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024) with gr.Row(): cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=3.5) steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=28) upload_to_r2 = gr.Checkbox(label="Upload to R2", value=False) account_id = gr.Textbox(label="Account Id", placeholder="Enter R2 account id") access_key = gr.Textbox(label="Access Key", placeholder="Enter R2 access key here") secret_key = gr.Textbox(label="Secret Key", placeholder="Enter R2 secret key here") bucket = gr.Textbox(label="Bucket Name", placeholder="Enter R2 bucket name here") with gr.Column(): result = gr.Image(label="Result", show_label=False) seed_output = gr.Text(label="Seed") json_text = gr.Text(label="Result JSON") inputs = [ prompt, lora_strings_json, cfg_scale, steps, randomize_seed, seed, width, height, upload_to_r2, account_id, access_key, secret_key, bucket ] outputs = [result, seed_output, json_text] run_button.click( fn=run_lora, inputs=inputs, outputs=outputs ) demo.queue().launch()