jiuface's picture
history blame
15.8 kB
import os
import gradio as gr
import numpy as np
import random
import torch
import json
import logging
from diffusers import DiffusionPipeline
from huggingface_hub import login
import time
from datetime import datetime
from io import BytesIO
from diffusers.models.attention_processor import AttentionProcessor
import re
import json
# 登录 Hugging Face Hub
HF_TOKEN = os.environ.get("HF_TOKEN")
# 初始化
dtype = torch.float16 # 您可以根据需要调整数据类型
device = "cuda" if torch.cuda.is_available() else "cpu"
base_model = "black-forest-labs/FLUX.1-dev" # 替换为您的模型
# 加载管道
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype).to(device)
MAX_SEED = 2**32 - 1
class calculateDuration:
def __init__(self, activity_name=""):
self.activity_name = activity_name
def __enter__(self):
self.start_time = time.time()
return self
def __exit__(self, exc_type, exc_value, traceback):
self.end_time = time.time()
self.elapsed_time = self.end_time - self.start_time
if self.activity_name:
print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds")
print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
# 定义位置、偏移和区域的映射
valid_locations = { # x, y in 90*90
'in the center': (45, 45),
'on the left': (15, 45),
'on the right': (75, 45),
'on the top': (45, 15),
'on the bottom': (45, 75),
'on the top-left': (15, 15),
'on the top-right': (75, 15),
'on the bottom-left': (15, 75),
'on the bottom-right': (75, 75)
valid_offsets = { # x, y in 90*90
'no offset': (0, 0),
'slightly to the left': (-10, 0),
'slightly to the right': (10, 0),
'slightly to the upper': (0, -10),
'slightly to the lower': (0, 10),
'slightly to the upper-left': (-10, -10),
'slightly to the upper-right': (10, -10),
'slightly to the lower-left': (-10, 10),
'slightly to the lower-right': (10, 10)
valid_areas = { # w, h in 90*90
"a small square area": (50, 50),
"a small vertical area": (40, 60),
"a small horizontal area": (60, 40),
"a medium-sized square area": (60, 60),
"a medium-sized vertical area": (50, 80),
"a medium-sized horizontal area": (80, 50),
"a large square area": (70, 70),
"a large vertical area": (60, 90),
"a large horizontal area": (90, 60)
# 解析角色位置的函数
def parse_character_position(character_position):
# 定义正则表达式模式
location_pattern = '|'.join(re.escape(key) for key in valid_locations.keys())
offset_pattern = '|'.join(re.escape(key) for key in valid_offsets.keys())
area_pattern = '|'.join(re.escape(key) for key in valid_areas.keys())
# 提取位置
location_match = re.search(location_pattern, character_position, re.IGNORECASE)
location = location_match.group(0) if location_match else 'in the center'
# 提取偏移
offset_match = re.search(offset_pattern, character_position, re.IGNORECASE)
offset = offset_match.group(0) if offset_match else 'no offset'
# 提取区域
area_match = re.search(area_pattern, character_position, re.IGNORECASE)
area = area_match.group(0) if area_match else 'a medium-sized square area'
return {
'location': location,
'offset': offset,
'area': area
# 创建掩码的函数
def create_attention_mask(image_width, image_height, location, offset, area):
# 图像在生成时通常会被缩放为 90x90,因此先定义一个基础尺寸
base_size = 90
# 获取位置坐标
loc_x, loc_y = valid_locations.get(location, (45, 45))
# 获取偏移量
offset_x, offset_y = valid_offsets.get(offset, (0, 0))
# 获取区域大小
area_width, area_height = valid_areas.get(area, (60, 60))
# 计算最终位置
final_x = loc_x + offset_x
final_y = loc_y + offset_y
# 将坐标和尺寸映射到实际图像尺寸
scale_x = image_width / base_size
scale_y = image_height / base_size
center_x = final_x * scale_x
center_y = final_y * scale_y
width = area_width * scale_x
height = area_height * scale_y
# 计算左上角和右下角坐标
x_start = int(max(center_x - width / 2, 0))
y_start = int(max(center_y - height / 2, 0))
x_end = int(min(center_x + width / 2, image_width))
y_end = int(min(center_y + height / 2, image_height))
# 创建掩码
mask = torch.zeros((image_height, image_width), dtype=torch.float32, device="cuda")
mask[y_start:y_end, x_start:x_end] = 1.0
# 展平成一维
mask_flat = mask.view(-1) # 形状为 (image_height * image_width,)
return mask_flat
# 自定义注意力处理器
class CustomCrossAttentionProcessor(AttentionProcessor):
def __init__(self, masks, embeddings, adapter_names):
self.masks = masks # 列表,包含每个角色的掩码
self.embeddings = embeddings # 列表,包含每个角色的嵌入
self.adapter_names = adapter_names # 列表,包含每个角色的 LoRA 适配器名称
def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, **kwargs):
# 获取当前的 adapter_name
adapter_name = getattr(attn, 'adapter_name', None)
if adapter_name is None or adapter_name not in self.adapter_names:
# 如果没有 adapter_name,直接执行默认的注意力计算
return super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs)
# 查找 adapter_name 对应的索引
idx = self.adapter_names.index(adapter_name)
mask = self.masks[idx]
# 标准的注意力计算
batch_size, sequence_length, _ = hidden_states.shape
query = attn.to_q(hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
# 重塑以适应多头注意力
query = query.view(batch_size, -1, attn.heads, attn.head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, attn.head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, attn.head_dim).transpose(1, 2)
# 计算注意力得分
attention_scores = torch.matmul(query, key.transpose(-1, -2)) * attn.scale
# 应用掩码调整注意力得分
# 将 mask 调整为与 attention_scores 兼容的形状
# 假设 key_len 与 mask 的长度一致
mask_expanded = mask.unsqueeze(0).unsqueeze(0).unsqueeze(0) # (1, 1, 1, key_len)
# 将掩码应用于 attention_scores
attention_scores += mask_expanded * 1e6 # 增强对应位置的注意力
# 计算注意力概率
attention_probs = torch.nn.functional.softmax(attention_scores, dim=-1)
# 计算上下文向量
context = torch.matmul(attention_probs, value)
# 重塑回原始形状
context = context.transpose(1, 2).reshape(batch_size, -1, attn.heads * attn.head_dim)
# 输出投影
hidden_states = attn.to_out(context)
return hidden_states
# 替换注意力处理器的函数
def replace_attention_processors(pipe, masks, embeddings, adapter_names):
custom_processor = CustomCrossAttentionProcessor(masks, embeddings, adapter_names)
for name, module in pipe.unet.named_modules():
if hasattr(module, 'attn2'):
# 设置 adapter_name 为模块的属性
module.attn2.adapter_name = getattr(module, 'adapter_name', None)
module.attn2.processor = custom_processor
# 生成图像的函数
def generate_image_with_embeddings(prompt_embeddings, steps, seed, cfg_scale, width, height, progress):
generator = torch.Generator(device="cuda").manual_seed(seed)
with calculateDuration("Generating image"):
# Generate image
generated_image = pipe(
progress(99, "Generate success!")
return generated_image
# 主函数
def run_lora(prompt_bg, character_prompts_json, character_positions_json, lora_strings_json, prompt_details, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale, upload_to_r2, account_id, access_key, secret_key, bucket, progress=gr.Progress(track_tqdm=True)):
# 解析角色提示词、位置和 LoRA 字符串
character_prompts = json.loads(character_prompts_json)
character_positions = json.loads(character_positions_json)
lora_strings = json.loads(lora_strings_json)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON input: {e}")
# 确保提示词、位置和 LoRA 字符串的数量一致
if len(character_prompts) != len(character_positions) or len(character_prompts) != len(lora_strings):
raise ValueError("The number of character prompts, positions, and LoRA strings must be the same.")
# 角色的数量
num_characters = len(character_prompts)
# Load LoRA weights
with calculateDuration("Loading LoRA weights"):
adapter_names = []
for lora_info in lora_strings:
lora_repo = lora_info.get("repo")
weights = lora_info.get("weights")
adapter_name = lora_info.get("adapter_name")
if lora_repo and weights and adapter_name:
# 调用 pipe.load_lora_weights() 方法加载权重
pipe.load_lora_weights(lora_repo, weight_name=weights, adapter_name=adapter_name)
# 将 adapter_name 设置为模型的属性
setattr(pipe.unet, 'adapter_name', adapter_name)
raise ValueError("Invalid LoRA string format. Each item must have 'repo', 'weights', and 'adapter_name' keys.")
adapter_weights = [lora_scale] * len(adapter_names)
# 调用 pipeline.set_adapters 方法设置 adapter 和对应权重
pipe.set_adapters(adapter_names, adapter_weights=adapter_weights)
# 确保 adapter_names 的数量与角色数量匹配
if len(adapter_names) != num_characters:
raise ValueError("The number of LoRA adapters must match the number of characters.")
# Set random seed for reproducibility
if randomize_seed:
with calculateDuration("Set random seed"):
seed = random.randint(0, MAX_SEED)
# 编码提示词
with calculateDuration("Encoding prompts"):
# 编码背景提示词
bg_text_input = pipe.tokenizer(prompt_bg, return_tensors="pt").to("cuda")
bg_embeddings = pipe.text_encoder(bg_text_input.input_ids.to(device))[0]
# 编码角色提示词
character_embeddings = []
for prompt in character_prompts:
char_text_input = pipe.tokenizer(prompt, return_tensors="pt").to("cuda")
char_embeddings = pipe.text_encoder(char_text_input.input_ids.to(device))[0]
# 编码互动细节提示词
details_text_input = pipe.tokenizer(prompt_details, return_tensors="pt").to("cuda")
details_embeddings = pipe.text_encoder(details_text_input.input_ids.to(device))[0]
# 合并背景和互动细节的嵌入
prompt_embeddings = torch.cat([bg_embeddings, details_embeddings], dim=1)
# 解析角色位置
character_infos = []
for position_str in character_positions:
info = parse_character_position(position_str)
# 创建角色的掩码
masks = []
for info in character_infos:
mask = create_attention_mask(width, height, info['location'], info['offset'], info['area'])
# 替换注意力处理器
replace_attention_processors(pipe, masks, character_embeddings, adapter_names)
# Generate image
final_image = generate_image_with_embeddings(prompt_embeddings, steps, seed, cfg_scale, width, height, progress)
# 您可以在此处添加上传图片的代码
result = {"status": "success", "message": "Image generated"}
progress(100, "Completed!")
return final_image, seed, json.dumps(result)
# Gradio 界面
#col-container {
margin: 0 auto;
max-width: 640px;
with gr.Blocks(css=css) as demo:
gr.Markdown("Flux with LoRA")
with gr.Row():
with gr.Column():
prompt_bg = gr.Text(label="Background Prompt", placeholder="Enter background/scene prompt", lines=2)
character_prompts = gr.Text(label="Character Prompts (JSON List)", placeholder='["Character 1 prompt", "Character 2 prompt"]', lines=5)
character_positions = gr.Text(label="Character Positions (JSON List)", placeholder='["Character 1 position", "Character 2 position"]', lines=5)
lora_strings_json = gr.Text(label="LoRA Strings (JSON List)", placeholder='[{"repo": "lora_repo1", "weights": "weights1", "adapter_name": "adapter_name1"}, {"repo": "lora_repo2", "weights": "weights2", "adapter_name": "adapter_name2"}]', lines=5)
prompt_details = gr.Text(label="Interaction Details", placeholder="Enter interaction details between characters", lines=2)
run_button = gr.Button("Run", scale=0)
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=1, step=0.01, value=0.5)
with gr.Row():
width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=512)
height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=512)
with gr.Row():
cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=7.5)
steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=28)
upload_to_r2 = gr.Checkbox(label="Upload to R2", value=False)
account_id = gr.Textbox(label="Account Id", placeholder="Enter R2 account id")
access_key = gr.Textbox(label="Access Key", placeholder="Enter R2 access key here")
secret_key = gr.Textbox(label="Secret Key", placeholder="Enter R2 secret key here")
bucket = gr.Textbox(label="Bucket Name", placeholder="Enter R2 bucket name here")
with gr.Column():
result = gr.Image(label="Result", show_label=False)
seed_output = gr.Text(label="Seed")
json_text = gr.Text(label="Result JSON")
inputs = [
outputs = [result, seed_output, json_text]