import os
import datetime
import einops
import gradio as gr
from gradio_imageslider import ImageSlider
import numpy as np
import torch
import random
from PIL import Image
from pathlib import Path
from torchvision import transforms
import torch.nn.functional as F
from torchvision.models import resnet50, ResNet50_Weights

from pytorch_lightning import seed_everything
from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor
from diffusers import AutoencoderKL, DDIMScheduler, PNDMScheduler, DPMSolverMultistepScheduler, UniPCMultistepScheduler

from pipelines.pipeline_pasd import StableDiffusionControlNetPipeline
from myutils.misc import load_dreambooth_lora, rand_name
from myutils.wavelet_color_fix import wavelet_color_fix
from annotator.retinaface import RetinaFaceDetection

use_pasd_light = False
face_detector = RetinaFaceDetection()

if use_pasd_light:
    from models.pasd_light.unet_2d_condition import UNet2DConditionModel
    from models.pasd_light.controlnet import ControlNetModel
else:
    from models.pasd.unet_2d_condition import UNet2DConditionModel
    from models.pasd.controlnet import ControlNetModel

pretrained_model_path = "checkpoints/stable-diffusion-v1-5"
ckpt_path = "runs/pasd/checkpoint-100000"
#dreambooth_lora_path = "checkpoints/personalized_models/toonyou_beta3.safetensors"
dreambooth_lora_path = "checkpoints/personalized_models/majicmixRealistic_v6.safetensors"
#dreambooth_lora_path = "checkpoints/personalized_models/Realistic_Vision_V5.1.safetensors"
weight_dtype = torch.float16
device = "cuda"

scheduler = UniPCMultistepScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler")
text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder")
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae")
feature_extractor = CLIPImageProcessor.from_pretrained(f"{pretrained_model_path}/feature_extractor")
unet = UNet2DConditionModel.from_pretrained(ckpt_path, subfolder="unet")
controlnet = ControlNetModel.from_pretrained(ckpt_path, subfolder="controlnet")
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
unet.requires_grad_(False)
controlnet.requires_grad_(False)

unet, vae, text_encoder = load_dreambooth_lora(unet, vae, text_encoder, dreambooth_lora_path)

text_encoder.to(device, dtype=weight_dtype)
vae.to(device, dtype=weight_dtype)
unet.to(device, dtype=weight_dtype)
controlnet.to(device, dtype=weight_dtype)

validation_pipeline = StableDiffusionControlNetPipeline(
        vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, feature_extractor=feature_extractor, 
        unet=unet, controlnet=controlnet, scheduler=scheduler, safety_checker=None, requires_safety_checker=False,
    )
#validation_pipeline.enable_vae_tiling()
validation_pipeline._init_tiled_vae(decoder_tile_size=224)

weights = ResNet50_Weights.DEFAULT
preprocess = weights.transforms()
resnet = resnet50(weights=weights)
resnet.eval()

def resize_image(image_path, target_height):
    # Open the image file
    with Image.open(image_path) as img:
        # Calculate the ratio to resize the image to the target height
        ratio = target_height / float(img.size[1])
        # Calculate the new width based on the aspect ratio
        new_width = int(float(img.size[0]) * ratio)
        # Resize the image
        resized_img = img.resize((new_width, target_height), Image.LANCZOS)
        # Save the resized image
        #resized_img.save(output_path)
        return resized_img

def inference(input_image, prompt, a_prompt, n_prompt, denoise_steps, upscale, alpha, cfg, seed):
    input_image = resize_image(input_image, 512)
    process_size = 768
    resize_preproc = transforms.Compose([
        transforms.Resize(process_size, interpolation=transforms.InterpolationMode.BILINEAR),
    ])
    
    # Get the current timestamp
    timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")

    with torch.no_grad():
        seed_everything(seed)
        generator = torch.Generator(device=device)

        input_image = input_image.convert('RGB')
        batch = preprocess(input_image).unsqueeze(0)
        prediction = resnet(batch).squeeze(0).softmax(0)
        class_id = prediction.argmax().item()
        score = prediction[class_id].item()
        category_name = weights.meta["categories"][class_id]
        if score >= 0.1:
            prompt += f"{category_name}" if prompt=='' else f", {category_name}"

        prompt = a_prompt if prompt=='' else f"{prompt}, {a_prompt}"

        ori_width, ori_height = input_image.size
        resize_flag = False

        rscale = upscale
        input_image = input_image.resize((input_image.size[0]*rscale, input_image.size[1]*rscale))
        
        #if min(validation_image.size) < process_size:
        #    validation_image = resize_preproc(validation_image)

        input_image = input_image.resize((input_image.size[0]//8*8, input_image.size[1]//8*8))
        width, height = input_image.size
        resize_flag = True #

        try:
            image = validation_pipeline(
                    None, prompt, input_image, num_inference_steps=denoise_steps, generator=generator, height=height, width=width, guidance_scale=cfg, 
                    negative_prompt=n_prompt, conditioning_scale=alpha, eta=0.0,
                ).images[0]
            
            if True: #alpha<1.0:
                image = wavelet_color_fix(image, input_image)
        
            if resize_flag: 
                image = image.resize((ori_width*rscale, ori_height*rscale))
        except Exception as e:
            print(e)
            image = Image.new(mode="RGB", size=(512, 512))
    
    # Convert and save the image as JPEG
    image.save(f'result_{timestamp}.jpg', 'JPEG')

    # Convert and save the image as JPEG
    input_image.save(f'input_{timestamp}.jpg', 'JPEG')
    
    return (f"input_{timestamp}.jpg", f"result_{timestamp}.jpg"), f"result_{timestamp}.jpg"

title = "Pixel-Aware Stable Diffusion for Real-ISR"
description = "Gradio Demo for PASD Real-ISR. To use it, simply upload your image, or click one of the examples to load them."
article = "<a href='https://github.com/yangxy/PASD' target='_blank'>Github Repo Pytorch</a>"
#examples=[['samples/27d38eeb2dbbe7c9.png'],['samples/629e4da70703193b.png']]

css = """
#col-container{
    margin: 0 auto;
    max-width: 720px;
}
"""

with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.HTML(f"""
        <h2 style="text-align: center;">
            PASD Magnify
        </h2>
        <p style="text-align: center;">
            Pixel-Aware Stable Diffusion for Realistic Image Super-resolution and Personalized Stylization
        </p>
        
        """)
        with gr.Row():
            with gr.Column():
                input_image = gr.Image(type="filepath", sources=["upload"], value="samples/frog.png")
                prompt_in = gr.Textbox(label="Prompt", value="Frog")
                with gr.Accordion(label="Advanced settings", open=False):
                    added_prompt = gr.Textbox(label="Added Prompt", value='clean, high-resolution, 8k, best quality, masterpiece')
                    neg_prompt = gr.Textbox(label="Negative Prompt",value='dotted, noise, blur, lowres, oversmooth, longbody, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
                    denoise_steps = gr.Slider(label="Denoise Steps", minimum=10, maximum=50, value=20, step=1)
                    upsample_scale = gr.Slider(label="Upsample Scale", minimum=1, maximum=4, value=2, step=1)
                    condition_scale = gr.Slider(label="Conditioning Scale", minimum=0.5, maximum=1.5, value=1.1, step=0.1)
                    classifier_free_guidance = gr.Slider(label="Classier-free Guidance", minimum=0.1, maximum=10.0, value=7.5, step=0.1)
                    seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
                submit_btn = gr.Button("Submit")
            with gr.Column():
                b_a_slider = ImageSlider(label="B/A result", position=0.5)
                file_output = gr.File(label="Downloadable image result")

    submit_btn.click(
        fn = inference,
        inputs = [
            input_image, prompt_in,
            added_prompt, neg_prompt,
            denoise_steps,
            upsample_scale, condition_scale,
            classifier_free_guidance, seed
        ],
        outputs = [
            b_a_slider,
            file_output
        ]
    )
demo.queue().launch()