# Copyright (c) Facebook, Inc. and its affiliates.
from typing import Union, List
from collections import OrderedDict
import torch
from torch import nn
import torch

from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer

__all__ = ["tokenize"]


class LayerNorm(nn.LayerNorm):
    """Subclass torch's LayerNorm to handle fp16."""

    def forward(self, x: torch.Tensor):
        orig_type = x.dtype
        ret = super().forward(x.type(torch.float32))
        return ret.type(orig_type)


class QuickGELU(nn.Module):
    def forward(self, x: torch.Tensor):
        return x * torch.sigmoid(1.702 * x)


class ResidualAttentionBlock(nn.Module):
    def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
        super().__init__()

        self.attn = nn.MultiheadAttention(d_model, n_head)
        self.ln_1 = LayerNorm(d_model)
        self.mlp = nn.Sequential(OrderedDict([
            ("c_fc", nn.Linear(d_model, d_model * 4)),
            ("gelu", QuickGELU()),
            ("c_proj", nn.Linear(d_model * 4, d_model))
        ]))
        self.ln_2 = LayerNorm(d_model)
        self.attn_mask = attn_mask

    def attention(self, x: torch.Tensor):
        self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
        return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]

    def forward(self, x: torch.Tensor):
        x = x + self.attention(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x


class Transformer(nn.Module):
    def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
        super().__init__()
        self.width = width
        self.layers = layers
        self.resblocks = nn.Sequential(
            *[ResidualAttentionBlock(width, heads, attn_mask) \
                for _ in range(layers)])

    def forward(self, x: torch.Tensor):
        return self.resblocks(x)

class CLIPTEXT(nn.Module):
    def __init__(self,
                 embed_dim=512,
                 # text
                 context_length=77,
                 vocab_size=49408,
                 transformer_width=512,
                 transformer_heads=8,
                 transformer_layers=12
                 ):
        super().__init__()
        
        self._tokenizer = _Tokenizer()
        self.context_length = context_length

        self.transformer = Transformer(
            width=transformer_width,
            layers=transformer_layers,
            heads=transformer_heads,
            attn_mask=self.build_attention_mask()
        )

        self.vocab_size = vocab_size
        self.token_embedding = nn.Embedding(vocab_size, transformer_width)
        self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
        self.ln_final = LayerNorm(transformer_width)

        self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
        # self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

        self.initialize_parameters()

    def initialize_parameters(self):
        nn.init.normal_(self.token_embedding.weight, std=0.02)
        nn.init.normal_(self.positional_embedding, std=0.01)

        proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
        attn_std = self.transformer.width ** -0.5
        fc_std = (2 * self.transformer.width) ** -0.5
        for block in self.transformer.resblocks:
            nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
            nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
            nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
            nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)

        if self.text_projection is not None:
            nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)

    def build_attention_mask(self):
        # lazily create causal attention mask, with full attention between the vision tokens
        # pytorch uses additive attention mask; fill with -inf
        mask = torch.empty(self.context_length, self.context_length)
        mask.fill_(float("-inf"))
        mask.triu_(1)  # zero out the lower diagonal
        return mask

    @property
    def device(self):
        return self.text_projection.device

    @property
    def dtype(self):
        return self.text_projection.dtype

    def tokenize(self, 
        texts: Union[str, List[str]], \
        context_length: int = 77) -> torch.LongTensor:
        """
        """
        if isinstance(texts, str):
            texts = [texts]

        sot_token = self._tokenizer.encoder["<|startoftext|>"]
        eot_token = self._tokenizer.encoder["<|endoftext|>"]
        all_tokens = [[sot_token] + self._tokenizer.encode(text) + [eot_token] for text in texts]
        result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)

        for i, tokens in enumerate(all_tokens):
            if len(tokens) > context_length:
                st = torch.randint(
                    len(tokens) - context_length + 1, (1,))[0].item()
                tokens = tokens[st: st + context_length]
                # raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
            result[i, :len(tokens)] = torch.tensor(tokens)

        return result

    def encode_text(self, text):
        x = self.token_embedding(text).type(self.dtype)  # [batch_size, n_ctx, d_model]
        x = x + self.positional_embedding.type(self.dtype)
        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_final(x).type(self.dtype)
        # take features from the eot embedding (eot_token is the highest number in each sequence)
        x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
        return x

    def forward(self, captions):
        '''
        captions: list of strings
        '''
        text = self.tokenize(captions).to(self.device) # B x L x D
        features = self.encode_text(text) # B x D
        return features


def build_text_encoder(pretrain=True, visual_type="RN50"):
    clip_dict = {
        "visual_type": ["embed_dim", "context_length", "vocab_size",
                        "transformer_width", "transformer_heads", "transformer_layers"],
        "RN50":        [1024, 77, 49408, 512, 8, 12],
        "RN50x4":      [640, 77, 49408, 640, 10, 12],
        "RN50x16":     [768, 77, 49408, 768, 12, 12],
        "RN50x64":     [1024, 77, 49408, 1024, 16, 12],
    }
    text_encoder = CLIPTEXT(**{k: v for k, v in zip(clip_dict['visual_type'], clip_dict[visual_type])})
    if pretrain:
        import clip
        if visual_type in clip_dict:
            pretrained_model, _ = clip.load(visual_type, device='cpu')
        else:
            raise NotImplementedError

        state_dict = pretrained_model.state_dict()
        to_delete_keys = ["logit_scale", "input_resolution", \
        "context_length", "vocab_size"] + \
            [k for k in state_dict.keys() if k.startswith('visual.')]
        for k in to_delete_keys:
            if k in state_dict:
                del state_dict[k]
        # print('Loading pretrained CLIP')
        text_encoder.load_state_dict(state_dict)
    return text_encoder