import gradio as gr
import json
import logging
import torch
import base64
import rembg
import numpy as np
from io import BytesIO
from PIL import Image
from diffusers import (
    DiffusionPipeline, 
    EulerDiscreteScheduler, 
    DPMSolverMultistepScheduler,
    DPMSolverSinglestepScheduler,
    KDPM2DiscreteScheduler,
    KDPM2AncestralDiscreteScheduler,
    EulerAncestralDiscreteScheduler,
    HeunDiscreteScheduler,
    LMSDiscreteScheduler,
    DEISMultistepScheduler,
    UniPCMultistepScheduler
)
import spaces

# Load LoRAs from JSON file
with open('loras.json', 'r') as f:
    loras = json.load(f)

# Initialize the base model
base_model = "stabilityai/stable-diffusion-xl-base-1.0"
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.float16)
pipe.to("cuda")

def image_to_base64(image: Image) -> str:
    buffered = BytesIO()
    image.save(buffered, format="PNG")  # You can change the format as needed (e.g., "JPEG")
    img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
    return img_base64

def remove_bg(image: Image):
    input_array_bg = np.array(image)
    # Apply background removal using rembg
    output_array_bg = rembg.remove(input_array_bg)
    # Create a PIL Image from the output array
    img = Image.fromarray(output_array_bg)

    mask = img.convert('L')  # Convert to grayscale
    mask_array = np.array(mask)

    # Create a binary mask (non-background areas are 255, background areas are 0)
    binary_mask = mask_array > 0

    # Find the bounding box of the non-background areas
    coords = np.argwhere(binary_mask)
    x0, y0 = coords.min(axis=0)
    x1, y1 = coords.max(axis=0) + 1

    # Crop the output image using the bounding box
    cropped_output_image = img.crop((y0, x0, y1, x1))

    # Resize the cropped image to 1024x1024
    upscaled_image = cropped_output_image.resize((1024, 1024), Image.LANCZOS)
    return upscaled_image
    
def update_selection(evt: gr.SelectData):
    selected_lora = loras[evt.index]
    new_placeholder = f"Type a prompt for {selected_lora['title']}"
    lora_repo = selected_lora["repo"]
    updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo}) ✨"
    return (
        gr.update(placeholder=new_placeholder),
        updated_text,
        evt.index
    )

@spaces.GPU
def run_lora(prompt, negative_prompt, cfg_scale, steps, scheduler, seed, width, height, lora_scale):
    if selected_index is None:
        raise gr.Error("You must select a LoRA before proceeding.")

    # selected_lora = loras[selected_index]
    # lora_path = selected_lora["repo"]
    # trigger_word = selected_lora["trigger_word"]

    # Load LoRA weights
    pipe.load_lora_weights("Abdullah-Habib/lora-logo-v1",scale = 1)
    # pipe.load_lora_weights("Abdullah-Habib/logolora",scale = 1)
    # pipe.load_lora_weights("Abdullah-Habib/icon-lora",scale = 0.5)

    # Set scheduler
    scheduler_config = pipe.scheduler.config
    if scheduler == "DPM++ 2M":
        pipe.scheduler = DPMSolverMultistepScheduler.from_config(scheduler_config)
    elif scheduler == "DPM++ 2M Karras":
        pipe.scheduler = DPMSolverMultistepScheduler.from_config(scheduler_config, use_karras_sigmas=True)
    elif scheduler == "DPM++ 2M SDE":
        pipe.scheduler = DPMSolverMultistepScheduler.from_config(scheduler_config, algorithm_type="sde-dpmsolver++")
    elif scheduler == "DPM++ 2M SDE Karras":
        pipe.scheduler = DPMSolverMultistepScheduler.from_config(scheduler_config, use_karras_sigmas=True, algorithm_type="sde-dpmsolver++")
    elif scheduler == "DPM++ SDE":
        pipe.scheduler = DPMSolverSinglestepScheduler.from_config(scheduler_config)
    elif scheduler == "DPM++ SDE Karras":
        pipe.scheduler = DPMSolverSinglestepScheduler.from_config(scheduler_config, use_karras_sigmas=True)
    elif scheduler == "DPM2":
        pipe.scheduler = KDPM2DiscreteScheduler.from_config(scheduler_config)
    elif scheduler == "DPM2 Karras":
        pipe.scheduler = KDPM2DiscreteScheduler.from_config(scheduler_config, use_karras_sigmas=True)
    elif scheduler == "DPM2 a":
        pipe.scheduler = KDPM2AncestralDiscreteScheduler.from_config(scheduler_config)
    elif scheduler == "DPM2 a Karras":
        pipe.scheduler = KDPM2AncestralDiscreteScheduler.from_config(scheduler_config, use_karras_sigmas=True)
    elif scheduler == "Euler":
        pipe.scheduler = EulerDiscreteScheduler.from_config(scheduler_config)
    elif scheduler == "Euler a":
        pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler_config)
    elif scheduler == "Heun":
        pipe.scheduler = HeunDiscreteScheduler.from_config(scheduler_config)
    elif scheduler == "LMS":
        pipe.scheduler = LMSDiscreteScheduler.from_config(scheduler_config)
    elif scheduler == "LMS Karras":
        pipe.scheduler = LMSDiscreteScheduler.from_config(scheduler_config, use_karras_sigmas=True)
    elif scheduler == "DEIS":
        pipe.scheduler = DEISMultistepScheduler.from_config(scheduler_config)
    elif scheduler == "UniPC":
        pipe.scheduler = UniPCMultistepScheduler.from_config(scheduler_config)

    # Set random seed for reproducibility
    generator = torch.Generator(device="cuda").manual_seed(seed)

    # Generate image
    image = pipe(
        prompt=f"{prompt}, rounded square, logo, logoredmaf, icons",
        negative_prompt=negative_prompt,
        num_inference_steps=steps,
        guidance_scale=cfg_scale,
        width=width,
        height=height,
        generator=generator,
        # cross_attention_kwargs={"scale": lora_scale},
    ).images[0]

    # Unload LoRA weights
    pipe.unload_lora_weights()
    image_without_bg = remove_bg(image)
    return image_to_base64(image_without_bg)

with gr.Blocks(theme=gr.themes.Soft()) as app:

    selected_index = gr.State(None)

    with gr.Row():
        with gr.Column(scale=2):
            result = gr.Text(label="Generated Image")
            generate_button = gr.Button("Generate", variant="primary")

        # with gr.Column(scale=1):
        #     gallery = gr.Gallery(
        #         [(item["image"], item["title"]) for item in loras],
        #         label="LoRA Gallery",
        #         allow_preview=False,
        #         columns=2
        #     )

    with gr.Row():
        with gr.Column():
            prompt_title = ""
            selected_info = gr.Markdown("")
            prompt = gr.Textbox(label="Prompt", lines=3, placeholder="Please enter a prompt")
            negative_prompt = gr.Textbox(label="Negative Prompt", lines=2, value="low quality, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry")

        with gr.Column():
            with gr.Row():
                cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=7.5)
                steps = gr.Slider(label="Steps", minimum=1, maximum=100, step=1, value=30)
            
            with gr.Row():
                width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1024)
                height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024)
            
            with gr.Row():
                seed = gr.Slider(label="Seed", minimum=0, maximum=2**32-1, step=1, value=0, randomize=True)
                lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=1, step=0.01, value=1)
            
            scheduler = gr.Dropdown(
                label="Scheduler", 
                choices=[
                    "DPM++ 2M", "DPM++ 2M Karras", "DPM++ 2M SDE", "DPM++ 2M SDE Karras",
                    "DPM++ SDE", "DPM++ SDE Karras", "DPM2", "DPM2 Karras", "DPM2 a", "DPM2 a Karras",
                    "Euler", "Euler a", "Heun", "LMS", "LMS Karras", "DEIS", "UniPC"
                ],
                value="DPM++ 2M SDE Karras"
            )

    # gallery.select(update_selection, outputs=[prompt, selected_info, selected_index])
    
    generate_button.click(
        fn=run_lora,
        inputs=[prompt, negative_prompt, cfg_scale, steps, scheduler, seed, width, height, lora_scale],
        outputs=[result]
    )

app.queue()
app.launch()