from distutils.version import LooseVersion
from types import MethodType
from typing import List, Optional, Tuple, Union
import warnings

import torch
from torch import nn
import torch.nn.functional as F

from timm.models.registry import register_model
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD

from .forward_intermediates import forward_intermediates
from .input_conditioner import InputConditioner

_has_torch_sdpa = hasattr(F, 'scaled_dot_product_attention')


class PaliGemmaWrapper(nn.Module):
    def __init__(self, vis_model: nn.Module, embed_dim: int):
        super().__init__()

        self.vis_model = vis_model
        self.embed_dim = embed_dim

    @property
    def patch_size(self):
        return self.vis_model.embeddings.patch_size

    @property
    def blocks(self):
        return self.vis_model.encoder.layers

    @property
    def embed_dim(self):
        return self.vis_model.embeddings.embed_dim

    def forward(self, x: torch.Tensor):
        outputs = self.vis_model(
            x,
            return_dict=False,
            interpolate_pos_encoding=True,
        )

        features = outputs[0].to(torch.float32)

        summary = features.mean(dim=1)

        return summary, features

    def forward_features(self, x: torch.Tensor):
        return self(x)


def _get_paligemma_model(repo: str, embed_dim: int = None, dtype: torch.dtype = torch.bfloat16):
    from transformers import PaliGemmaForConditionalGeneration, __version__ as tx_version

    if LooseVersion(tx_version) > LooseVersion('4.44.2'):
        warnings.warn(f'Your transformers version "{tx_version}" is higher than 4.44.2, and for whatever reason, PaliGemma might be broken.')

    extra_args = dict()

    if dtype is not None:
        extra_args['torch_dtype'] = dtype
        rev = str(dtype).split('.')[-1]
        extra_args['revision'] = rev

    model = PaliGemmaForConditionalGeneration.from_pretrained(repo, **extra_args)

    vis_model = model.vision_tower.vision_model

    vis_model = PaliGemmaWrapper(vis_model, embed_dim)

    return vis_model

@register_model
def paligemma_896_student(**kwargs):
    model = _get_paligemma_model('google/paligemma-3b-pt-896', embed_dim=1152, dtype=None)

    return model


def dv2_sdpa(self, x: torch.Tensor) -> torch.Tensor:
    B, N, C = x.shape
    qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)

    q, k, v = qkv[0], qkv[1], qkv[2]
    x = F.scaled_dot_product_attention(
        q, k, v,
        is_causal=False,
        dropout_p=self.attn_drop.p if self.training else 0.,
        scale=self.scale,
    )
    x = x.transpose(1, 2).reshape(B, N, C)
    x = self.proj(x)
    x = self.proj_drop(x)
    return x

def _load_dino_v2(dino_v2_model, cache_dir: Optional[str] = None, pretrained=True, **kwargs):
    if cache_dir:
        torch.hub.set_dir(cache_dir)
    model: nn.Module = torch.hub.load(
        'facebookresearch/dinov2',
        dino_v2_model,
        pretrained=pretrained,
        # **kwargs,
    )

    if _has_torch_sdpa:
        for n, m in model.named_modules():
            if n.endswith('.attn'):
                m.forward = MethodType(dv2_sdpa, m)

    return model

class DinoWrapper(nn.Module):
    def __init__(self, dino_model: nn.Module):
        super().__init__()

        self.inner = dino_model
        dino_model.blocks = nn.Sequential(*dino_model.blocks)

    @property
    def embed_dim(self):
        return self.inner.embed_dim

    @property
    def patch_size(self):
        return self.inner.patch_size

    @property
    def num_cls_tokens(self):
        return getattr(self.inner, 'num_tokens', 1)

    @property
    def num_registers(self):
        return getattr(self.inner, 'num_register_tokens', 0)

    @property
    def num_summary_tokens(self):
        return self.num_cls_tokens + self.num_registers

    @property
    def blocks(self):
        return self.inner.blocks

    def forward(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
        parts = self.inner.forward_features(*args, **kwargs)

        cls_token = parts['x_norm_clstoken']
        features = parts['x_norm_patchtokens']

        return cls_token, features

    def forward_features(self, x: torch.Tensor):
        x = self.inner.prepare_tokens_with_masks(x)
        x = self.inner.blocks(x)
        x_norm = self.inner.norm(x)

        return x_norm[:, 0], x_norm[:, self.num_summary_tokens:]

    def patchify(self, x: torch.Tensor) -> torch.Tensor:
        return self.inner.prepare_tokens_with_masks(x)

    def forward_intermediates(self,
        x: torch.Tensor,
        norm: bool = False,
        **kwargs,
    ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
        return forward_intermediates(
            self,
            patch_extractor=self.inner.prepare_tokens_with_masks,
            num_summary_tokens=self.num_summary_tokens,
            num_cls_tokens=self.num_cls_tokens,
            norm=self.inner.norm if norm else lambda y: y,
            x=x,
            **kwargs,
        )


def _dino_student(arch: str, **kwargs):
    from . import dinov2_arch

    factory = getattr(dinov2_arch, arch)
    model = factory()

    model = DinoWrapper(model)

    conditioner = InputConditioner(
        input_scale=1.0,
        norm_mean=IMAGENET_DEFAULT_MEAN,
        norm_std=IMAGENET_DEFAULT_STD,
    )

    model.input_conditioner = conditioner

    return model


@register_model
def dino_v2_l_student(**kwargs):
    return _dino_student('dinov2_vitl14_reg', **kwargs)

@register_model
def dino_v2_g_student(**kwargs):
    return _dino_student('dinov2_vitg14_reg', **kwargs)