import os
import gradio as gr
import numpy as np
import random
import spaces
import torch
import json
import logging
from diffusers import DiffusionPipeline
from huggingface_hub import login
import time
from datetime import datetime
from io import BytesIO
# from diffusers.models.attention_processor import AttentionProcessor
from diffusers.models.attention_processor import AttnProcessor2_0
import torch.nn.functional as F
import time
import boto3
from io import BytesIO
import re
import json

# 登录 Hugging Face Hub
HF_TOKEN = os.environ.get("HF_TOKEN")
login(token=HF_TOKEN)
import diffusers
print(diffusers.__version__)

# 初始化
dtype = torch.float16  # 您可以根据需要调整数据类型
device = "cuda" if torch.cuda.is_available() else "cpu"
base_model = "black-forest-labs/FLUX.1-dev"  # 替换为您的模型

# 加载管道
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype).to(device)

MAX_SEED = 2**32 - 1

class calculateDuration:
    def __init__(self, activity_name=""):
        self.activity_name = activity_name

    def __enter__(self):
        self.start_time = time.time()
        self.start_time_formatted = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(self.start_time))
        print(f"Activity: {self.activity_name}, Start time: {self.start_time_formatted}")
        return self
    
    def __exit__(self, exc_type, exc_value, traceback):
        self.end_time = time.time()
        self.elapsed_time = self.end_time - self.start_time
        self.end_time_formatted = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(self.end_time))
        
        if self.activity_name:
            print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds")
        else:
            print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
        
        print(f"Activity: {self.activity_name}, End time: {self.start_time_formatted}")

# 生成图像的函数
@spaces.GPU
@torch.inference_mode()
def generate_image(prompt, steps, seed, cfg_scale, width, height, progress):
    pipe.to(device)
    generator = torch.Generator(device=device).manual_seed(seed)
    with calculateDuration("Generating image"):
        # Generate image
        print(prompt, steps, seed, cfg_scale, width, height)
        generated_image = pipe(
            prompt=prompt,
            num_inference_steps=steps,
            guidance_scale=cfg_scale,
            width=width,
            height=height,
            generator=generator,
        ).images[0]
    
    progress(99, "Generate success!")
    return generated_image


def upload_image_to_r2(image, account_id, access_key, secret_key, bucket_name):
    print("upload_image_to_r2", account_id, access_key, secret_key, bucket_name)
    connectionUrl = f"https://{account_id}.r2.cloudflarestorage.com"

    s3 = boto3.client(
        's3',
        endpoint_url=connectionUrl,
        region_name='auto',
        aws_access_key_id=access_key,
        aws_secret_access_key=secret_key
    )

    current_time = datetime.now().strftime("%Y/%m/%d/%H%M%S")
    image_file = f"generated_images/{current_time}_{random.randint(0, MAX_SEED)}.png"
    buffer = BytesIO()
    image.save(buffer, "PNG")
    buffer.seek(0)
    s3.upload_fileobj(buffer, bucket_name, image_file)
    print("upload finish", image_file)
    
    return image_file

def run_lora(prompt, lora_strings_json,  cfg_scale, steps, randomize_seed, seed, width, height, upload_to_r2, account_id, access_key, secret_key, bucket, progress=gr.Progress(track_tqdm=True)):
    
    # Load LoRA weights
    if lora_strings_json:
        try:
            lora_configs = json.loads(lora_strings_json)
        except:
            lora_configs = None
    if lora_configs:
        with calculateDuration("Loading LoRA weights"):
            pipe.unload_lora_weights()
            adapter_names = []
            adapter_weights = []
            for lora_info in lora_configs:
                    lora_repo = lora_info.get("repo")
                    weights = lora_info.get("weights")
                    adapter_name = lora_info.get("adapter_name")
                    adapter_weight = lora_info.get("adapter_weight")
                    if lora_repo and weights and adapter_name:
                        # 加载 LoRA 权重
                        pipe.load_lora_weights(lora_repo, weight_name=weights, adapter_name=adapter_name)
                        adapter_names.append(adapter_name)
                        adapter_weights.append(adapter_weight)
            pipe.set_adapters(adapter_names, adapter_weights=adapter_weights)
    
    # Set random seed for reproducibility
    if randomize_seed:
        with calculateDuration("Set random seed"):
            seed = random.randint(0, MAX_SEED)
    
    # Generate image
    final_image = generate_image(prompt, steps, seed, cfg_scale, width, height, progress)
    
    if final_image:
        if upload_to_r2:
            with calculateDuration("Upload image"):
                url = upload_image_to_r2(final_image, account_id, access_key, secret_key, bucket)
                result = {"status": "success", "message": "upload image success", "url": url}
        else:
            result = {"status": "success", "message": "Image generated but not uploaded"}

    progress(100, "Completed!")

    return final_image, seed, json.dumps(result)

# Gradio 界面
css="""
#col-container {
    margin: 0 auto;
    max-width: 640px;
}
"""

with gr.Blocks(css=css) as demo:
    gr.Markdown("Flux with LoRA")
    with gr.Row():
        
        with gr.Column():

            prompt = gr.Text(label="Prompt", placeholder="Enter prompt", lines=2)
            lora_strings_json = gr.Text(label="LoRA Strings (JSON List)", placeholder='[{"repo": "lora_repo1", "weights": "weights1", "adapter_name": "adapter_name1"}, {"repo": "lora_repo2", "weights": "weights2", "adapter_name": "adapter_name2"}]', lines=5)
            
            run_button = gr.Button("Run", scale=0)

            with gr.Accordion("Advanced Settings", open=False):

                with gr.Row():
                    seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
                    randomize_seed = gr.Checkbox(label="Randomize seed", value=True)

                with gr.Row():
                    width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1024)
                    height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024)

                with gr.Row():
                    cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=3.5)
                    steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=28) 

                upload_to_r2 = gr.Checkbox(label="Upload to R2", value=False)
                account_id = gr.Textbox(label="Account Id", placeholder="Enter R2 account id")
                access_key = gr.Textbox(label="Access Key", placeholder="Enter R2 access key here")
                secret_key = gr.Textbox(label="Secret Key", placeholder="Enter R2 secret key here")
                bucket = gr.Textbox(label="Bucket Name", placeholder="Enter R2 bucket name here")
        

        with gr.Column():
            result = gr.Image(label="Result", show_label=False)
            seed_output = gr.Text(label="Seed")
            json_text = gr.Text(label="Result JSON")

    inputs = [
        prompt,
        lora_strings_json,
        cfg_scale,
        steps,
        randomize_seed,
        seed,
        width,
        height,
        upload_to_r2,
        account_id,
        access_key,
        secret_key,
        bucket
    ]

    outputs = [result, seed_output, json_text]

    run_button.click(
        fn=run_lora,
        inputs=inputs,
        outputs=outputs
    )

demo.queue().launch()