jiuface's picture
bugfix
422bc49
raw
history blame
7.84 kB
import os
import gradio as gr
import numpy as np
import random
import spaces
import torch
import json
import logging
from diffusers import DiffusionPipeline
from huggingface_hub import login
import time
from datetime import datetime
from io import BytesIO
# from diffusers.models.attention_processor import AttentionProcessor
from diffusers.models.attention_processor import AttnProcessor2_0
import torch.nn.functional as F
import time
import boto3
from io import BytesIO
import re
import json
# 登录 Hugging Face Hub
HF_TOKEN = os.environ.get("HF_TOKEN")
login(token=HF_TOKEN)
import diffusers
print(diffusers.__version__)
# 初始化
dtype = torch.float16 # 您可以根据需要调整数据类型
device = "cuda" if torch.cuda.is_available() else "cpu"
base_model = "black-forest-labs/FLUX.1-dev" # 替换为您的模型
# 加载管道
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype).to(device)
MAX_SEED = 2**32 - 1
class calculateDuration:
def __init__(self, activity_name=""):
self.activity_name = activity_name
def __enter__(self):
self.start_time = time.time()
self.start_time_formatted = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(self.start_time))
print(f"Activity: {self.activity_name}, Start time: {self.start_time_formatted}")
return self
def __exit__(self, exc_type, exc_value, traceback):
self.end_time = time.time()
self.elapsed_time = self.end_time - self.start_time
self.end_time_formatted = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(self.end_time))
if self.activity_name:
print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds")
else:
print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
print(f"Activity: {self.activity_name}, End time: {self.start_time_formatted}")
# 生成图像的函数
@spaces.GPU
@torch.inference_mode()
def generate_image(prompt, steps, seed, cfg_scale, width, height, progress):
pipe.to(device)
generator = torch.Generator(device=device).manual_seed(seed)
with calculateDuration("Generating image"):
# Generate image
print(prompt, steps, seed, cfg_scale, width, height)
generated_image = pipe(
prompt=prompt,
num_inference_steps=steps,
guidance_scale=cfg_scale,
width=width,
height=height,
generator=generator,
).images[0]
progress(99, "Generate success!")
return generated_image
def upload_image_to_r2(image, account_id, access_key, secret_key, bucket_name):
print("upload_image_to_r2", account_id, access_key, secret_key, bucket_name)
connectionUrl = f"https://{account_id}.r2.cloudflarestorage.com"
s3 = boto3.client(
's3',
endpoint_url=connectionUrl,
region_name='auto',
aws_access_key_id=access_key,
aws_secret_access_key=secret_key
)
current_time = datetime.now().strftime("%Y/%m/%d/%H%M%S")
image_file = f"generated_images/{current_time}_{random.randint(0, MAX_SEED)}.png"
buffer = BytesIO()
image.save(buffer, "PNG")
buffer.seek(0)
s3.upload_fileobj(buffer, bucket_name, image_file)
print("upload finish", image_file)
return image_file
def run_lora(prompt, lora_strings_json, cfg_scale, steps, randomize_seed, seed, width, height, upload_to_r2, account_id, access_key, secret_key, bucket, progress=gr.Progress(track_tqdm=True)):
# Load LoRA weights
if lora_strings_json:
try:
lora_configs = json.loads(lora_strings_json)
except:
lora_configs = None
if lora_configs:
with calculateDuration("Loading LoRA weights"):
pipe.unload_lora_weights()
adapter_names = []
adapter_weights = []
for lora_info in lora_configs:
lora_repo = lora_info.get("repo")
weights = lora_info.get("weights")
adapter_name = lora_info.get("adapter_name")
adapter_weight = lora_info.get("adapter_weight")
if lora_repo and weights and adapter_name:
# 加载 LoRA 权重
pipe.load_lora_weights(lora_repo, weight_name=weights, adapter_name=adapter_name)
adapter_names.append(adapter_name)
adapter_weights.append(adapter_weight)
pipe.set_adapters(adapter_names, adapter_weights=adapter_weights)
# Set random seed for reproducibility
if randomize_seed:
with calculateDuration("Set random seed"):
seed = random.randint(0, MAX_SEED)
# Generate image
final_image = generate_image(prompt, steps, seed, cfg_scale, width, height, progress)
if final_image:
if upload_to_r2:
with calculateDuration("Upload image"):
url = upload_image_to_r2(final_image, account_id, access_key, secret_key, bucket)
result = {"status": "success", "message": "upload image success", "url": url}
else:
result = {"status": "success", "message": "Image generated but not uploaded"}
progress(100, "Completed!")
return final_image, seed, json.dumps(result)
# Gradio 界面
css="""
#col-container {
margin: 0 auto;
max-width: 640px;
}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown("Flux with LoRA")
with gr.Row():
with gr.Column():
prompt = gr.Text(label="Prompt", placeholder="Enter prompt", lines=2)
lora_strings_json = gr.Text(label="LoRA Strings (JSON List)", placeholder='[{"repo": "lora_repo1", "weights": "weights1", "adapter_name": "adapter_name1"}, {"repo": "lora_repo2", "weights": "weights2", "adapter_name": "adapter_name2"}]', lines=5)
run_button = gr.Button("Run", scale=0)
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1024)
height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024)
with gr.Row():
cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=3.5)
steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=28)
upload_to_r2 = gr.Checkbox(label="Upload to R2", value=False)
account_id = gr.Textbox(label="Account Id", placeholder="Enter R2 account id")
access_key = gr.Textbox(label="Access Key", placeholder="Enter R2 access key here")
secret_key = gr.Textbox(label="Secret Key", placeholder="Enter R2 secret key here")
bucket = gr.Textbox(label="Bucket Name", placeholder="Enter R2 bucket name here")
with gr.Column():
result = gr.Image(label="Result", show_label=False)
seed_output = gr.Text(label="Seed")
json_text = gr.Text(label="Result JSON")
inputs = [
prompt,
lora_strings_json,
cfg_scale,
steps,
randomize_seed,
seed,
width,
height,
upload_to_r2,
account_id,
access_key,
secret_key,
bucket
]
outputs = [result, seed_output, json_text]
run_button.click(
fn=run_lora,
inputs=inputs,
outputs=outputs
)
demo.queue().launch()