import sys
from pathlib import Path

import torch

from .. import DEVICE, MODEL_REPO_ID, logger
from ..utils.base_model import BaseModel

gim_path = Path(__file__).parents[2] / "third_party/gim"
sys.path.append(str(gim_path))


def load_model(weight_name, checkpoints_path):
    # load model
    model = None
    detector = None
    if weight_name == "gim_dkm":
        from networks.dkm.models.model_zoo.DKMv3 import DKMv3

        model = DKMv3(weights=None, h=672, w=896)
    elif weight_name == "gim_loftr":
        from networks.loftr.config import get_cfg_defaults
        from networks.loftr.loftr import LoFTR
        from networks.loftr.misc import lower_config

        model = LoFTR(lower_config(get_cfg_defaults())["loftr"])
    elif weight_name == "gim_lightglue":
        from networks.lightglue.models.matchers.lightglue import LightGlue
        from networks.lightglue.superpoint import SuperPoint

        detector = SuperPoint(
            {
                "max_num_keypoints": 2048,
                "force_num_keypoints": True,
                "detection_threshold": 0.0,
                "nms_radius": 3,
                "trainable": False,
            }
        )
        model = LightGlue(
            {
                "filter_threshold": 0.1,
                "flash": False,
                "checkpointed": True,
            }
        )

    # load state dict
    if weight_name == "gim_dkm":
        state_dict = torch.load(checkpoints_path, map_location="cpu")
        if "state_dict" in state_dict.keys():
            state_dict = state_dict["state_dict"]
        for k in list(state_dict.keys()):
            if k.startswith("model."):
                state_dict[k.replace("model.", "", 1)] = state_dict.pop(k)
            if "encoder.net.fc" in k:
                state_dict.pop(k)
        model.load_state_dict(state_dict)

    elif weight_name == "gim_loftr":
        state_dict = torch.load(checkpoints_path, map_location="cpu")
        if "state_dict" in state_dict.keys():
            state_dict = state_dict["state_dict"]
        model.load_state_dict(state_dict)

    elif weight_name == "gim_lightglue":
        state_dict = torch.load(checkpoints_path, map_location="cpu")
        if "state_dict" in state_dict.keys():
            state_dict = state_dict["state_dict"]
        for k in list(state_dict.keys()):
            if k.startswith("model."):
                state_dict.pop(k)
            if k.startswith("superpoint."):
                state_dict[k.replace("superpoint.", "", 1)] = state_dict.pop(k)
        detector.load_state_dict(state_dict)

        state_dict = torch.load(checkpoints_path, map_location="cpu")
        if "state_dict" in state_dict.keys():
            state_dict = state_dict["state_dict"]
        for k in list(state_dict.keys()):
            if k.startswith("superpoint."):
                state_dict.pop(k)
            if k.startswith("model."):
                state_dict[k.replace("model.", "", 1)] = state_dict.pop(k)
        model.load_state_dict(state_dict)

    # eval mode
    if detector is not None:
        detector = detector.eval().to(DEVICE)
    model = model.eval().to(DEVICE)
    return model


class GIM(BaseModel):
    default_conf = {
        "match_threshold": 0.2,
        "checkpoint_dir": gim_path / "weights",
        "weights": "gim_dkm",
    }
    required_inputs = [
        "image0",
        "image1",
    ]
    ckpt_name_dict = {
        "gim_dkm": "gim_dkm_100h.ckpt",
        "gim_loftr": "gim_loftr_50h.ckpt",
        "gim_lightglue": "gim_lightglue_100h.ckpt",
    }

    def _init(self, conf):
        ckpt_name = self.ckpt_name_dict[conf["weights"]]
        model_path = self._download_model(
            repo_id=MODEL_REPO_ID,
            filename="{}/{}".format(Path(__file__).stem, ckpt_name),
        )
        self.aspect_ratio = 896 / 672
        model = load_model(conf["weights"], model_path)
        self.net = model
        logger.info("Loaded GIM model")

    def pad_image(self, image, aspect_ratio):
        new_width = max(image.shape[3], int(image.shape[2] * aspect_ratio))
        new_height = max(image.shape[2], int(image.shape[3] / aspect_ratio))
        pad_width = new_width - image.shape[3]
        pad_height = new_height - image.shape[2]
        return torch.nn.functional.pad(
            image,
            (
                pad_width // 2,
                pad_width - pad_width // 2,
                pad_height // 2,
                pad_height - pad_height // 2,
            ),
        )

    def rescale_kpts(self, sparse_matches, shape0, shape1):
        kpts0 = torch.stack(
            (
                shape0[1] * (sparse_matches[:, 0] + 1) / 2,
                shape0[0] * (sparse_matches[:, 1] + 1) / 2,
            ),
            dim=-1,
        )
        kpts1 = torch.stack(
            (
                shape1[1] * (sparse_matches[:, 2] + 1) / 2,
                shape1[0] * (sparse_matches[:, 3] + 1) / 2,
            ),
            dim=-1,
        )
        return kpts0, kpts1

    def compute_mask(self, kpts0, kpts1, orig_shape0, orig_shape1):
        mask = (
            (kpts0[:, 0] > 0)
            & (kpts0[:, 1] > 0)
            & (kpts1[:, 0] > 0)
            & (kpts1[:, 1] > 0)
        )
        mask &= (
            (kpts0[:, 0] <= (orig_shape0[1] - 1))
            & (kpts1[:, 0] <= (orig_shape1[1] - 1))
            & (kpts0[:, 1] <= (orig_shape0[0] - 1))
            & (kpts1[:, 1] <= (orig_shape1[0] - 1))
        )
        return mask

    def _forward(self, data):
        # TODO: only support dkm+gim
        image0, image1 = (
            self.pad_image(data["image0"], self.aspect_ratio),
            self.pad_image(data["image1"], self.aspect_ratio),
        )
        dense_matches, dense_certainty = self.net.match(image0, image1)
        sparse_matches, mconf = self.net.sample(
            dense_matches, dense_certainty, self.conf["max_keypoints"]
        )
        kpts0, kpts1 = self.rescale_kpts(
            sparse_matches, image0.shape[-2:], image1.shape[-2:]
        )
        mask = self.compute_mask(
            kpts0, kpts1, data["image0"].shape[-2:], data["image1"].shape[-2:]
        )
        b_ids, i_ids = torch.where(mconf[None])
        pred = {
            "keypoints0": kpts0[i_ids],
            "keypoints1": kpts1[i_ids],
            "confidence": mconf[i_ids],
            "batch_indexes": b_ids,
        }
        scores, b_ids = pred["confidence"], pred["batch_indexes"]
        kpts0, kpts1 = pred["keypoints0"], pred["keypoints1"]
        pred["confidence"], pred["batch_indexes"] = scores[mask], b_ids[mask]
        pred["keypoints0"], pred["keypoints1"] = kpts0[mask], kpts1[mask]

        out = {
            "keypoints0": pred["keypoints0"],
            "keypoints1": pred["keypoints1"],
        }
        return out