Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import gradio as gr | |
import numpy as np | |
import random | |
import torch | |
import json | |
import logging | |
from diffusers import DiffusionPipeline | |
from huggingface_hub import login | |
import time | |
from datetime import datetime | |
from io import BytesIO | |
from diffusers.models.attention_processor import AttentionProcessor | |
import re | |
import json | |
# 登录 Hugging Face Hub | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
login(token=HF_TOKEN) | |
# 初始化 | |
dtype = torch.float16 # 您可以根据需要调整数据类型 | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
base_model = "black-forest-labs/FLUX.1-dev" # 替换为您的模型 | |
# 加载管道 | |
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype).to(device) | |
MAX_SEED = 2**32 - 1 | |
class calculateDuration: | |
def __init__(self, activity_name=""): | |
self.activity_name = activity_name | |
def __enter__(self): | |
self.start_time = time.time() | |
return self | |
def __exit__(self, exc_type, exc_value, traceback): | |
self.end_time = time.time() | |
self.elapsed_time = self.end_time - self.start_time | |
if self.activity_name: | |
print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds") | |
else: | |
print(f"Elapsed time: {self.elapsed_time:.6f} seconds") | |
# 定义位置、偏移和区域的映射 | |
valid_locations = { # x, y in 90*90 | |
'in the center': (45, 45), | |
'on the left': (15, 45), | |
'on the right': (75, 45), | |
'on the top': (45, 15), | |
'on the bottom': (45, 75), | |
'on the top-left': (15, 15), | |
'on the top-right': (75, 15), | |
'on the bottom-left': (15, 75), | |
'on the bottom-right': (75, 75) | |
} | |
valid_offsets = { # x, y in 90*90 | |
'no offset': (0, 0), | |
'slightly to the left': (-10, 0), | |
'slightly to the right': (10, 0), | |
'slightly to the upper': (0, -10), | |
'slightly to the lower': (0, 10), | |
'slightly to the upper-left': (-10, -10), | |
'slightly to the upper-right': (10, -10), | |
'slightly to the lower-left': (-10, 10), | |
'slightly to the lower-right': (10, 10) | |
} | |
valid_areas = { # w, h in 90*90 | |
"a small square area": (50, 50), | |
"a small vertical area": (40, 60), | |
"a small horizontal area": (60, 40), | |
"a medium-sized square area": (60, 60), | |
"a medium-sized vertical area": (50, 80), | |
"a medium-sized horizontal area": (80, 50), | |
"a large square area": (70, 70), | |
"a large vertical area": (60, 90), | |
"a large horizontal area": (90, 60) | |
} | |
# 解析角色位置的函数 | |
def parse_character_position(character_position): | |
# 定义正则表达式模式 | |
location_pattern = '|'.join(re.escape(key) for key in valid_locations.keys()) | |
offset_pattern = '|'.join(re.escape(key) for key in valid_offsets.keys()) | |
area_pattern = '|'.join(re.escape(key) for key in valid_areas.keys()) | |
# 提取位置 | |
location_match = re.search(location_pattern, character_position, re.IGNORECASE) | |
location = location_match.group(0) if location_match else 'in the center' | |
# 提取偏移 | |
offset_match = re.search(offset_pattern, character_position, re.IGNORECASE) | |
offset = offset_match.group(0) if offset_match else 'no offset' | |
# 提取区域 | |
area_match = re.search(area_pattern, character_position, re.IGNORECASE) | |
area = area_match.group(0) if area_match else 'a medium-sized square area' | |
return { | |
'location': location, | |
'offset': offset, | |
'area': area | |
} | |
# 创建掩码的函数 | |
def create_attention_mask(image_width, image_height, location, offset, area): | |
# 图像在生成时通常会被缩放为 90x90,因此先定义一个基础尺寸 | |
base_size = 90 | |
# 获取位置坐标 | |
loc_x, loc_y = valid_locations.get(location, (45, 45)) | |
# 获取偏移量 | |
offset_x, offset_y = valid_offsets.get(offset, (0, 0)) | |
# 获取区域大小 | |
area_width, area_height = valid_areas.get(area, (60, 60)) | |
# 计算最终位置 | |
final_x = loc_x + offset_x | |
final_y = loc_y + offset_y | |
# 将坐标和尺寸映射到实际图像尺寸 | |
scale_x = image_width / base_size | |
scale_y = image_height / base_size | |
center_x = final_x * scale_x | |
center_y = final_y * scale_y | |
width = area_width * scale_x | |
height = area_height * scale_y | |
# 计算左上角和右下角坐标 | |
x_start = int(max(center_x - width / 2, 0)) | |
y_start = int(max(center_y - height / 2, 0)) | |
x_end = int(min(center_x + width / 2, image_width)) | |
y_end = int(min(center_y + height / 2, image_height)) | |
# 创建掩码 | |
mask = torch.zeros((image_height, image_width), dtype=torch.float32, device="cuda") | |
mask[y_start:y_end, x_start:x_end] = 1.0 | |
# 展平成一维 | |
mask_flat = mask.view(-1) # 形状为 (image_height * image_width,) | |
return mask_flat | |
# 自定义注意力处理器 | |
class CustomCrossAttentionProcessor(AttentionProcessor): | |
def __init__(self, masks, embeddings, adapter_names): | |
super().__init__() | |
self.masks = masks # 列表,包含每个角色的掩码 | |
self.embeddings = embeddings # 列表,包含每个角色的嵌入 | |
self.adapter_names = adapter_names # 列表,包含每个角色的 LoRA 适配器名称 | |
def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, **kwargs): | |
# 获取当前的 adapter_name | |
adapter_name = getattr(attn, 'adapter_name', None) | |
if adapter_name is None or adapter_name not in self.adapter_names: | |
# 如果没有 adapter_name,直接执行默认的注意力计算 | |
return super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs) | |
# 查找 adapter_name 对应的索引 | |
idx = self.adapter_names.index(adapter_name) | |
mask = self.masks[idx] | |
# 标准的注意力计算 | |
batch_size, sequence_length, _ = hidden_states.shape | |
query = attn.to_q(hidden_states) | |
key = attn.to_k(encoder_hidden_states) | |
value = attn.to_v(encoder_hidden_states) | |
# 重塑以适应多头注意力 | |
query = query.view(batch_size, -1, attn.heads, attn.head_dim).transpose(1, 2) | |
key = key.view(batch_size, -1, attn.heads, attn.head_dim).transpose(1, 2) | |
value = value.view(batch_size, -1, attn.heads, attn.head_dim).transpose(1, 2) | |
# 计算注意力得分 | |
attention_scores = torch.matmul(query, key.transpose(-1, -2)) * attn.scale | |
# 应用掩码调整注意力得分 | |
# 将 mask 调整为与 attention_scores 兼容的形状 | |
# 假设 key_len 与 mask 的长度一致 | |
mask_expanded = mask.unsqueeze(0).unsqueeze(0).unsqueeze(0) # (1, 1, 1, key_len) | |
# 将掩码应用于 attention_scores | |
attention_scores += mask_expanded * 1e6 # 增强对应位置的注意力 | |
# 计算注意力概率 | |
attention_probs = torch.nn.functional.softmax(attention_scores, dim=-1) | |
# 计算上下文向量 | |
context = torch.matmul(attention_probs, value) | |
# 重塑回原始形状 | |
context = context.transpose(1, 2).reshape(batch_size, -1, attn.heads * attn.head_dim) | |
# 输出投影 | |
hidden_states = attn.to_out(context) | |
return hidden_states | |
# 替换注意力处理器的函数 | |
def replace_attention_processors(pipe, masks, embeddings, adapter_names): | |
custom_processor = CustomCrossAttentionProcessor(masks, embeddings, adapter_names) | |
for name, module in pipe.unet.named_modules(): | |
if hasattr(module, 'attn2'): | |
# 设置 adapter_name 为模块的属性 | |
module.attn2.adapter_name = getattr(module, 'adapter_name', None) | |
module.attn2.processor = custom_processor | |
# 生成图像的函数 | |
def generate_image_with_embeddings(prompt_embeddings, steps, seed, cfg_scale, width, height, progress): | |
pipe.to("cuda") | |
generator = torch.Generator(device="cuda").manual_seed(seed) | |
with calculateDuration("Generating image"): | |
# Generate image | |
generated_image = pipe( | |
prompt_embeds=prompt_embeddings, | |
num_inference_steps=steps, | |
guidance_scale=cfg_scale, | |
width=width, | |
height=height, | |
generator=generator, | |
).images[0] | |
progress(99, "Generate success!") | |
return generated_image | |
# 主函数 | |
def run_lora(prompt_bg, character_prompts_json, character_positions_json, lora_strings_json, prompt_details, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale, upload_to_r2, account_id, access_key, secret_key, bucket, progress=gr.Progress(track_tqdm=True)): | |
# 解析角色提示词、位置和 LoRA 字符串 | |
try: | |
character_prompts = json.loads(character_prompts_json) | |
character_positions = json.loads(character_positions_json) | |
lora_strings = json.loads(lora_strings_json) | |
except json.JSONDecodeError as e: | |
raise ValueError(f"Invalid JSON input: {e}") | |
# 确保提示词、位置和 LoRA 字符串的数量一致 | |
if len(character_prompts) != len(character_positions) or len(character_prompts) != len(lora_strings): | |
raise ValueError("The number of character prompts, positions, and LoRA strings must be the same.") | |
# 角色的数量 | |
num_characters = len(character_prompts) | |
# Load LoRA weights | |
with calculateDuration("Loading LoRA weights"): | |
pipe.unload_lora_weights() | |
adapter_names = [] | |
for lora_info in lora_strings: | |
lora_repo = lora_info.get("repo") | |
weights = lora_info.get("weights") | |
adapter_name = lora_info.get("adapter_name") | |
if lora_repo and weights and adapter_name: | |
# 调用 pipe.load_lora_weights() 方法加载权重 | |
pipe.load_lora_weights(lora_repo, weight_name=weights, adapter_name=adapter_name) | |
adapter_names.append(adapter_name) | |
# 将 adapter_name 设置为模型的属性 | |
setattr(pipe.unet, 'adapter_name', adapter_name) | |
else: | |
raise ValueError("Invalid LoRA string format. Each item must have 'repo', 'weights', and 'adapter_name' keys.") | |
adapter_weights = [lora_scale] * len(adapter_names) | |
# 调用 pipeline.set_adapters 方法设置 adapter 和对应权重 | |
pipe.set_adapters(adapter_names, adapter_weights=adapter_weights) | |
# 确保 adapter_names 的数量与角色数量匹配 | |
if len(adapter_names) != num_characters: | |
raise ValueError("The number of LoRA adapters must match the number of characters.") | |
# Set random seed for reproducibility | |
if randomize_seed: | |
with calculateDuration("Set random seed"): | |
seed = random.randint(0, MAX_SEED) | |
# 编码提示词 | |
with calculateDuration("Encoding prompts"): | |
# 编码背景提示词 | |
bg_text_input = pipe.tokenizer(prompt_bg, return_tensors="pt").to("cuda") | |
bg_embeddings = pipe.text_encoder(bg_text_input.input_ids.to(device))[0] | |
# 编码角色提示词 | |
character_embeddings = [] | |
for prompt in character_prompts: | |
char_text_input = pipe.tokenizer(prompt, return_tensors="pt").to("cuda") | |
char_embeddings = pipe.text_encoder(char_text_input.input_ids.to(device))[0] | |
character_embeddings.append(char_embeddings) | |
# 编码互动细节提示词 | |
details_text_input = pipe.tokenizer(prompt_details, return_tensors="pt").to("cuda") | |
details_embeddings = pipe.text_encoder(details_text_input.input_ids.to(device))[0] | |
# 合并背景和互动细节的嵌入 | |
prompt_embeddings = torch.cat([bg_embeddings, details_embeddings], dim=1) | |
# 解析角色位置 | |
character_infos = [] | |
for position_str in character_positions: | |
info = parse_character_position(position_str) | |
character_infos.append(info) | |
# 创建角色的掩码 | |
masks = [] | |
for info in character_infos: | |
mask = create_attention_mask(width, height, info['location'], info['offset'], info['area']) | |
masks.append(mask) | |
# 替换注意力处理器 | |
replace_attention_processors(pipe, masks, character_embeddings, adapter_names) | |
# Generate image | |
final_image = generate_image_with_embeddings(prompt_embeddings, steps, seed, cfg_scale, width, height, progress) | |
# 您可以在此处添加上传图片的代码 | |
result = {"status": "success", "message": "Image generated"} | |
progress(100, "Completed!") | |
return final_image, seed, json.dumps(result) | |
# Gradio 界面 | |
css=""" | |
#col-container { | |
margin: 0 auto; | |
max-width: 640px; | |
} | |
""" | |
with gr.Blocks(css=css) as demo: | |
gr.Markdown("Flux with LoRA") | |
with gr.Row(): | |
with gr.Column(): | |
prompt_bg = gr.Text(label="Background Prompt", placeholder="Enter background/scene prompt", lines=2) | |
character_prompts = gr.Text(label="Character Prompts (JSON List)", placeholder='["Character 1 prompt", "Character 2 prompt"]', lines=5) | |
character_positions = gr.Text(label="Character Positions (JSON List)", placeholder='["Character 1 position", "Character 2 position"]', lines=5) | |
lora_strings_json = gr.Text(label="LoRA Strings (JSON List)", placeholder='[{"repo": "lora_repo1", "weights": "weights1", "adapter_name": "adapter_name1"}, {"repo": "lora_repo2", "weights": "weights2", "adapter_name": "adapter_name2"}]', lines=5) | |
prompt_details = gr.Text(label="Interaction Details", placeholder="Enter interaction details between characters", lines=2) | |
run_button = gr.Button("Run", scale=0) | |
with gr.Accordion("Advanced Settings", open=False): | |
with gr.Row(): | |
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True) | |
randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=1, step=0.01, value=0.5) | |
with gr.Row(): | |
width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=512) | |
height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=512) | |
with gr.Row(): | |
cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=7.5) | |
steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=28) | |
upload_to_r2 = gr.Checkbox(label="Upload to R2", value=False) | |
account_id = gr.Textbox(label="Account Id", placeholder="Enter R2 account id") | |
access_key = gr.Textbox(label="Access Key", placeholder="Enter R2 access key here") | |
secret_key = gr.Textbox(label="Secret Key", placeholder="Enter R2 secret key here") | |
bucket = gr.Textbox(label="Bucket Name", placeholder="Enter R2 bucket name here") | |
with gr.Column(): | |
result = gr.Image(label="Result", show_label=False) | |
seed_output = gr.Text(label="Seed") | |
json_text = gr.Text(label="Result JSON") | |
inputs = [ | |
prompt_bg, | |
character_prompts, | |
character_positions, | |
lora_strings_json, | |
prompt_details, | |
cfg_scale, | |
steps, | |
randomize_seed, | |
seed, | |
width, | |
height, | |
lora_scale, | |
upload_to_r2, | |
account_id, | |
access_key, | |
secret_key, | |
bucket | |
] | |
outputs = [result, seed_output, json_text] | |
run_button.click( | |
fn=run_lora, | |
inputs=inputs, | |
outputs=outputs | |
) | |
demo.queue().launch() | |