from collections import OrderedDict from typing import TYPE_CHECKING, Any, List, Mapping, Optional from packaging import version from transformers import is_torch_available if TYPE_CHECKING: from transformers import PreTrainedTokenizer, TensorType from transformers.configuration_utils import PretrainedConfig from transformers.onnx import OnnxConfigWithPast, PatchingSpec from transformers.utils import logging logger = logging.get_logger(__name__) CODIFY_PRETRAINED_CONFIG_ARCHIVE_MAP = { "smallcloudai/codify_medium_multi": "https://huggingface.co/smallcloudai/codify_medium_multi/blob/main/config.json", "smallcloudai/codify_3b_multi": "https://huggingface.co/smallcloudai/codify_3b_multi/blob/main/config.json", } class CodifyConfig(PretrainedConfig): model_type = "codify" keys_to_ignore_at_inference = ["past_key_values"] attribute_map = { "num_hidden_layers": "L", "num_attention_heads": "attn_heads", "hidden_size": "E", } def __init__( self, vocab_size=51305, layer_norm_epsilon=1e-5, initializer_range=0.02, use_cache=True, bos_token_id=1, eos_token_id=2, mlp_mult=4, tie_word_embeddings=False, **kwargs, ): self.vocab_size = vocab_size self.mlp_mult = mlp_mult self.layer_norm_epsilon = layer_norm_epsilon self.initializer_range = initializer_range self.use_cache = use_cache self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs) class CodifyOnnxConfig(OnnxConfigWithPast): torch_onnx_minimum_version = version.parse("1.12") def __init__( self, config: PretrainedConfig, task: str = "default", patching_specs: List[PatchingSpec] = None, use_past: bool = False, ): super().__init__(config, task=task, patching_specs=patching_specs, use_past=use_past) if not getattr(self._config, "pad_token_id", None): # TODO: how to do that better? self._config.pad_token_id = 0 @property def inputs(self) -> Mapping[str, Mapping[int, str]]: common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}}) if self.use_past: # BLOOM stores values on dynamic axis 2. For more details see: https://github.com/huggingface/transformers/pull/18344 self.fill_with_past_key_values_(common_inputs, direction="inputs", inverted_values_shape=True) common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"} else: common_inputs["attention_mask"] = {0: "batch", 1: "sequence"} return common_inputs @property def num_layers(self) -> int: return self._config.num_hidden_layers @property def num_attention_heads(self) -> int: return self._config.n_head @property def atol_for_validation(self) -> float: return 1e-3 def generate_dummy_inputs( self, tokenizer: "PreTrainedTokenizer", batch_size: int = -1, seq_length: int = -1, is_pair: bool = False, framework: Optional["TensorType"] = None, ) -> Mapping[str, Any]: common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework ) # We need to order the input in the way they appears in the forward() ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]}) # Need to add the past_keys if self.use_past: if not is_torch_available(): raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.") else: import torch batch, seqlen = common_inputs["input_ids"].shape # Not using the same length for past_key_values past_key_values_length = seqlen + 2 head_dim = self._config.hidden_size // self.num_attention_heads past_key_shape = ( batch * self.num_attention_heads, head_dim, past_key_values_length, ) past_value_shape = ( batch * self.num_attention_heads, past_key_values_length, head_dim, ) ordered_inputs["past_key_values"] = [ (torch.zeros(past_key_shape), torch.zeros(past_value_shape)) for _ in range(self.num_layers) ] ordered_inputs["attention_mask"] = common_inputs["attention_mask"] if self.use_past: mask_dtype = ordered_inputs["attention_mask"].dtype ordered_inputs["attention_mask"] = torch.cat( [ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1 ) return ordered_inputs @property def default_onnx_opset(self) -> int: return 13 from transformers import AutoConfig AutoConfig.register(CodifyConfig.model_type, CodifyConfig)