#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
#

"""Implements HF OpenELMConfig based on PretrainedConfig"""
from numbers import Number
from typing import List, Optional, Union

import numpy as np
from transformers import PretrainedConfig


def make_divisible(
    v: Union[float, int],
    divisor: Optional[int] = 8,
    min_value: Optional[Union[float, int]] = None,
) -> Union[float, int]:
    """
    This function is taken from the original tf repo.
    It ensures that all layers have a channel number that is divisible by the divisor
    It can be seen at:
    https://github.com/tensorflow/models/blob/2cfc99eff5e5eb729c6793d2f3d03aa1c9be2b15/research/slim/nets/mobilenet/mobilenet.py#L62

    Args:
        v: input value
        divisor: default to 8
        min_value: minimum divisor value
    Returns:
        new_v: new divisible value
    """
    if min_value is None:
        min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v


def compute_heads(model_dim: int, head_dim: int) -> int:
    """Compute the number of heads.

    Args:
        model_dim: Model dimension.
        head_dim: Head dimension.

    Returns:
        An integer denoting number of heads in multi-head attention is returned.

    Raises:
        ValueError: if model dimension is not divisible by head dimension.
    """
    if model_dim % head_dim == 0:
        return model_dim // head_dim
    else:
        raise ValueError(
            f"Model dimension should be divisible by head dimension. Got: {model_dim} and {head_dim}."
        )


OpenELM_CONFIGS = {
    "OpenELM-270M": dict(
        num_transformer_layers=16,
        model_dim=1280,
        head_dim=64,
        num_gqa_groups=4,
        normalize_qk_projections=True,
        share_input_output_layers=True,
        # Vary the FFN and QKV multipliers to create variable FFN and attention layers respectively.
        ffn_multipliers=(0.5, 4.0),
        qkv_multipliers=(0.5, 1.0),
    ),
    "OpenELM-450M": dict(
        num_transformer_layers=20,
        model_dim=1536,
        head_dim=64,
        num_gqa_groups=4,
        normalize_qk_projections=True,
        share_input_output_layers=True,
        # Vary the FFN and QKV multipliers to create variable FFN and attention layers respectively.
        ffn_multipliers=(0.5, 4.0),
        qkv_multipliers=(0.5, 1.0),
    ),
    "OpenELM-1_1B": dict(
        num_transformer_layers=28,
        model_dim=2048,
        head_dim=64,
        num_gqa_groups=4,
        normalize_qk_projections=True,
        share_input_output_layers=True,
        # Vary the FFN and QKV multipliers to create variable FFN and attention layers respectively.
        ffn_multipliers=(0.5, 4.0),
        qkv_multipliers=(0.5, 1.0),
    ),
    "OpenELM-3B": dict(
        num_transformer_layers=36,
        model_dim=3072,
        head_dim=128,
        num_gqa_groups=4,
        normalize_qk_projections=True,
        share_input_output_layers=True,
        # Vary the FFN and QKV multipliers to create variable FFN and attention layers respectively.
        ffn_multipliers=(0.5, 4.0),
        qkv_multipliers=(0.5, 1.0),
    ),
}


class OpenELMConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`OpenELMModel`]. It is used to instantiate an OpenELM model according to the specified arguments, defining the model architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.

    Args:
        vocab_size (`int`, *optional*, defaults to 32000):
            Vocabulary size of the OpenELM model.
        max_context_length (`int`, *optional*, defaults to 2048):
            Maximum number of input tokens.
        num_transformer_layers (`int`, *optional*, defaults to 12):
            Number of hidden layers in the Transformer decoder.
        model_dim (`int`, *optional*, defaults to 2048):
            Dimension of the hidden representations.
        head_dim (`int`, *optional*, defaults to 128):
            The attention head dimension.
        qkv_multipliers (`Union[Number, List[Number]]`, *optional*, defaults to 1.0):
            If the qkv_multipliers is a Number, then all attention layers have the same latent dimensions,
            resulting in uniform allocation of parameters.
            If the qkv_multipliers is a List of Number, then each attention layer have different latent dimensions
            assuming qkv_multipliers[0] != qkv_multipliers[1]. This results in variable allocation of parameters in attention layer.
            This scaling is known as layer-wise or block-wise scaling: https://arxiv.org/abs/2008.00623
        num_query_heads (`Union[int, None]`, *optional*, defaults to None):
            The number of query heads, computed from `compute_heads(model_dim=model_dim, head_dim=head_dim)`.
        num_gqa_groups (`int`, *optional*, defaults to 1):
            This variable allows to switch between multi-head attention, group query attention, and multi-query attention.
            When num_gqa_groups == 1, then it is multi-head attention.
            When 1 < num_gqa_groups < num_heads and num_heads is divisible by num_gqa_groups, then it is group query attention
            When num_gqa_groups == num_heads, then it is multi-query attention
        ffn_multipliers (`Union[Number, List[Number]]`, *optional*, defaults to 4.0):
            Feed-forward network (FFN) multipliers.
            If the ffn_multipliers is a Number, then all FFN layers have the same latent dimensions,
            resulting in uniform allocation of parameters.
            If the ffn_multipliers is a List of Number, then each FFN layer have different latent dimensions
            assuming ffn_multipliers[0] != ffn_multipliers[1]. This results in variable allocation of parameters in FFN layer.
            This scaling is known as layer-wise or block-wise scaling: https://arxiv.org/abs/2008.00623
        ffn_with_glu (`bool`, *optional*, defaults to True):
            Whether to use FFN with Gated Linear Unit (GLU)
        ffn_dim_divisor (`int`, *optional*, defaults to 256):
            The ffn layer dimension divisor.
        activation_fn_name (`str` or `function`, *optional*, defaults to `"swish"`):
            The non-linear activation function (function or string) in the decoder.
        normalization_layer_name (`str` or `function`, *optional*, defaults to `"rms_norm"`):
            Type of normalization layer.
        normalize_qk_projections (`bool`, *optional*, defaults to False):
            Whether to normalize queries and keys after projections
        share_input_output_layers (`bool`, *optional*, defaults to False):
            Whether to share the embedding between input and output linear layer
        rope_freq_constant (`int`, *optional*, defaults to 10000):
            The base period of the RoPE embeddings.
        rope_max_length (`int`, *optional*, defaults to 4096):
            That rope_max_length is set to twice of max_context_length.
            This allows flexibility in token lengths during training or fine-tuning.
        initializer_range (`float`, *optional*, defaults to 0.02):
            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
        use_cache (`bool`, *optional*, defaults to `True`):
            Whether or not the model should return the last key/values attentions (not used by all models). Only
            relevant if `config.is_decoder=True`.
        bos_token_id (`int`, *optional*, defaults to 2):
            Beginning of stream token id.
        eos_token_id (`int`, *optional*, defaults to 1):
            End of stream token id.
    """

    model_type = "openelm"

    def __init__(
        self,
        vocab_size: int = 32000,
        max_context_length: int = 2048,
        num_transformer_layers: int = 12,
        model_dim: int = 2048,
        head_dim: int = 128,
        qkv_multipliers: Union[Number, List[Number]] = 1.0,
        num_query_heads: Union[int, None] = None,
        num_gqa_groups: int = 1,
        ffn_multipliers: Union[Number, List[Number]] = 4.0,
        ffn_with_glu: bool = True,
        ffn_dim_divisor: int = 256,
        activation_fn_name: str = "swish",
        normalization_layer_name: str = "rms_norm",
        normalize_qk_projections: bool = False,
        share_input_output_layers: bool = False,
        rope_freq_constant: int = 10000,
        rope_max_length: int = 4096,
        initializer_range: float = 0.02,
        use_cache: bool = True,
        bos_token_id: int = 1,
        eos_token_id: int = 2,
        **kwargs,
    ) -> None:
        self.vocab_size = vocab_size
        self.max_context_length = max_context_length
        self.num_transformer_layers = num_transformer_layers
        self.model_dim = model_dim
        self.head_dim = head_dim
        self.qkv_multipliers = qkv_multipliers
        self.num_query_heads = num_query_heads
        self.num_gqa_groups = num_gqa_groups
        self.ffn_multipliers = ffn_multipliers
        self.ffn_with_glu = ffn_with_glu
        self.ffn_dim_divisor = ffn_dim_divisor
        self.activation_fn_name = activation_fn_name
        self.normalization_layer_name = normalization_layer_name
        self.normalize_qk_projections = normalize_qk_projections
        self.share_input_output_layers = share_input_output_layers
        self.rope_freq_constant = rope_freq_constant
        self.rope_max_length = rope_max_length
        self.num_query_heads = (
            compute_heads(model_dim=model_dim, head_dim=head_dim)
            if num_query_heads is None
            else num_query_heads
        )
        self.initializer_range = initializer_range

        self.__post_init__()
        super().__init__(
            use_cache=use_cache,
            bos_token_id=bos_token_id,
            eos_token_id=eos_token_id,
            **kwargs,
        )

    def __post_init__(self) -> None:
        if self.num_gqa_groups is not None:
            head_multiple_of = self.num_gqa_groups
        else:
            head_multiple_of = 2

        if isinstance(self.qkv_multipliers, Number):
            # All attention layers have the same latent dimensions, resulting in uniform allocation of parameters.
            qkv_dim = make_divisible(
                self.model_dim * self.qkv_multipliers,
                divisor=self.head_dim * head_multiple_of,
            )
            query_dims = [int(qkv_dim)] * self.num_transformer_layers

        elif (
            isinstance(self.qkv_multipliers, (tuple, list))
            and len(self.qkv_multipliers) == 2
        ):
            # Each attention layer have different latent dimensions assuming qkv_multipliers[0] != qkv_multipliers[1].
            # This results in variable allocation of parameters in attention layer.
            # This scaling is known as layer-wise or block-wise scaling: https://arxiv.org/abs/2008.00623
            qkv_multipliers = [
                round(v, 2)
                for v in np.linspace(
                    self.qkv_multipliers[0],
                    self.qkv_multipliers[1],
                    num=self.num_transformer_layers,
                    dtype=float,
                )
            ]
            # Make sure that scaled model dimension is divisible by scaled head dimension.
            query_dims = [
                int(
                    make_divisible(
                        self.model_dim * m, divisor=self.head_dim * head_multiple_of
                    )
                )
                for m in qkv_multipliers
            ]
        else:
            raise NotImplementedError(
                f"QKV multipliers should be a single number or a list containing exactly two numbers. Got: {qkv_multipliers}."
            )

        # compute the number of query, key, and value heads
        # For multi-head and multi-query attention, the number of heads for query, key, and value are the same.
        # For group query attention, the number of key and value heads are the same.
        self.num_query_heads = [
            int(compute_heads(q_dim, self.head_dim)) for q_dim in query_dims
        ]
        self.num_kv_heads = [
            q_heads // self.num_gqa_groups for q_heads in self.num_query_heads
        ]

        # Feed-forward network (FFN) multipliers
        if isinstance(self.ffn_multipliers, Number):
            # All FFN layers have the same latent dimensions, resulting in uniform allocation of parameters.
            self.ffn_multipliers = [self.ffn_multipliers] * self.num_transformer_layers
        elif isinstance(self.ffn_multipliers, (tuple, list)):
            # Each FFN layer have different latent dimensions assuming ffn_multipliers[0] != ffn_multipliers[1].
            # This results in variable allocation of parameters in FFN layer.
            # This scaling is known as layer-wise or block-wise scaling: https://arxiv.org/abs/2008.00623
            if len(self.ffn_multipliers) == 2:
                self.ffn_multipliers = [
                    round(v, 2)
                    for v in np.linspace(
                        self.ffn_multipliers[0],
                        self.ffn_multipliers[1],
                        num=self.num_transformer_layers,
                        dtype=float,
                    )
                ]
            else:
                assert (
                    len(self.ffn_multipliers) == self.num_transformer_layers
                ), f"{len(self.ffn_multipliers)=}!={self.num_transformer_layers=}"
        else:
            raise NotImplementedError(
                f"FFN multipliers should be a single number or a list containing exactly two numbers. Got: {qkv_multipliers}."
            )

        # check num_query_heads divisible by num_kv_heads for every layer
        for layer_idx in range(len(query_dims)):
            assert self.num_query_heads[layer_idx] % self.num_kv_heads[layer_idx] == 0