# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Conversion script for the NCSNPP checkpoints. """

import argparse
import json

import torch

from diffusers import ScoreSdeVePipeline, ScoreSdeVeScheduler, UNet2DModel


def convert_ncsnpp_checkpoint(checkpoint, config):
    """
    Takes a state dict and the path to
    """
    new_model_architecture = UNet2DModel(**config)
    new_model_architecture.time_proj.W.data = checkpoint["all_modules.0.W"].data
    new_model_architecture.time_proj.weight.data = checkpoint["all_modules.0.W"].data
    new_model_architecture.time_embedding.linear_1.weight.data = checkpoint["all_modules.1.weight"].data
    new_model_architecture.time_embedding.linear_1.bias.data = checkpoint["all_modules.1.bias"].data

    new_model_architecture.time_embedding.linear_2.weight.data = checkpoint["all_modules.2.weight"].data
    new_model_architecture.time_embedding.linear_2.bias.data = checkpoint["all_modules.2.bias"].data

    new_model_architecture.conv_in.weight.data = checkpoint["all_modules.3.weight"].data
    new_model_architecture.conv_in.bias.data = checkpoint["all_modules.3.bias"].data

    new_model_architecture.conv_norm_out.weight.data = checkpoint[list(checkpoint.keys())[-4]].data
    new_model_architecture.conv_norm_out.bias.data = checkpoint[list(checkpoint.keys())[-3]].data
    new_model_architecture.conv_out.weight.data = checkpoint[list(checkpoint.keys())[-2]].data
    new_model_architecture.conv_out.bias.data = checkpoint[list(checkpoint.keys())[-1]].data

    module_index = 4

    def set_attention_weights(new_layer, old_checkpoint, index):
        new_layer.query.weight.data = old_checkpoint[f"all_modules.{index}.NIN_0.W"].data.T
        new_layer.key.weight.data = old_checkpoint[f"all_modules.{index}.NIN_1.W"].data.T
        new_layer.value.weight.data = old_checkpoint[f"all_modules.{index}.NIN_2.W"].data.T

        new_layer.query.bias.data = old_checkpoint[f"all_modules.{index}.NIN_0.b"].data
        new_layer.key.bias.data = old_checkpoint[f"all_modules.{index}.NIN_1.b"].data
        new_layer.value.bias.data = old_checkpoint[f"all_modules.{index}.NIN_2.b"].data

        new_layer.proj_attn.weight.data = old_checkpoint[f"all_modules.{index}.NIN_3.W"].data.T
        new_layer.proj_attn.bias.data = old_checkpoint[f"all_modules.{index}.NIN_3.b"].data

        new_layer.group_norm.weight.data = old_checkpoint[f"all_modules.{index}.GroupNorm_0.weight"].data
        new_layer.group_norm.bias.data = old_checkpoint[f"all_modules.{index}.GroupNorm_0.bias"].data

    def set_resnet_weights(new_layer, old_checkpoint, index):
        new_layer.conv1.weight.data = old_checkpoint[f"all_modules.{index}.Conv_0.weight"].data
        new_layer.conv1.bias.data = old_checkpoint[f"all_modules.{index}.Conv_0.bias"].data
        new_layer.norm1.weight.data = old_checkpoint[f"all_modules.{index}.GroupNorm_0.weight"].data
        new_layer.norm1.bias.data = old_checkpoint[f"all_modules.{index}.GroupNorm_0.bias"].data

        new_layer.conv2.weight.data = old_checkpoint[f"all_modules.{index}.Conv_1.weight"].data
        new_layer.conv2.bias.data = old_checkpoint[f"all_modules.{index}.Conv_1.bias"].data
        new_layer.norm2.weight.data = old_checkpoint[f"all_modules.{index}.GroupNorm_1.weight"].data
        new_layer.norm2.bias.data = old_checkpoint[f"all_modules.{index}.GroupNorm_1.bias"].data

        new_layer.time_emb_proj.weight.data = old_checkpoint[f"all_modules.{index}.Dense_0.weight"].data
        new_layer.time_emb_proj.bias.data = old_checkpoint[f"all_modules.{index}.Dense_0.bias"].data

        if new_layer.in_channels != new_layer.out_channels or new_layer.up or new_layer.down:
            new_layer.conv_shortcut.weight.data = old_checkpoint[f"all_modules.{index}.Conv_2.weight"].data
            new_layer.conv_shortcut.bias.data = old_checkpoint[f"all_modules.{index}.Conv_2.bias"].data

    for i, block in enumerate(new_model_architecture.downsample_blocks):
        has_attentions = hasattr(block, "attentions")
        for j in range(len(block.resnets)):
            set_resnet_weights(block.resnets[j], checkpoint, module_index)
            module_index += 1
            if has_attentions:
                set_attention_weights(block.attentions[j], checkpoint, module_index)
                module_index += 1

        if hasattr(block, "downsamplers") and block.downsamplers is not None:
            set_resnet_weights(block.resnet_down, checkpoint, module_index)
            module_index += 1
            block.skip_conv.weight.data = checkpoint[f"all_modules.{module_index}.Conv_0.weight"].data
            block.skip_conv.bias.data = checkpoint[f"all_modules.{module_index}.Conv_0.bias"].data
            module_index += 1

    set_resnet_weights(new_model_architecture.mid_block.resnets[0], checkpoint, module_index)
    module_index += 1
    set_attention_weights(new_model_architecture.mid_block.attentions[0], checkpoint, module_index)
    module_index += 1
    set_resnet_weights(new_model_architecture.mid_block.resnets[1], checkpoint, module_index)
    module_index += 1

    for i, block in enumerate(new_model_architecture.up_blocks):
        has_attentions = hasattr(block, "attentions")
        for j in range(len(block.resnets)):
            set_resnet_weights(block.resnets[j], checkpoint, module_index)
            module_index += 1
        if has_attentions:
            set_attention_weights(
                block.attentions[0], checkpoint, module_index
            )  # why can there only be a single attention layer for up?
            module_index += 1

        if hasattr(block, "resnet_up") and block.resnet_up is not None:
            block.skip_norm.weight.data = checkpoint[f"all_modules.{module_index}.weight"].data
            block.skip_norm.bias.data = checkpoint[f"all_modules.{module_index}.bias"].data
            module_index += 1
            block.skip_conv.weight.data = checkpoint[f"all_modules.{module_index}.weight"].data
            block.skip_conv.bias.data = checkpoint[f"all_modules.{module_index}.bias"].data
            module_index += 1
            set_resnet_weights(block.resnet_up, checkpoint, module_index)
            module_index += 1

    new_model_architecture.conv_norm_out.weight.data = checkpoint[f"all_modules.{module_index}.weight"].data
    new_model_architecture.conv_norm_out.bias.data = checkpoint[f"all_modules.{module_index}.bias"].data
    module_index += 1
    new_model_architecture.conv_out.weight.data = checkpoint[f"all_modules.{module_index}.weight"].data
    new_model_architecture.conv_out.bias.data = checkpoint[f"all_modules.{module_index}.bias"].data

    return new_model_architecture.state_dict()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--checkpoint_path",
        default="/Users/arthurzucker/Work/diffusers/ArthurZ/diffusion_pytorch_model.bin",
        type=str,
        required=False,
        help="Path to the checkpoint to convert.",
    )

    parser.add_argument(
        "--config_file",
        default="/Users/arthurzucker/Work/diffusers/ArthurZ/config.json",
        type=str,
        required=False,
        help="The config json file corresponding to the architecture.",
    )

    parser.add_argument(
        "--dump_path",
        default="/Users/arthurzucker/Work/diffusers/ArthurZ/diffusion_model_new.pt",
        type=str,
        required=False,
        help="Path to the output model.",
    )

    args = parser.parse_args()

    checkpoint = torch.load(args.checkpoint_path, map_location="cpu")

    with open(args.config_file) as f:
        config = json.loads(f.read())

    converted_checkpoint = convert_ncsnpp_checkpoint(
        checkpoint,
        config,
    )

    if "sde" in config:
        del config["sde"]

    model = UNet2DModel(**config)
    model.load_state_dict(converted_checkpoint)

    try:
        scheduler = ScoreSdeVeScheduler.from_config("/".join(args.checkpoint_path.split("/")[:-1]))

        pipe = ScoreSdeVePipeline(unet=model, scheduler=scheduler)
        pipe.save_pretrained(args.dump_path)
    except:  # noqa: E722
        model.save_pretrained(args.dump_path)