import streamlit as st
from streamlit_drawable_canvas import st_canvas
from PIL import Image
from typing import Union
import random
import numpy as np
import os
import time

from models import make_image_controlnet, make_inpainting, segment_image
from config import HEIGHT, WIDTH, POS_PROMPT, NEG_PROMPT, COLOR_MAPPING, map_colors, map_colors_rgb
from palette import COLOR_MAPPING_CATEGORY
from preprocessing import preprocess_seg_mask, get_image, get_mask

# wide layout
st.set_page_config(layout="wide")


def on_upload() -> None:
    """Upload image to the canvas."""
    if 'input_image' in st.session_state and st.session_state['input_image'] is not None:
        image = Image.open(st.session_state['input_image']).convert('RGB')
        st.session_state['initial_image'] = image
        if 'seg' in st.session_state:
            del st.session_state['seg']
        if 'unique_colors' in st.session_state:
            del st.session_state['unique_colors']
        if 'output_image' in st.session_state:
            del st.session_state['output_image']


def check_reset_state() -> bool:
    """Check whether the UI elements need to be reset
    Returns:
        bool: True if the UI elements need to be reset, False otherwise
    """
    if ('reset_canvas' in st.session_state and st.session_state['reset_canvas']):
        st.session_state['reset_canvas'] = False
        return True
    st.session_state['reset_canvas'] = False
    return False


def move_image(source: Union[str, Image.Image],
               dest: str,
               rerun: bool = True,
               remove_state: bool = True) -> None:
    """Move image from source to destination.
    Args:
        source (Union[str, Image.Image]): source image
        dest (str): destination image location
        rerun (bool, optional): rerun streamlit. Defaults to True.
        remove_state (bool, optional): remove the canvas state. Defaults to True.
    """
    source_image = source if isinstance(source, Image.Image) else st.session_state[source]

    if remove_state:
        st.session_state['reset_canvas'] = True
        if 'seg' in st.session_state:
            del st.session_state['seg']
        if 'unique_colors' in st.session_state:
            del st.session_state['unique_colors']

    st.session_state[dest] = source_image
    if rerun:
        st.experimental_rerun()


def on_change_radio() -> None:
    """Reset the UI elements when the radio button is changed."""
    st.session_state['reset_canvas'] = True


def make_canvas_dict(canvas_color, brush, paint_mode, _reset_state):
    canvas_dict = dict(
        fill_color=canvas_color,
        stroke_color=canvas_color,
        background_color="#FFFFFF",
        background_image=st.session_state['initial_image'] if 'initial_image' in st.session_state else None,
        stroke_width=brush,
        initial_drawing={'version': '4.4.0', 'objects': []} if _reset_state else None,
        update_streamlit=True,
        height=512,
        width=512,
        drawing_mode=paint_mode,
        key="canvas",
    )
    return canvas_dict  

def make_prompt_row():
    col_0_0, col_0_1 = st.columns(2)
    with col_0_0:
        st.text_input(label="Positive prompt", value="a photograph of a room, interior design, 4k, high resolution", key='positive_prompt')
    with col_0_1:
        st.text_input(label="Negative prompt", value="", key='negative_prompt')

def make_sidebar():
    with st.sidebar:
        input_image = st.file_uploader("", type=["png", "jpg"], key='input_image', on_change=on_upload)
        generation_mode = st.selectbox("Generation mode", ["Re-generate objects",
                                                           "Segmentation conditioning",
                                                           "Inpainting"], on_change=on_change_radio)


        if generation_mode == "Segmentation conditioning":
            paint_mode = st.sidebar.selectbox("Painting mode", ("freedraw", "polygon"))
            if paint_mode == "freedraw":
                brush = st.slider("Stroke width", 5, 140, 100, key='slider_seg')
            else:
                brush = 5
    
            category_chooser = st.sidebar.selectbox("Filter on category", list(
                COLOR_MAPPING_CATEGORY.keys()), index=0, key='category_chooser')

            chosen_colors = list(COLOR_MAPPING_CATEGORY[category_chooser].keys())

            color_chooser = st.sidebar.selectbox(
                "Choose a color", chosen_colors, index=0, format_func=map_colors, key='color_chooser'
            )

        elif generation_mode == "Re-generate objects":
            color_chooser = "rgba(0, 0, 0, 0.0)"
            paint_mode = 'freedraw'
            brush = 0

        else:
            paint_mode = st.sidebar.selectbox("Painting mode", ("freedraw", "polygon"))
            if paint_mode == "freedraw":
                brush = st.slider("Stroke width", 5, 140, 100, key='slider_seg')
            else:
                brush = 5

            color_chooser = "#000000"
    return input_image, generation_mode, brush, color_chooser, paint_mode


def make_output_image():
    if 'output_image' in st.session_state:
        output_image = st.session_state['output_image']
        if isinstance(output_image, np.ndarray):
            output_image = Image.fromarray(output_image)

        if isinstance(output_image, Image.Image):
            output_image = output_image.resize((512, 512))
    else:
        output_image = Image.new('RGB', (512, 512), (255, 255, 255))

    st.write("#### Output image")
    st.image(output_image, width=512)
    if st.button("Move to input image"):
        move_image('output_image', 'initial_image', remove_state=True, rerun=True)

def make_editing_canvas(canvas_color, brush, _reset_state, generation_mode, paint_mode):
    st.write("#### Input image")
    canvas_dict = make_canvas_dict(
        canvas_color=canvas_color,
        paint_mode=paint_mode,
        brush=brush,
        _reset_state=_reset_state
    )
    if generation_mode == "Segmentation conditioning":
        canvas = st_canvas(
            **canvas_dict,
        )

        if st.button("generate image", key='generate_button'):
            image = get_image()
            print("Preparing image segmentation")
            real_seg = segment_image(Image.fromarray(image))
            mask, seg = preprocess_seg_mask(canvas, real_seg)

            with st.spinner(text="Generating image"):
                print("Making image")
                result_image = make_image_controlnet(image=image,
                                                        mask_image=mask,
                                                        controlnet_conditioning_image=seg,
                                                        positive_prompt=st.session_state['positive_prompt'],
                                                        negative_prompt=st.session_state['negative_prompt'],
                                                        seed=random.randint(0, 100000) # nosec
                                                        )[0]
                if isinstance(result_image, np.ndarray):
                    result_image = Image.fromarray(result_image)
                st.session_state['output_image'] = result_image


    elif generation_mode == "Re-generate objects":
        canvas = st_canvas(
            **canvas_dict,
        )
        if 'seg' not in st.session_state:
            with st.spinner(text="Preparing image segmentation"):
                image = get_image()
                real_seg = np.array(segment_image(Image.fromarray(image)))
                st.session_state['seg'] = real_seg

        if 'unique_colors' not in st.session_state:
            real_seg = st.session_state['seg']
            unique_colors = np.unique(real_seg.reshape(-1, real_seg.shape[2]), axis=0)
            unique_colors = [tuple(color) for color in unique_colors]
            st.session_state['unique_colors'] = unique_colors

        chosen_colors = st.multiselect(
            label="Choose which concepts you want to regenerate in the image",
            options=st.session_state['unique_colors'],
            key='chosen_colors',
            default=st.session_state['unique_colors'],
            format_func=map_colors_rgb,
        )
        with st.expander("Explanation", expanded=False):
            st.write("This mode allows you to choose which objects you want to re-generate in the image. "
                 "Use the selection dropdown to add or remove objects. If you are ready, press the generate button"
                 " to generate the image, which can take up to 30 seconds. If you want to improve the generated image, click"
                 " the 'move image to input' button."
                 )

        if st.button("generate image", key='generate_button'):
            image = get_image()
            print(chosen_colors)

            segmentation = st.session_state['seg']
            mask = np.zeros_like(segmentation)
            for color in chosen_colors:
                # if the color is in the segmentation, set mask to 1
                mask[np.where((segmentation == color).all(axis=2))] = 1

            with st.spinner(text="Generating image"):
                result_image = make_image_controlnet(image=image,
                                                        mask_image=mask,
                                                        controlnet_conditioning_image=segmentation,
                                                        positive_prompt=st.session_state['positive_prompt'],
                                                        negative_prompt=st.session_state['negative_prompt'],
                                                        seed=random.randint(0, 100000) # nosec
                                                        )[0]
                if isinstance(result_image, np.ndarray):
                    result_image = Image.fromarray(result_image)
                st.session_state['output_image'] = result_image

    elif generation_mode == "Inpainting":
        image = get_image()

        canvas = st_canvas(
            **canvas_dict,
        )

        if st.button("generate images", key='generate_button'):
            canvas_mask = canvas.image_data
            if not isinstance(canvas_mask, np.ndarray):
                canvas_mask = np.array(canvas_mask)
            mask = get_mask(canvas_mask)

            with st.spinner(text="Generating new images"):
                print("Making image")
                result_image = make_inpainting(positive_prompt=st.session_state['positive_prompt'],
                                                image=image,
                                                mask_image=mask,
                                                negative_prompt=st.session_state['negative_prompt'],
                                                )[0]
                if isinstance(result_image, np.ndarray):
                    result_image = Image.fromarray(result_image)
                st.session_state['output_image'] = result_image

def main():
    # center text
    st.write("## Controlnet sprint - interior design", unsafe_allow_html=True)

    input_image, generation_mode, brush, color_chooser, paint_mode = make_sidebar()

    # check if there is an input_image
    if not ('input_image' in st.session_state and st.session_state['input_image'] is not None):
        print("Image not present")
        st.success("Upload an image to start")
    else:
        make_prompt_row()

        _reset_state = check_reset_state()

        col1, col2 = st.columns(2)
        with col1:
            make_editing_canvas(canvas_color=color_chooser,
                                brush=brush,
                                _reset_state=_reset_state,
                                generation_mode=generation_mode,
                                paint_mode=paint_mode
                                )

        with col2:
            make_output_image()
            

if __name__ == "__main__":
    main()