import os
from io import BytesIO

import torch

import internals.util.prompt as prompt_util
from internals.data.dataAccessor import update_db
from internals.data.task import ModelType, Task, TaskType
from internals.pipelines.controlnets import ControlNet
from internals.pipelines.high_res import HighRes
from internals.pipelines.img_classifier import ImageClassifier
from internals.pipelines.img_to_text import Image2Text
from internals.pipelines.inpainter import InPainter
from internals.pipelines.object_remove import ObjectRemoval
from internals.pipelines.prompt_modifier import PromptModifier
from internals.pipelines.remove_background import (RemoveBackground,
                                                   RemoveBackgroundV2)
from internals.pipelines.replace_background import ReplaceBackground
from internals.pipelines.safety_checker import SafetyChecker
from internals.pipelines.upscaler import Upscaler
from internals.util.avatar import Avatar
from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda
from internals.util.commons import (construct_default_s3_url, upload_image,
                                    upload_images)
from internals.util.config import (num_return_sequences, set_configs_from_task,
                                   set_model_dir, set_root_dir)
from internals.util.failure_hander import FailureHandler
from internals.util.lora_style import LoraStyle
from internals.util.slack import Slack

torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True

auto_mode = False

slack = Slack()

prompt_modifier = PromptModifier(num_of_sequences=num_return_sequences)
upscaler = Upscaler()
inpainter = InPainter()
controlnet = ControlNet()
safety_checker = SafetyChecker()
high_res = HighRes()
object_removal = ObjectRemoval()
remove_background_v2 = RemoveBackgroundV2()
replace_background = ReplaceBackground()
img2text = Image2Text()
img_classifier = ImageClassifier()
avatar = Avatar()
lora_style = LoraStyle()


def get_patched_prompt_tile_upscale(task: Task):
    return prompt_util.get_patched_prompt_tile_upscale(
        task, avatar, lora_style, img_classifier, img2text
    )


@update_db
@auto_clear_cuda_and_gc(controlnet)
@slack.auto_send_alert
def tile_upscale(task: Task):
    output_key = "crecoAI/{}_tile_upscaler.png".format(task.get_taskId())

    prompt = get_patched_prompt_tile_upscale(task)

    controlnet.load_tile_upscaler()

    lora_patcher = lora_style.get_patcher(controlnet.pipe, task.get_style())
    lora_patcher.patch()

    images, has_nsfw = controlnet.process_tile_upscaler(
        imageUrl=task.get_imageUrl(),
        seed=task.get_seed(),
        steps=task.get_steps(),
        width=task.get_width(),
        height=task.get_height(),
        prompt=prompt,
        resize_dimension=task.get_resize_dimension(),
        negative_prompt=task.get_negative_prompt(),
        guidance_scale=task.get_ti_guidance_scale(),
    )

    generated_image_url = upload_image(images[0], output_key)

    lora_patcher.cleanup()
    controlnet.cleanup()

    return {
        "modified_prompts": prompt,
        "generated_image_url": generated_image_url,
        "has_nsfw": has_nsfw,
    }


@update_db
@slack.auto_send_alert
def remove_bg(task: Task):
    # remove_background = RemoveBackground()
    output_image = remove_background_v2.remove(task.get_imageUrl())

    output_key = "crecoAI/{}_rmbg.png".format(task.get_taskId())
    upload_image(output_image, output_key)

    return {"generated_image_url": construct_default_s3_url(output_key)}


@update_db
@slack.auto_send_alert
def inpaint(task: Task):
    prompt = avatar.add_code_names(task.get_prompt())
    if task.is_prompt_engineering():
        prompt = prompt_modifier.modify(prompt)
    else:
        prompt = [prompt] * num_return_sequences

    print({"prompts": prompt})

    images = inpainter.process(
        prompt=prompt,
        image_url=task.get_imageUrl(),
        mask_image_url=task.get_maskImageUrl(),
        width=task.get_width(),
        height=task.get_height(),
        seed=task.get_seed(),
        negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
    )

    generated_image_urls = upload_images(images, "_inpaint", task.get_taskId())

    clear_cuda()

    return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls}


@update_db
@slack.auto_send_alert
def remove_object(task: Task):
    output_key = "crecoAI/{}_object_remove.png".format(task.get_taskId())

    images = object_removal.process(
        image_url=task.get_imageUrl(),
        mask_image_url=task.get_maskImageUrl(),
        seed=task.get_seed(),
        width=task.get_width(),
        height=task.get_height(),
    )
    generated_image_urls = upload_image(images[0], output_key)

    clear_cuda()

    return {"generated_image_urls": generated_image_urls}


@update_db
@slack.auto_send_alert
def replace_bg(task: Task):
    prompt = task.get_prompt()
    if task.is_prompt_engineering():
        prompt = prompt_modifier.modify(prompt)
    else:
        prompt = [prompt] * num_return_sequences

    images, has_nsfw = replace_background.replace(
        image=task.get_imageUrl(),
        prompt=prompt,
        negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
        seed=task.get_seed(),
        width=task.get_width(),
        height=task.get_height(),
        steps=task.get_steps(),
        extend_object=task.rbg_extend_object(),
        product_scale_width=task.get_image_scale(),
        conditioning_scale=task.rbg_controlnet_conditioning_scale(),
    )

    generated_image_urls = upload_images(images, "_replace_bg", task.get_taskId())

    return {
        "modified_prompts": prompt,
        "generated_image_urls": generated_image_urls,
        "has_nsfw": has_nsfw,
    }


@update_db
@slack.auto_send_alert
def upscale_image(task: Task):
    output_key = "crecoAI/{}_upscale.png".format(task.get_taskId())
    out_img = None
    if task.get_modelType() == ModelType.ANIME:
        print("Using Anime model")
        out_img = upscaler.upscale_anime(
            image=task.get_imageUrl(),
            width=task.get_width(),
            height=task.get_height(),
            face_enhance=task.get_face_enhance(),
            resize_dimension=task.get_resize_dimension(),
        )
    else:
        print("Using Real model")
        out_img = upscaler.upscale(
            image=task.get_imageUrl(),
            width=task.get_width(),
            height=task.get_height(),
            face_enhance=task.get_face_enhance(),
            resize_dimension=task.get_resize_dimension(),
        )

    upload_image(BytesIO(out_img), output_key)
    return {"generated_image_url": construct_default_s3_url(output_key)}


def model_fn(model_dir):
    print("Logs: model loaded .... starts")

    set_model_dir(model_dir)
    set_root_dir(__file__)

    FailureHandler.register()

    avatar.load_local(model_dir)
    lora_style.load(model_dir)

    prompt_modifier.load()
    safety_checker.load()

    object_removal.load(model_dir)
    upscaler.load()
    inpainter.load()
    high_res.load()

    replace_background.load(
        upscaler=upscaler, remove_background=remove_background_v2, high_res=high_res
    )

    print("Logs: model loaded ....")
    return


def load_model_by_task(task: Task):
    if task.get_type() == TaskType.TILE_UPSCALE:
        controlnet.load_tile_upscaler()

    safety_checker.apply(controlnet)


@FailureHandler.clear
def predict_fn(data, pipe):
    task = Task(data)
    print("task is ", data)

    FailureHandler.handle(task)

    try:
        # Set set_environment
        set_configs_from_task(task)

        # Load model based on task
        load_model_by_task(task)

        # Apply safety checker based on environment
        safety_checker.apply(inpainter)
        safety_checker.apply(replace_background)
        safety_checker.apply(high_res)

        # Fetch avatars
        avatar.fetch_from_network(task.get_model_id())

        task_type = task.get_type()

        if task_type == TaskType.REMOVE_BG:
            return remove_bg(task)
        elif task_type == TaskType.INPAINT:
            return inpaint(task)
        elif task_type == TaskType.UPSCALE_IMAGE:
            return upscale_image(task)
        elif task_type == TaskType.OBJECT_REMOVAL:
            return remove_object(task)
        elif task_type == TaskType.REPLACE_BG:
            return replace_bg(task)
        elif task_type == TaskType.TILE_UPSCALE:
            return tile_upscale(task)
        elif task_type == TaskType.SYSTEM_CMD:
            os.system(task.get_prompt())
        else:
            raise Exception("Invalid task type")
    except Exception as e:
        print(f"Error: {e}")
        slack.error_alert(task, e)
        controlnet.cleanup()
        return None