import os import gradio as gr import numpy as np import random import torch import json import logging from diffusers import DiffusionPipeline from huggingface_hub import login import time from datetime import datetime from io import BytesIO from diffusers.models.attention_processor import AttentionProcessor import re import json # 登录 Hugging Face Hub HF_TOKEN = os.environ.get("HF_TOKEN") login(token=HF_TOKEN) # 初始化 dtype = torch.float16 # 您可以根据需要调整数据类型 device = "cuda" if torch.cuda.is_available() else "cpu" base_model = "black-forest-labs/FLUX.1-dev" # 替换为您的模型 # 加载管道 pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype).to(device) MAX_SEED = 2**32 - 1 class calculateDuration: def __init__(self, activity_name=""): self.activity_name = activity_name def __enter__(self): self.start_time = time.time() return self def __exit__(self, exc_type, exc_value, traceback): self.end_time = time.time() self.elapsed_time = self.end_time - self.start_time if self.activity_name: print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds") else: print(f"Elapsed time: {self.elapsed_time:.6f} seconds") # 定义位置、偏移和区域的映射 valid_locations = { # x, y in 90*90 'in the center': (45, 45), 'on the left': (15, 45), 'on the right': (75, 45), 'on the top': (45, 15), 'on the bottom': (45, 75), 'on the top-left': (15, 15), 'on the top-right': (75, 15), 'on the bottom-left': (15, 75), 'on the bottom-right': (75, 75) } valid_offsets = { # x, y in 90*90 'no offset': (0, 0), 'slightly to the left': (-10, 0), 'slightly to the right': (10, 0), 'slightly to the upper': (0, -10), 'slightly to the lower': (0, 10), 'slightly to the upper-left': (-10, -10), 'slightly to the upper-right': (10, -10), 'slightly to the lower-left': (-10, 10), 'slightly to the lower-right': (10, 10) } valid_areas = { # w, h in 90*90 "a small square area": (50, 50), "a small vertical area": (40, 60), "a small horizontal area": (60, 40), "a medium-sized square area": (60, 60), "a medium-sized vertical area": (50, 80), "a medium-sized horizontal area": (80, 50), "a large square area": (70, 70), "a large vertical area": (60, 90), "a large horizontal area": (90, 60) } # 解析角色位置的函数 def parse_character_position(character_position): # 定义正则表达式模式 location_pattern = '|'.join(re.escape(key) for key in valid_locations.keys()) offset_pattern = '|'.join(re.escape(key) for key in valid_offsets.keys()) area_pattern = '|'.join(re.escape(key) for key in valid_areas.keys()) # 提取位置 location_match = re.search(location_pattern, character_position, re.IGNORECASE) location = location_match.group(0) if location_match else 'in the center' # 提取偏移 offset_match = re.search(offset_pattern, character_position, re.IGNORECASE) offset = offset_match.group(0) if offset_match else 'no offset' # 提取区域 area_match = re.search(area_pattern, character_position, re.IGNORECASE) area = area_match.group(0) if area_match else 'a medium-sized square area' return { 'location': location, 'offset': offset, 'area': area } # 创建掩码的函数 def create_attention_mask(image_width, image_height, location, offset, area): # 图像在生成时通常会被缩放为 90x90,因此先定义一个基础尺寸 base_size = 90 # 获取位置坐标 loc_x, loc_y = valid_locations.get(location, (45, 45)) # 获取偏移量 offset_x, offset_y = valid_offsets.get(offset, (0, 0)) # 获取区域大小 area_width, area_height = valid_areas.get(area, (60, 60)) # 计算最终位置 final_x = loc_x + offset_x final_y = loc_y + offset_y # 将坐标和尺寸映射到实际图像尺寸 scale_x = image_width / base_size scale_y = image_height / base_size center_x = final_x * scale_x center_y = final_y * scale_y width = area_width * scale_x height = area_height * scale_y # 计算左上角和右下角坐标 x_start = int(max(center_x - width / 2, 0)) y_start = int(max(center_y - height / 2, 0)) x_end = int(min(center_x + width / 2, image_width)) y_end = int(min(center_y + height / 2, image_height)) # 创建掩码 mask = torch.zeros((image_height, image_width), dtype=torch.float32, device="cuda") mask[y_start:y_end, x_start:x_end] = 1.0 # 展平成一维 mask_flat = mask.view(-1) # 形状为 (image_height * image_width,) return mask_flat # 自定义注意力处理器 class CustomCrossAttentionProcessor(AttentionProcessor): def __init__(self, masks, embeddings, adapter_names): super().__init__() self.masks = masks # 列表,包含每个角色的掩码 self.embeddings = embeddings # 列表,包含每个角色的嵌入 self.adapter_names = adapter_names # 列表,包含每个角色的 LoRA 适配器名称 def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, **kwargs): # 获取当前的 adapter_name adapter_name = getattr(attn, 'adapter_name', None) if adapter_name is None or adapter_name not in self.adapter_names: # 如果没有 adapter_name,直接执行默认的注意力计算 return super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs) # 查找 adapter_name 对应的索引 idx = self.adapter_names.index(adapter_name) mask = self.masks[idx] # 标准的注意力计算 batch_size, sequence_length, _ = hidden_states.shape query = attn.to_q(hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) # 重塑以适应多头注意力 query = query.view(batch_size, -1, attn.heads, attn.head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, attn.head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, attn.head_dim).transpose(1, 2) # 计算注意力得分 attention_scores = torch.matmul(query, key.transpose(-1, -2)) * attn.scale # 应用掩码调整注意力得分 # 将 mask 调整为与 attention_scores 兼容的形状 # 假设 key_len 与 mask 的长度一致 mask_expanded = mask.unsqueeze(0).unsqueeze(0).unsqueeze(0) # (1, 1, 1, key_len) # 将掩码应用于 attention_scores attention_scores += mask_expanded * 1e6 # 增强对应位置的注意力 # 计算注意力概率 attention_probs = torch.nn.functional.softmax(attention_scores, dim=-1) # 计算上下文向量 context = torch.matmul(attention_probs, value) # 重塑回原始形状 context = context.transpose(1, 2).reshape(batch_size, -1, attn.heads * attn.head_dim) # 输出投影 hidden_states = attn.to_out(context) return hidden_states # 替换注意力处理器的函数 def replace_attention_processors(pipe, masks, embeddings, adapter_names): custom_processor = CustomCrossAttentionProcessor(masks, embeddings, adapter_names) for name, module in pipe.unet.named_modules(): if hasattr(module, 'attn2'): # 设置 adapter_name 为模块的属性 module.attn2.adapter_name = getattr(module, 'adapter_name', None) module.attn2.processor = custom_processor # 生成图像的函数 @spaces.GPU @torch.inference_mode() def generate_image_with_embeddings(prompt_embeddings, steps, seed, cfg_scale, width, height, progress): pipe.to("cuda") generator = torch.Generator(device="cuda").manual_seed(seed) with calculateDuration("Generating image"): # Generate image generated_image = pipe( prompt_embeds=prompt_embeddings, num_inference_steps=steps, guidance_scale=cfg_scale, width=width, height=height, generator=generator, ).images[0] progress(99, "Generate success!") return generated_image # 主函数 def run_lora(prompt_bg, character_prompts_json, character_positions_json, lora_strings_json, prompt_details, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale, upload_to_r2, account_id, access_key, secret_key, bucket, progress=gr.Progress(track_tqdm=True)): # 解析角色提示词、位置和 LoRA 字符串 try: character_prompts = json.loads(character_prompts_json) character_positions = json.loads(character_positions_json) lora_strings = json.loads(lora_strings_json) except json.JSONDecodeError as e: raise ValueError(f"Invalid JSON input: {e}") # 确保提示词、位置和 LoRA 字符串的数量一致 if len(character_prompts) != len(character_positions) or len(character_prompts) != len(lora_strings): raise ValueError("The number of character prompts, positions, and LoRA strings must be the same.") # 角色的数量 num_characters = len(character_prompts) # Load LoRA weights with calculateDuration("Loading LoRA weights"): pipe.unload_lora_weights() adapter_names = [] for lora_info in lora_strings: lora_repo = lora_info.get("repo") weights = lora_info.get("weights") adapter_name = lora_info.get("adapter_name") if lora_repo and weights and adapter_name: # 调用 pipe.load_lora_weights() 方法加载权重 pipe.load_lora_weights(lora_repo, weight_name=weights, adapter_name=adapter_name) adapter_names.append(adapter_name) # 将 adapter_name 设置为模型的属性 setattr(pipe.unet, 'adapter_name', adapter_name) else: raise ValueError("Invalid LoRA string format. Each item must have 'repo', 'weights', and 'adapter_name' keys.") adapter_weights = [lora_scale] * len(adapter_names) # 调用 pipeline.set_adapters 方法设置 adapter 和对应权重 pipe.set_adapters(adapter_names, adapter_weights=adapter_weights) # 确保 adapter_names 的数量与角色数量匹配 if len(adapter_names) != num_characters: raise ValueError("The number of LoRA adapters must match the number of characters.") # Set random seed for reproducibility if randomize_seed: with calculateDuration("Set random seed"): seed = random.randint(0, MAX_SEED) # 编码提示词 with calculateDuration("Encoding prompts"): # 编码背景提示词 bg_text_input = pipe.tokenizer(prompt_bg, return_tensors="pt").to("cuda") bg_embeddings = pipe.text_encoder(bg_text_input.input_ids.to(device))[0] # 编码角色提示词 character_embeddings = [] for prompt in character_prompts: char_text_input = pipe.tokenizer(prompt, return_tensors="pt").to("cuda") char_embeddings = pipe.text_encoder(char_text_input.input_ids.to(device))[0] character_embeddings.append(char_embeddings) # 编码互动细节提示词 details_text_input = pipe.tokenizer(prompt_details, return_tensors="pt").to("cuda") details_embeddings = pipe.text_encoder(details_text_input.input_ids.to(device))[0] # 合并背景和互动细节的嵌入 prompt_embeddings = torch.cat([bg_embeddings, details_embeddings], dim=1) # 解析角色位置 character_infos = [] for position_str in character_positions: info = parse_character_position(position_str) character_infos.append(info) # 创建角色的掩码 masks = [] for info in character_infos: mask = create_attention_mask(width, height, info['location'], info['offset'], info['area']) masks.append(mask) # 替换注意力处理器 replace_attention_processors(pipe, masks, character_embeddings, adapter_names) # Generate image final_image = generate_image_with_embeddings(prompt_embeddings, steps, seed, cfg_scale, width, height, progress) # 您可以在此处添加上传图片的代码 result = {"status": "success", "message": "Image generated"} progress(100, "Completed!") return final_image, seed, json.dumps(result) # Gradio 界面 css=""" #col-container { margin: 0 auto; max-width: 640px; } """ with gr.Blocks(css=css) as demo: gr.Markdown("Flux with LoRA") with gr.Row(): with gr.Column(): prompt_bg = gr.Text(label="Background Prompt", placeholder="Enter background/scene prompt", lines=2) character_prompts = gr.Text(label="Character Prompts (JSON List)", placeholder='["Character 1 prompt", "Character 2 prompt"]', lines=5) character_positions = gr.Text(label="Character Positions (JSON List)", placeholder='["Character 1 position", "Character 2 position"]', lines=5) lora_strings_json = gr.Text(label="LoRA Strings (JSON List)", placeholder='[{"repo": "lora_repo1", "weights": "weights1", "adapter_name": "adapter_name1"}, {"repo": "lora_repo2", "weights": "weights2", "adapter_name": "adapter_name2"}]', lines=5) prompt_details = gr.Text(label="Interaction Details", placeholder="Enter interaction details between characters", lines=2) run_button = gr.Button("Run", scale=0) with gr.Accordion("Advanced Settings", open=False): with gr.Row(): seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True) randomize_seed = gr.Checkbox(label="Randomize seed", value=True) lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=1, step=0.01, value=0.5) with gr.Row(): width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=512) height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=512) with gr.Row(): cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=7.5) steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=28) upload_to_r2 = gr.Checkbox(label="Upload to R2", value=False) account_id = gr.Textbox(label="Account Id", placeholder="Enter R2 account id") access_key = gr.Textbox(label="Access Key", placeholder="Enter R2 access key here") secret_key = gr.Textbox(label="Secret Key", placeholder="Enter R2 secret key here") bucket = gr.Textbox(label="Bucket Name", placeholder="Enter R2 bucket name here") with gr.Column(): result = gr.Image(label="Result", show_label=False) seed_output = gr.Text(label="Seed") json_text = gr.Text(label="Result JSON") inputs = [ prompt_bg, character_prompts, character_positions, lora_strings_json, prompt_details, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale, upload_to_r2, account_id, access_key, secret_key, bucket ] outputs = [result, seed_output, json_text] run_button.click( fn=run_lora, inputs=inputs, outputs=outputs ) demo.queue().launch()