# Edit Anything trained with Stable Diffusion + ControlNet + SAM  + BLIP2
from torchvision.utils import save_image
from PIL import Image
from pytorch_lightning import seed_everything
import subprocess
from collections import OrderedDict
import re
import cv2
import einops
import gradio as gr
import numpy as np
import torch
import random
import os
import requests
from io import BytesIO
from annotator.util import resize_image, HWC3, resize_points, get_bounding_box

import torch
from safetensors.torch import load_file
from collections import defaultdict
from diffusers import StableDiffusionControlNetPipeline
from diffusers import ControlNetModel, UniPCMultistepScheduler

from utils.stable_diffusion_controlnet import ControlNetModel2
from utils.stable_diffusion_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline, \
    StableDiffusionControlNetInpaintMixingPipeline, prepare_mask_image
# need the latest transformers
# pip install git+https://github.com/huggingface/transformers.git
from transformers import AutoProcessor, Blip2ForConditionalGeneration
from diffusers import ControlNetModel, DiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
import PIL.Image

# Segment-Anything init.
# pip install git+https://github.com/facebookresearch/segment-anything.git
try:
    from segment_anything import (
        sam_model_registry,
        SamAutomaticMaskGenerator,
        SamPredictor,
    )
except ImportError:
    print("segment_anything not installed")
    result = subprocess.run(
        [
            "pip",
            "install",
            "git+https://github.com/facebookresearch/segment-anything.git",
        ],
        check=True,
    )
    print(f"Install segment_anything {result}")
    from segment_anything import (
        sam_model_registry,
        SamAutomaticMaskGenerator,
        SamPredictor,
    )
if not os.path.exists("./models/sam_vit_h_4b8939.pth"):
    result = subprocess.run(
        [
            "wget",
            "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
            "-P",
            "models",
        ],
        check=True,
    )
    print(f"Download sam_vit_h_4b8939.pth {result}")

device = "cuda" if torch.cuda.is_available() else "cpu"

config_dict = OrderedDict(
    [
        ("LAION Pretrained(v0-4)-SD15", "shgao/edit-anything-v0-4-sd15"),
        ("LAION Pretrained(v0-4)-SD21", "shgao/edit-anything-v0-4-sd21"),
        ("LAION Pretrained(v0-3)-SD21", "shgao/edit-anything-v0-3"),
        ("SAM Pretrained(v0-1)-SD21", "shgao/edit-anything-v0-1-1"),
    ]
)


def init_sam_model(sam_generator=None, mask_predictor=None):
    if sam_generator is not None and mask_predictor is not None:
        return sam_generator, mask_predictor
    sam_checkpoint = "models/sam_vit_h_4b8939.pth"
    model_type = "default"
    sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
    sam.to(device=device)
    sam_generator = (
        SamAutomaticMaskGenerator(
            sam) if sam_generator is None else sam_generator
    )
    mask_predictor = SamPredictor(
        sam) if mask_predictor is None else mask_predictor
    return sam_generator, mask_predictor


def init_blip_processor():
    blip_processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
    return blip_processor


def init_blip_model():
    blip_model = Blip2ForConditionalGeneration.from_pretrained(
        "Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16, device_map="auto"
    )
    return blip_model


def get_pipeline_embeds(pipeline, prompt, negative_prompt, device):
    # https://github.com/huggingface/diffusers/issues/2136
    """Get pipeline embeds for prompts bigger than the maxlength of the pipe
    :param pipeline:
    :param prompt:
    :param negative_prompt:
    :param device:
    :return:
    """
    max_length = pipeline.tokenizer.model_max_length

    # simple way to determine length of tokens
    count_prompt = len(re.split(r", ", prompt))
    count_negative_prompt = len(re.split(r", ", negative_prompt))

    # create the tensor based on which prompt is longer
    if count_prompt >= count_negative_prompt:
        input_ids = pipeline.tokenizer(
            prompt, return_tensors="pt", truncation=False
        ).input_ids.to(device)
        shape_max_length = input_ids.shape[-1]
        negative_ids = pipeline.tokenizer(
            negative_prompt,
            truncation=False,
            padding="max_length",
            max_length=shape_max_length,
            return_tensors="pt",
        ).input_ids.to(device)
    else:
        negative_ids = pipeline.tokenizer(
            negative_prompt, return_tensors="pt", truncation=False
        ).input_ids.to(device)
        shape_max_length = negative_ids.shape[-1]
        input_ids = pipeline.tokenizer(
            prompt,
            return_tensors="pt",
            truncation=False,
            padding="max_length",
            max_length=shape_max_length,
        ).input_ids.to(device)

    concat_embeds = []
    neg_embeds = []
    for i in range(0, shape_max_length, max_length):
        concat_embeds.append(pipeline.text_encoder(input_ids[:, i : i + max_length])[0])
        neg_embeds.append(pipeline.text_encoder(negative_ids[:, i : i + max_length])[0])

    return torch.cat(concat_embeds, dim=1), torch.cat(neg_embeds, dim=1)


def load_lora_weights(pipeline, checkpoint_path, multiplier, device, dtype):
    LORA_PREFIX_UNET = "lora_unet"
    LORA_PREFIX_TEXT_ENCODER = "lora_te"
    # load LoRA weight from .safetensors
    print('device: {}'.format(device))
    if isinstance(checkpoint_path, str):
        state_dict = load_file(checkpoint_path, device=device)

        updates = defaultdict(dict)
        for key, value in state_dict.items():
            # it is suggested to print out the key, it usually will be something like below
            # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"

            layer, elem = key.split(".", 1)
            updates[layer][elem] = value

        # directly update weight in diffusers model
        for layer, elems in updates.items():

            if "text" in layer:
                layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
                curr_layer = pipeline.text_encoder
            else:
                layer_infos = layer.split(LORA_PREFIX_UNET + "_")[-1].split("_")
                curr_layer = pipeline.unet

            # find the target layer
            temp_name = layer_infos.pop(0)
            while len(layer_infos) > -1:
                try:
                    curr_layer = curr_layer.__getattr__(temp_name)
                    if len(layer_infos) > 0:
                        temp_name = layer_infos.pop(0)
                    elif len(layer_infos) == 0:
                        break
                except Exception:
                    if len(temp_name) > 0:
                        temp_name += "_" + layer_infos.pop(0)
                    else:
                        temp_name = layer_infos.pop(0)

            # get elements for this layer
            weight_up = elems["lora_up.weight"].to(dtype)
            weight_down = elems["lora_down.weight"].to(dtype)
            alpha = elems["alpha"]
            if alpha:
                alpha = alpha.item() / weight_up.shape[1]
            else:
                alpha = 1.0

            # update weight
            if len(weight_up.shape) == 4:
                curr_layer.weight.data += (
                    multiplier
                    * alpha
                    * torch.mm(
                        weight_up.squeeze(3).squeeze(2),
                        weight_down.squeeze(3).squeeze(2),
                    )
                    .unsqueeze(2)
                    .unsqueeze(3)
                )
            else:
                curr_layer.weight.data += (
                    multiplier * alpha * torch.mm(weight_up, weight_down)
                )
    else:
        for ckptpath in checkpoint_path:
            state_dict = load_file(ckptpath, device=device)

            updates = defaultdict(dict)
            for key, value in state_dict.items():
                # it is suggested to print out the key, it usually will be something like below
                # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"

                layer, elem = key.split(".", 1)
                updates[layer][elem] = value

            # directly update weight in diffusers model
            for layer, elems in updates.items():
                if "text" in layer:
                    layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split(
                        "_"
                    )
                    curr_layer = pipeline.text_encoder
                else:
                    layer_infos = layer.split(LORA_PREFIX_UNET + "_")[-1].split("_")
                    curr_layer = pipeline.unet

                # find the target layer
                temp_name = layer_infos.pop(0)
                while len(layer_infos) > -1:
                    try:
                        curr_layer = curr_layer.__getattr__(temp_name)
                        if len(layer_infos) > 0:
                            temp_name = layer_infos.pop(0)
                        elif len(layer_infos) == 0:
                            break
                    except Exception:
                        if len(temp_name) > 0:
                            temp_name += "_" + layer_infos.pop(0)
                        else:
                            temp_name = layer_infos.pop(0)

                # get elements for this layer
                weight_up = elems["lora_up.weight"].to(dtype)
                weight_down = elems["lora_down.weight"].to(dtype)
                alpha = elems["alpha"]
                if alpha:
                    alpha = alpha.item() / weight_up.shape[1]
                else:
                    alpha = 1.0

                # update weight
                if len(weight_up.shape) == 4:
                    curr_layer.weight.data += (
                        multiplier
                        * alpha
                        * torch.mm(
                            weight_up.squeeze(3).squeeze(2),
                            weight_down.squeeze(3).squeeze(2),
                        )
                        .unsqueeze(2)
                        .unsqueeze(3)
                    )
                else:
                    curr_layer.weight.data += (
                        multiplier * alpha * torch.mm(weight_up, weight_down)
                    )
    return pipeline


def make_inpaint_condition(image, image_mask):
    image = image / 255.0
    assert (
        image.shape[0:1] == image_mask.shape[0:1]
    ), "image and image_mask must have the same image size"
    image[image_mask > 128] = -1.0  # set as masked pixel
    image = np.expand_dims(image, 0).transpose(0, 3, 1, 2)
    image = torch.from_numpy(image)
    return image


def obtain_generation_model(
    base_model_path,
    lora_model_path,
    controlnet_path,
    generation_only=False,
    extra_inpaint=True,
    lora_weight=1.0,
):
    controlnet = []
    controlnet.append(
        ControlNetModel2.from_pretrained(
            controlnet_path, torch_dtype=torch.float16)
    )  # sam control
    if (not generation_only) and extra_inpaint:  # inpainting control
        print("Warning: ControlNet based inpainting model only support SD1.5 for now.")
        controlnet.append(
            ControlNetModel.from_pretrained(
                "lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float16
            )  # inpainting controlnet
        )

    if generation_only and extra_inpaint:
        pipe = StableDiffusionControlNetPipeline.from_pretrained(
            base_model_path,
            controlnet=controlnet,
            torch_dtype=torch.float16,
            safety_checker=None,
        )
    else:
        pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
            base_model_path,
            controlnet=controlnet,
            torch_dtype=torch.float16,
            safety_checker=None,
        )
    if lora_model_path is not None:
        pipe = load_lora_weights(
            pipe, [lora_model_path], lora_weight, "cpu", torch.float32
        )
    # speed up diffusion process with faster scheduler and memory optimization
    pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
    # remove following line if xformers is not installed
    pipe.enable_xformers_memory_efficient_attention()

    pipe.enable_model_cpu_offload()
    return pipe


def obtain_tile_model(base_model_path, lora_model_path, lora_weight=1.0):
    controlnet = ControlNetModel2.from_pretrained(
        "lllyasviel/control_v11f1e_sd15_tile", torch_dtype=torch.float16
    )  # tile controlnet
    if (
        base_model_path == "runwayml/stable-diffusion-v1-5"
        or base_model_path == "stabilityai/stable-diffusion-2-inpainting"
    ):
        print("base_model_path", base_model_path)
        pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
            "runwayml/stable-diffusion-v1-5",
            controlnet=controlnet,
            torch_dtype=torch.float16,
            safety_checker=None,
        )
    else:
        pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
            base_model_path,
            controlnet=controlnet,
            torch_dtype=torch.float16,
            safety_checker=None,
        )
    if lora_model_path is not None:
        pipe = load_lora_weights(
            pipe, [lora_model_path], lora_weight, "cpu", torch.float32
        )
    # speed up diffusion process with faster scheduler and memory optimization
    pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
    # remove following line if xformers is not installed
    pipe.enable_xformers_memory_efficient_attention()

    pipe.enable_model_cpu_offload()
    return pipe


def show_anns(anns):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x["area"]), reverse=True)
    full_img = None

    # for ann in sorted_anns:
    for i in range(len(sorted_anns)):
        ann = anns[i]
        m = ann["segmentation"]
        if full_img is None:
            full_img = np.zeros((m.shape[0], m.shape[1], 3))
            map = np.zeros((m.shape[0], m.shape[1]), dtype=np.uint16)
        map[m != 0] = i + 1
        color_mask = np.random.random((1, 3)).tolist()[0]
        full_img[m != 0] = color_mask
    full_img = full_img * 255
    # anno encoding from https://github.com/LUSSeg/ImageNet-S
    res = np.zeros((map.shape[0], map.shape[1], 3))
    res[:, :, 0] = map % 256
    res[:, :, 1] = map // 256
    res.astype(np.float32)
    full_img = Image.fromarray(np.uint8(full_img))
    return full_img, res


class EditAnythingLoraModel:
    def __init__(
        self,
        base_model_path="../chilloutmix_NiPrunedFp32Fix",
        lora_model_path="../40806/mix4",
        use_blip=True,
        blip_processor=None,
        blip_model=None,
        sam_generator=None,
        controlmodel_name="LAION Pretrained(v0-4)-SD15",
        # used when the base model is not an inpainting model.
        extra_inpaint=True,
        tile_model=None,
        lora_weight=1.0,
        alpha_mixing=None,
        mask_predictor=None,
    ):
        self.device = device
        self.use_blip = use_blip

        # Diffusion init using diffusers.
        self.default_controlnet_path = config_dict[controlmodel_name]
        self.base_model_path = base_model_path
        self.lora_model_path = lora_model_path
        self.defalut_enable_all_generate = False
        self.extra_inpaint = extra_inpaint
        self.last_ref_infer = False
        self.pipe = obtain_generation_model(
            base_model_path,
            lora_model_path,
            self.default_controlnet_path,
            generation_only=False,
            extra_inpaint=extra_inpaint,
            lora_weight=lora_weight,
        )
        # self.pipe.load_textual_inversion("textual_inversion_cat/learned_embeds.bin")
        # Segment-Anything init.
        self.sam_generator, self.mask_predictor = init_sam_model(
            sam_generator, mask_predictor
        )
        # BLIP2 init.
        if use_blip:
            if blip_processor is not None:
                self.blip_processor = blip_processor
            else:
                self.blip_processor = init_blip_processor()

            if blip_model is not None:
                self.blip_model = blip_model
            else:
                self.blip_model = init_blip_model()

        # tile model init.
        if tile_model is not None:
            self.tile_pipe = tile_model
        else:
            self.tile_pipe = obtain_tile_model(
                base_model_path, lora_model_path, lora_weight=lora_weight
            )

    def get_blip2_text(self, image):
        inputs = self.blip_processor(image, return_tensors="pt").to(
            self.device, torch.float16
        )
        generated_ids = self.blip_model.generate(**inputs, max_new_tokens=50)
        generated_text = self.blip_processor.batch_decode(
            generated_ids, skip_special_tokens=True
        )[0].strip()
        return generated_text

    def get_sam_control(self, image):
        masks = self.sam_generator.generate(image)
        full_img, res = show_anns(masks)
        return full_img, res

    def get_click_mask(self, image, clicked_points):
        self.mask_predictor.set_image(image)
        # Separate the points and labels
        points, labels = zip(*[(point[:2], point[2])
                             for point in clicked_points])

        # Convert the points and labels to numpy arrays
        input_point = np.array(points)
        input_label = np.array(labels)

        masks, _, _ = self.mask_predictor.predict(
            point_coords=input_point,
            point_labels=input_label,
            multimask_output=False,
        )

        return masks

    @torch.inference_mode()
    def process_image_click(
        self,
        original_image: gr.Image,
        point_prompt: gr.Radio,
        clicked_points: gr.State,
        image_resolution,
        evt: gr.SelectData,
    ):
        # Get the clicked coordinates
        clicked_coords = evt.index
        x, y = clicked_coords
        label = point_prompt
        lab = 1 if label == "Foreground Point" else 0
        clicked_points.append((x, y, lab))

        input_image = np.array(original_image, dtype=np.uint8)
        H, W, C = input_image.shape
        input_image = HWC3(input_image)
        img = resize_image(input_image, image_resolution)

        # Update the clicked_points
        resized_points = resize_points(
            clicked_points, input_image.shape, image_resolution
        )
        mask_click_np = self.get_click_mask(img, resized_points)

        # Convert mask_click_np to HWC format
        mask_click_np = np.transpose(mask_click_np, (1, 2, 0)) * 255.0

        mask_image = HWC3(mask_click_np.astype(np.uint8))
        mask_image = cv2.resize(mask_image, (W, H), interpolation=cv2.INTER_LINEAR)
        # mask_image = Image.fromarray(mask_image_tmp)

        # Draw circles for all clicked points
        edited_image = input_image
        for x, y, lab in clicked_points:
            # Set the circle color based on the label
            color = (255, 0, 0) if lab == 1 else (0, 0, 255)

            # Draw the circle
            edited_image = cv2.circle(edited_image, (x, y), 20, color, -1)

        # Set the opacity for the mask_image and edited_image
        opacity_mask = 0.75
        opacity_edited = 1.0

        # Combine the edited_image and the mask_image using cv2.addWeighted()
        overlay_image = cv2.addWeighted(
            edited_image,
            opacity_edited,
            (mask_image *
             np.array([0 / 255, 255 / 255, 0 / 255])).astype(np.uint8),
            opacity_mask,
            0,
        )

        return (
            Image.fromarray(overlay_image),
            clicked_points,
            Image.fromarray(mask_image),
        )

    @torch.inference_mode()
    def process(
        self,
        source_image,
        enable_all_generate,
        mask_image,
        control_scale,
        enable_auto_prompt,
        a_prompt,
        n_prompt,
        num_samples,
        image_resolution,
        detect_resolution,
        ddim_steps,
        guess_mode,
        scale,
        seed,
        eta,
        enable_tile=True,
        refine_alignment_ratio=None,
        refine_image_resolution=None,
        alpha_weight=0.5,
        use_scale_map=False,
        condition_model=None,
        ref_image=None,
        attention_auto_machine_weight=1.0,
        gn_auto_machine_weight=1.0,
        style_fidelity=0.5,
        reference_attn=True,
        reference_adain=True,
        ref_prompt=None,
        ref_sam_scale=None,
        ref_inpaint_scale=None,
        ref_auto_prompt=False,
        ref_textinv=True,
        ref_textinv_path=None,
    ):

        if condition_model is None or condition_model == "EditAnything":
            this_controlnet_path = self.default_controlnet_path
        else:
            this_controlnet_path = condition_model
        input_image = (
            source_image["image"]
            if isinstance(source_image, dict)
            else np.array(source_image, dtype=np.uint8)
        )
        if mask_image is None:
            if enable_all_generate != self.defalut_enable_all_generate:
                self.pipe = obtain_generation_model(
                    self.base_model_path,
                    self.lora_model_path,
                    this_controlnet_path,
                    enable_all_generate,
                    self.extra_inpaint,
                )
                self.defalut_enable_all_generate = enable_all_generate
            if enable_all_generate:
                print(
                    "source_image",
                    source_image["mask"].shape,
                    input_image.shape,
                )
                mask_image = (
                    np.ones((input_image.shape[0],
                            input_image.shape[1], 3)) * 255
                )
            else:
                mask_image = source_image["mask"]
        else:
            mask_image = np.array(mask_image, dtype=np.uint8)
        if self.default_controlnet_path != this_controlnet_path:
            print(
                "To Use:",
                this_controlnet_path,
                "Current:",
                self.default_controlnet_path,
            )
            print("Change condition model to:", this_controlnet_path)
            self.pipe = obtain_generation_model(
                self.base_model_path,
                self.lora_model_path,
                this_controlnet_path,
                enable_all_generate,
                self.extra_inpaint,
            )
            self.default_controlnet_path = this_controlnet_path
            torch.cuda.empty_cache()
        if self.last_ref_infer:
            print("Redefine the model to overwrite the ref mode")
            self.pipe = obtain_generation_model(
                self.base_model_path,
                self.lora_model_path,
                this_controlnet_path,
                enable_all_generate,
                self.extra_inpaint,
            )
            self.last_ref_infer = False

        if ref_image is not None:
            ref_mask = ref_image["mask"]
            ref_image = ref_image["image"]
            if ref_auto_prompt or ref_textinv:
                bbox = get_bounding_box(
                    np.array(ref_mask) / 255
                )  # reverse the mask to make 1 the choosen region
                cropped_ref_mask = ref_mask.crop(
                    (bbox[0], bbox[1], bbox[2], bbox[3]))
                cropped_ref_image = ref_image.crop(
                    (bbox[0], bbox[1], bbox[2], bbox[3]))
                # cropped_ref_image.save("debug.jpg")
                cropped_ref_image = np.array(cropped_ref_image) * (
                    np.array(cropped_ref_mask)[:, :, :3] / 255.0
                )
                cropped_ref_image = Image.fromarray(
                    cropped_ref_image.astype("uint8"))

            if ref_auto_prompt:
                generated_prompt = self.get_blip2_text(cropped_ref_image)
                ref_prompt += generated_prompt
                a_prompt += generated_prompt
            print("Generated ref text:", ref_prompt)
            print("Generated input text:", a_prompt)
            self.last_ref_infer = True
            # ref_image = cropped_ref_image
            # ref_mask = cropped_ref_mask
            if ref_textinv:
                try:
                    self.pipe.load_textual_inversion(ref_textinv_path)
                    print("Load textinv embedding from:", ref_textinv_path)
                except:
                    print("No textinvert embeddings found.")
                    ref_data_path = "./utils/tmp/textinv/img"
                    if not os.path.exists(ref_data_path):   
                        os.makedirs(ref_data_path)
                    cropped_ref_image.save(os.path.join(ref_data_path, 'ref.png'))
                    print("Ref image region is save to:", ref_data_path)
                    print("Plese finetune with run_texutal_inversion.sh in utils folder to get the textinvert embeddings.")

        else:
            ref_mask = None

        with torch.no_grad():
            if self.use_blip and enable_auto_prompt:
                print("Generating text:")
                blip2_prompt = self.get_blip2_text(input_image)
                print("Generated text:", blip2_prompt)
                if len(a_prompt) > 0:
                    a_prompt = blip2_prompt + "," + a_prompt
                else:
                    a_prompt = blip2_prompt

            input_image = HWC3(input_image)

            img = resize_image(input_image, image_resolution)
            H, W, C = img.shape

            print("Generating SAM seg:")
            # the default SAM model is trained with 1024 size.
            full_segmask, detected_map = self.get_sam_control(
                resize_image(input_image, detect_resolution)
            )

            detected_map = HWC3(detected_map.astype(np.uint8))
            detected_map = cv2.resize(
                detected_map, (W, H), interpolation=cv2.INTER_LINEAR
            )

            control = torch.from_numpy(detected_map.copy()).float().cuda()
            control = torch.stack([control for _ in range(num_samples)], dim=0)
            control = einops.rearrange(control, "b h w c -> b c h w").clone()

            mask_imag_ori = HWC3(mask_image.astype(np.uint8))
            mask_image_tmp = cv2.resize(
                mask_imag_ori, (W, H), interpolation=cv2.INTER_LINEAR
            )
            mask_image = Image.fromarray(mask_image_tmp)

            if seed == -1:
                seed = random.randint(0, 65535)
            seed_everything(seed)
            generator = torch.manual_seed(seed)
            postive_prompt = a_prompt
            negative_prompt = n_prompt
            prompt_embeds, negative_prompt_embeds = get_pipeline_embeds(
                self.pipe, postive_prompt, negative_prompt, "cuda"
            )
            prompt_embeds = torch.cat([prompt_embeds] * num_samples, dim=0)
            negative_prompt_embeds = torch.cat(
                [negative_prompt_embeds] * num_samples, dim=0
            )

            if enable_all_generate and self.extra_inpaint:
                self.pipe.safety_checker = lambda images, clip_input: (
                    images, False)
                if ref_image is not None:
                    print("Not support yet.")
                    return
                x_samples = self.pipe(
                    prompt_embeds=prompt_embeds,
                    negative_prompt_embeds=negative_prompt_embeds,
                    num_images_per_prompt=num_samples,
                    num_inference_steps=ddim_steps,
                    generator=generator,
                    height=H,
                    width=W,
                    image=[control.type(torch.float16)],
                    controlnet_conditioning_scale=[float(control_scale)],
                    guidance_scale=scale,
                    guess_mode=guess_mode,
                ).images
            else:
                multi_condition_image = []
                multi_condition_scale = []
                multi_condition_image.append(control.type(torch.float16))
                multi_condition_scale.append(float(control_scale))
                ref_multi_condition_scale = []
                if ref_image is not None:
                    ref_multi_condition_scale.append(float(ref_sam_scale))
                if self.extra_inpaint:
                    inpaint_image = make_inpaint_condition(img, mask_image_tmp)
                    multi_condition_image.append(
                        inpaint_image.type(torch.float16))
                    multi_condition_scale.append(1.0)
                    if ref_image is not None:
                        ref_multi_condition_scale.append(
                            float(ref_inpaint_scale))
                if use_scale_map:
                    scale_map_tmp = source_image["mask"]
                    tmp = HWC3(scale_map_tmp.astype(np.uint8))
                    scale_map_tmp = cv2.resize(
                        tmp, (W, H), interpolation=cv2.INTER_LINEAR)
                    scale_map_tmp = Image.fromarray(scale_map_tmp)
                    controlnet_conditioning_scale_map = 1.0 - \
                        prepare_mask_image(scale_map_tmp).float()
                    print('scale map:', controlnet_conditioning_scale_map.size())
                else:
                    controlnet_conditioning_scale_map = None

                if isinstance(self.pipe, StableDiffusionControlNetInpaintMixingPipeline):
                    x_samples = self.pipe(
                        image=img,
                        mask_image=mask_image,
                        prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
                        num_images_per_prompt=num_samples,
                        num_inference_steps=ddim_steps,
                        generator=generator,
                        controlnet_conditioning_image=multi_condition_image,
                        height=H,
                        width=W,
                        controlnet_conditioning_scale=multi_condition_scale,
                        guidance_scale=scale,
                        alpha_weight=alpha_weight,
                        controlnet_conditioning_scale_map=controlnet_conditioning_scale_map
                    ).images
                else:
                    x_samples = self.pipe(
                        image=img,
                        mask_image=mask_image,
                        prompt_embeds=prompt_embeds,
                        negative_prompt_embeds=negative_prompt_embeds,
                        num_images_per_prompt=num_samples,
                        num_inference_steps=ddim_steps,
                        generator=generator,
                        controlnet_conditioning_image=multi_condition_image,
                        height=H,
                        width=W,
                        controlnet_conditioning_scale=multi_condition_scale,
                        guidance_scale=scale,
                        ref_image=ref_image,
                        ref_mask=ref_mask,
                        ref_prompt=ref_prompt,
                        attention_auto_machine_weight=attention_auto_machine_weight,
                        gn_auto_machine_weight=gn_auto_machine_weight,
                        style_fidelity=style_fidelity,
                        reference_attn=reference_attn,
                        reference_adain=reference_adain,
                        ref_controlnet_conditioning_scale=ref_multi_condition_scale,
                        guess_mode=guess_mode,
                    ).images
            results = [x_samples[i] for i in range(num_samples)]

            results_tile = []
            if enable_tile:
                prompt_embeds, negative_prompt_embeds = get_pipeline_embeds(
                    self.tile_pipe, postive_prompt, negative_prompt, "cuda"
                )
                for i in range(num_samples):
                    img_tile = PIL.Image.fromarray(
                        resize_image(
                            np.array(x_samples[i]), refine_image_resolution)
                    )
                    if i == 0:
                        mask_image_tile = cv2.resize(
                            mask_imag_ori,
                            (img_tile.size[0], img_tile.size[1]),
                            interpolation=cv2.INTER_LINEAR,
                        )
                        mask_image_tile = Image.fromarray(mask_image_tile)
                    if isinstance(self.pipe, StableDiffusionControlNetInpaintMixingPipeline):
                        x_samples_tile = self.tile_pipe(
                            image=img_tile,
                            mask_image=mask_image_tile,
                            prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
                            num_images_per_prompt=1,
                            num_inference_steps=ddim_steps,
                            generator=generator,
                            controlnet_conditioning_image=img_tile,
                            height=img_tile.size[1],
                            width=img_tile.size[0],
                            controlnet_conditioning_scale=1.0,
                            alignment_ratio=refine_alignment_ratio,
                            guidance_scale=scale,
                            alpha_weight=alpha_weight,
                            controlnet_conditioning_scale_map=controlnet_conditioning_scale_map
                        ).images
                    else:
                        x_samples_tile = self.tile_pipe(
                            image=img_tile,
                            mask_image=mask_image_tile,
                            prompt_embeds=prompt_embeds,
                            negative_prompt_embeds=negative_prompt_embeds,
                            num_images_per_prompt=1,
                            num_inference_steps=ddim_steps,
                            generator=generator,
                            controlnet_conditioning_image=img_tile,
                            height=img_tile.size[1],
                            width=img_tile.size[0],
                            controlnet_conditioning_scale=1.0,
                            alignment_ratio=refine_alignment_ratio,
                            guidance_scale=scale,
                            guess_mode=guess_mode,
                        ).images
                    results_tile += x_samples_tile

        return results_tile, results, [full_segmask, mask_image], postive_prompt

    def download_image(url):
        response = requests.get(url)
        return Image.open(BytesIO(response.content)).convert("RGB")