import os import gradio as gr import numpy as np import random import spaces import torch import json import logging from diffusers import DiffusionPipeline from huggingface_hub import login import time from datetime import datetime from io import BytesIO # from diffusers.models.attention_processor import AttentionProcessor from diffusers.models.attention_processor import AttnProcessor2_0 import torch.nn.functional as F import time import boto3 from io import BytesIO import re import json # 登录 Hugging Face Hub HF_TOKEN = os.environ.get("HF_TOKEN") login(token=HF_TOKEN) import diffusers print(diffusers.__version__) # 初始化 dtype = torch.float16 # 您可以根据需要调整数据类型 device = "cuda" if torch.cuda.is_available() else "cpu" base_model = "black-forest-labs/FLUX.1-dev" # 替换为您的模型 # 加载管道 pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype).to(device) MAX_SEED = 2**32 - 1 class calculateDuration: def __init__(self, activity_name=""): self.activity_name = activity_name def __enter__(self): self.start_time = time.time() return self def __exit__(self, exc_type, exc_value, traceback): self.end_time = time.time() self.elapsed_time = self.end_time - self.start_time if self.activity_name: print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds") else: print(f"Elapsed time: {self.elapsed_time:.6f} seconds") # 生成图像的函数 @spaces.GPU @torch.inference_mode() def generate_image(prompt, steps, seed, cfg_scale, width, height, progress): pipe.to(device) generator = torch.Generator(device=device).manual_seed(seed) with calculateDuration("Generating image"): # Generate image generated_image = pipe( prompt=prompt, num_inference_steps=steps, guidance_scale=cfg_scale, width=width, height=height, generator=generator, ).images[0] progress(99, "Generate success!") return generated_image def upload_image_to_r2(image, account_id, access_key, secret_key, bucket_name): print("upload_image_to_r2", account_id, access_key, secret_key, bucket_name) connectionUrl = f"https://{account_id}.r2.cloudflarestorage.com" s3 = boto3.client( 's3', endpoint_url=connectionUrl, region_name='auto', aws_access_key_id=access_key, aws_secret_access_key=secret_key ) current_time = datetime.now().strftime("%Y/%m/%d/%H%M%S") image_file = f"generated_images/{current_time}_{random.randint(0, MAX_SEED)}.png" buffer = BytesIO() image.save(buffer, "PNG") buffer.seek(0) s3.upload_fileobj(buffer, bucket_name, image_file) print("upload finish", image_file) return image_file def run_lora(prompt, lora_strings_json, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale, upload_to_r2, account_id, access_key, secret_key, bucket, progress=gr.Progress(track_tqdm=True)): # Load LoRA weights if lora_strings_json: try: lora_strings_json = json.loads(lora_strings_json) except: lora_strings_json = None if lora_strings_json: with calculateDuration("Loading LoRA weights"): pipe.unload_lora_weights() adapter_names = [] for lora_info in lora_strings: lora_repo = lora_info.get("repo") weights = lora_info.get("weights") adapter_name = lora_info.get("adapter_name") if lora_repo and weights and adapter_name: # 加载 LoRA 权重 pipe.load_lora_weights(lora_repo, weight_name=weights, adapter_name=adapter_name) adapter_names.append(adapter_name) adapter_weights = [lora_scale] * len(adapter_names) pipe.set_adapters(adapter_names, adapter_weights=adapter_weights) # Set random seed for reproducibility if randomize_seed: with calculateDuration("Set random seed"): seed = random.randint(0, MAX_SEED) # Generate image final_image = generate_image(prompt, steps, seed, cfg_scale, width, height, progress) if final_image: if upload_to_r2: with calculateDuration("Upload image"): url = upload_image_to_r2(final_image, account_id, access_key, secret_key, bucket) result = {"status": "success", "message": "upload image success", "url": url} else: result = {"status": "success", "message": "Image generated but not uploaded"} progress(100, "Completed!") return final_image, seed, json.dumps(result) # Gradio 界面 css=""" #col-container { margin: 0 auto; max-width: 640px; } """ with gr.Blocks(css=css) as demo: gr.Markdown("Flux with LoRA") with gr.Row(): with gr.Column(): prompt = gr.Text(label="Prompt", placeholder="Enter prompt", lines=2) lora_strings_json = gr.Text(label="LoRA Strings (JSON List)", placeholder='[{"repo": "lora_repo1", "weights": "weights1", "adapter_name": "adapter_name1"}, {"repo": "lora_repo2", "weights": "weights2", "adapter_name": "adapter_name2"}]', lines=5) run_button = gr.Button("Run", scale=0) with gr.Accordion("Advanced Settings", open=False): with gr.Row(): seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True) randomize_seed = gr.Checkbox(label="Randomize seed", value=True) lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=1, step=0.01, value=0.5) with gr.Row(): width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=512) height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=512) with gr.Row(): cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=7.5) steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=28) upload_to_r2 = gr.Checkbox(label="Upload to R2", value=False) account_id = gr.Textbox(label="Account Id", placeholder="Enter R2 account id") access_key = gr.Textbox(label="Access Key", placeholder="Enter R2 access key here") secret_key = gr.Textbox(label="Secret Key", placeholder="Enter R2 secret key here") bucket = gr.Textbox(label="Bucket Name", placeholder="Enter R2 bucket name here") with gr.Column(): result = gr.Image(label="Result", show_label=False) seed_output = gr.Text(label="Seed") json_text = gr.Text(label="Result JSON") inputs = [ prompt, lora_strings_json, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale, upload_to_r2, account_id, access_key, secret_key, bucket ] outputs = [result, seed_output, json_text] run_button.click( fn=run_lora, inputs=inputs, outputs=outputs ) demo.queue().launch()