Fill-Mask
Transformers
PyTorch
Safetensors
English
nomic_bert
custom_code

add full support for inputs_embeds

#9
by jxm - opened
Files changed (1) hide show
  1. configuration_hf_nomic_bert.py +2110 -45
configuration_hf_nomic_bert.py CHANGED
@@ -1,56 +1,2121 @@
1
- from transformers import GPT2Config
 
 
 
2
 
 
 
3
 
4
- class NomicBertConfig(GPT2Config):
5
- model_type = "nomic_bert"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  def __init__(
8
  self,
9
- prenorm=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  parallel_block=False,
11
  parallel_block_tied_norm=False,
12
- rotary_emb_fraction=0.0,
13
- fused_dropout_add_ln=False,
14
- fused_bias_fc=False,
15
- use_flash_attn=False,
16
- use_xentropy=False,
 
17
  qkv_proj_bias=True,
18
- rotary_emb_base=10_000,
19
- rotary_emb_scale_base=None,
20
- rotary_emb_interleaved=False,
21
- mlp_fc1_bias=True,
22
- mlp_fc2_bias=True,
23
  use_rms_norm=False,
24
  causal=False,
25
- type_vocab_size=2,
26
- dense_seq_output=True,
27
- pad_vocab_size_multiple=1,
28
- tie_word_embeddings=True,
29
- rotary_scaling_factor=None,
30
- max_trained_positions=2048,
31
- **kwargs,
32
- ):
33
- self.prenorm = prenorm
34
- self.parallel_block = parallel_block
35
- self.parallel_block_tied_norm = parallel_block_tied_norm
36
- self.rotary_emb_fraction = rotary_emb_fraction
37
- self.tie_word_embeddings = tie_word_embeddings
38
- self.fused_dropout_add_ln = fused_dropout_add_ln
39
- self.fused_bias_fc = fused_bias_fc
40
- self.use_flash_attn = use_flash_attn
41
- self.use_xentropy = use_xentropy
42
- self.qkv_proj_bias = qkv_proj_bias
43
- self.rotary_emb_base = rotary_emb_base
44
- self.rotary_emb_scale_base = rotary_emb_scale_base
45
- self.rotary_emb_interleaved = rotary_emb_interleaved
46
- self.mlp_fc1_bias = mlp_fc1_bias
47
- self.mlp_fc2_bias = mlp_fc2_bias
48
- self.use_rms_norm = use_rms_norm
49
- self.causal = causal
50
- self.type_vocab_size = type_vocab_size
51
- self.dense_seq_output = dense_seq_output
52
- self.pad_vocab_size_multiple = pad_vocab_size_multiple
53
- self.rotary_scaling_factor = rotary_scaling_factor
54
- self.max_trained_positions = max_trained_positions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
- super().__init__(**kwargs)
 
 
 
 
1
+ # Copyright (c) 2022, Tri Dao.
2
+ # This BERT implementation is based on our MLPerf 2.0 and MLPerf 2.1 BERT implementation.
3
+ # https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
4
+ # https://github.com/mlcommons/training_results_v2.1/blob/main/Azure-HazyResearch/benchmarks/bert/implementations/ND96amsr_A100_v4/modeling.py
5
 
6
+ import collections
7
+ import logging
8
 
9
+ # Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
10
+ import math
11
+ import os
12
+ import re
13
+ from collections import OrderedDict
14
+ from functools import partial
15
+ from typing import List, Optional, Tuple, Union
16
+
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ from einops import rearrange, repeat
22
+ from safetensors.torch import load_file as safe_load_file
23
+ from torch.nn.modules.utils import _pair
24
+ from transformers import GPT2Config, PreTrainedModel, ViTConfig, ViTModel
25
+ from transformers.modeling_outputs import BaseModelOutputWithPast
26
+ from transformers.models.bert.modeling_bert import (
27
+ BaseModelOutputWithPoolingAndCrossAttentions,
28
+ MaskedLMOutput,
29
+ SequenceClassifierOutput,
30
+ )
31
+ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME
32
+ from transformers.utils.hub import cached_file, get_checkpoint_shard_files
33
+
34
+ from .configuration_hf_nomic_bert import NomicBertConfig
35
+
36
+ try:
37
+ from torch.nn.functional import scaled_dot_product_attention
38
+ except ImportError:
39
+ scaled_dot_product_attention = None
40
+
41
+ logger = logging.getLogger(__name__)
42
+
43
+
44
+ # adapted from flash attention, added safe serialization option for hf models
45
+ def state_dict_from_pretrained(model_name, safe_serialization=False, device=None, dtype=None):
46
+ # If not fp32, then we don't want to load directly to the GPU
47
+ mapped_device = "cpu" if dtype not in [torch.float32, None] else device
48
+ is_sharded = False
49
+ load_safe = False
50
+ resolved_archive_file = None
51
+
52
+ weights_path = os.path.join(model_name, WEIGHTS_NAME)
53
+ weights_index_path = os.path.join(model_name, WEIGHTS_INDEX_NAME)
54
+ safe_weights_path = os.path.join(model_name, SAFE_WEIGHTS_NAME)
55
+ safe_weights_index_path = os.path.join(model_name, SAFE_WEIGHTS_INDEX_NAME)
56
+
57
+ if os.path.isfile(weights_path):
58
+ resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)
59
+ elif os.path.isfile(weights_index_path):
60
+ resolved_archive_file = cached_file(model_name, WEIGHTS_INDEX_NAME, _raise_exceptions_for_missing_entries=False)
61
+ is_sharded = True
62
+ elif os.path.isfile(safe_weights_path):
63
+ resolved_archive_file = cached_file(model_name, SAFE_WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)
64
+ load_safe = True
65
+ elif os.path.isfile(safe_weights_index_path):
66
+ resolved_archive_file = cached_file(
67
+ model_name, SAFE_WEIGHTS_INDEX_NAME, _raise_exceptions_for_missing_entries=False
68
+ )
69
+ is_sharded = True
70
+ load_safe = True
71
+ else: # Try loading from HF hub instead of from local files
72
+ resolved_archive_file = None
73
+ for weight_name in [WEIGHTS_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME]:
74
+ resolved_archive_file = cached_file(model_name, weight_name, _raise_exceptions_for_missing_entries=False)
75
+ if resolved_archive_file is not None:
76
+ if weight_name in [SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME]:
77
+ load_safe = True
78
+ if weight_name in [WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME]:
79
+ is_sharded = True
80
+ break
81
+
82
+ if resolved_archive_file is None:
83
+ raise EnvironmentError(f"Model name {model_name} was not found.")
84
+
85
+ if load_safe:
86
+ loader = partial(safe_load_file, device=mapped_device)
87
+ else:
88
+ loader = partial(torch.load, map_location=mapped_device)
89
+
90
+ if is_sharded:
91
+ # resolved_archive_file becomes a list of files that point to the different
92
+ # checkpoint shards in this case.
93
+ resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(model_name, resolved_archive_file)
94
+ state_dict = {}
95
+ for sharded_file in resolved_archive_file:
96
+ state_dict.update(loader(sharded_file))
97
+ else:
98
+ state_dict = loader(resolved_archive_file)
99
+ # Convert dtype before moving to GPU to save memory
100
+ if dtype is not None:
101
+ state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()}
102
+ state_dict = {k: v.to(device=device) for k, v in state_dict.items()}
103
+ return state_dict
104
+
105
+
106
+ def filter_shapes(state_dict, model):
107
+ """
108
+ Filters the state dict to match the current model shape.
109
+ """
110
+ filtered_state_dict = {}
111
+ for key, value in state_dict.items():
112
+ if key in model.state_dict():
113
+ if value.shape == model.state_dict()[key].shape:
114
+ filtered_state_dict[key] = value
115
+ return filtered_state_dict
116
+
117
+
118
+ def remap_bert_state_dict(
119
+ state_dict,
120
+ config,
121
+ remove_bert=False,
122
+ remove_cls_weights=False,
123
+ add_pooling_layer=False,
124
+ ):
125
+ """
126
+ Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
127
+ """
128
+
129
+ def add_bert_prefix(key):
130
+ # prepend bert. to the key
131
+ if key.startswith("bert.") or key.startswith("cls."):
132
+ return key
133
+ return f"bert.{key}"
134
+
135
+ state_dict = OrderedDict((add_bert_prefix(k), v) for k, v in state_dict.items())
136
+
137
+ # LayerNorm
138
+ def key_mapping_ln_gamma_beta(key):
139
+ key = re.sub(r"LayerNorm.gamma$", "LayerNorm.weight", key)
140
+ key = re.sub(r"LayerNorm.beta$", "LayerNorm.bias", key)
141
+ return key
142
+
143
+ state_dict = OrderedDict((key_mapping_ln_gamma_beta(k), v) for k, v in state_dict.items())
144
+
145
+ # Layers
146
+ def key_mapping_layers(key):
147
+ return re.sub(r"^bert.encoder.layer\.", "bert.encoder.layers.", key)
148
+
149
+ state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
150
+
151
+ # LayerNorm
152
+ def key_mapping_ln(key):
153
+ key = re.sub(r"^bert.embeddings.LayerNorm.", "bert.emb_ln.", key)
154
+ key = re.sub(
155
+ r"^bert.encoder.layers.(\d+).attention.output.LayerNorm.(weight|bias)",
156
+ r"bert.encoder.layers.\1.norm1.\2",
157
+ key,
158
+ )
159
+ key = re.sub(
160
+ r"^bert.encoder.layers.(\d+).output.LayerNorm.(weight|bias)",
161
+ r"bert.encoder.layers.\1.norm2.\2",
162
+ key,
163
+ )
164
+ key = re.sub(
165
+ r"^cls.predictions.transform.LayerNorm.(weight|bias)",
166
+ r"cls.predictions.transform.layer_norm.\1",
167
+ key,
168
+ )
169
+ return key
170
+
171
+ state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
172
+
173
+ # MLP
174
+ def key_mapping_mlp(key):
175
+ key = re.sub(
176
+ r"^bert.encoder.layers.(\d+).intermediate.dense.(weight|bias)",
177
+ r"bert.encoder.layers.\1.mlp.fc1.\2",
178
+ key,
179
+ )
180
+ key = re.sub(
181
+ r"^bert.encoder.layers.(\d+).output.dense.(weight|bias)",
182
+ r"bert.encoder.layers.\1.mlp.fc2.\2",
183
+ key,
184
+ )
185
+ return key
186
+
187
+ state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
188
+
189
+ # Attention
190
+ last_layer_subset = getattr(config, "last_layer_subset", False)
191
+ for d in range(config.num_hidden_layers):
192
+ if f"bert.encoder.layers.{d}.attention.self.query.weight" not in state_dict:
193
+ continue
194
+ Wq = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.query.weight")
195
+ Wk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.weight")
196
+ Wv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.weight")
197
+ bq = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.query.bias")
198
+ bk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.bias")
199
+ bv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.bias")
200
+ if not (last_layer_subset and d == config.num_hidden_layers - 1):
201
+ state_dict[f"bert.encoder.layers.{d}.attn.Wqkv.weight"] = torch.cat([Wq, Wk, Wv], dim=0)
202
+ state_dict[f"bert.encoder.layers.{d}.attn.Wqkv.bias"] = torch.cat([bq, bk, bv], dim=0)
203
+ else:
204
+ state_dict[f"bert.encoder.layers.{d}.attn.Wq.weight"] = Wq
205
+ state_dict[f"bert.encoder.layers.{d}.attn.Wkv.weight"] = torch.cat([Wk, Wv], dim=0)
206
+ state_dict[f"bert.encoder.layers.{d}.attn.Wq.bias"] = bq
207
+ state_dict[f"bert.encoder.layers.{d}.attn.Wkv.bias"] = torch.cat([bk, bv], dim=0)
208
+
209
+ def key_mapping_attn(key):
210
+ return re.sub(
211
+ r"^bert.encoder.layers.(\d+).attention.output.dense.(weight|bias)",
212
+ r"bert.encoder.layers.\1.attn.out_proj.\2",
213
+ key,
214
+ )
215
+
216
+ state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
217
+
218
+ def key_mapping_decoder_bias(key):
219
+ return re.sub(r"^cls.predictions.bias", "cls.predictions.decoder.bias", key)
220
+
221
+ # remove nsp weights, we don't use
222
+ state_dict.pop("cls.seq_relationship.weight", None)
223
+ state_dict.pop("cls.seq_relationship.bias", None)
224
+ state_dict.pop("bert.embeddings.position_ids", None)
225
+
226
+ state_dict = OrderedDict((key_mapping_decoder_bias(k), v) for k, v in state_dict.items())
227
+
228
+ if remove_cls_weights:
229
+ cls_weights = [
230
+ "cls.predictions.decoder.bias",
231
+ "cls.predictions.transform.dense.weight",
232
+ "cls.predictions.transform.dense.bias",
233
+ "cls.predictions.transform.layer_norm.weight",
234
+ "cls.predictions.transform.layer_norm.bias",
235
+ "cls.predictions.decoder.weight",
236
+ ]
237
+ for weight in cls_weights:
238
+ state_dict.pop(weight, None)
239
+
240
+ # Word embedding
241
+ pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
242
+ if pad_vocab_size_multiple > 1:
243
+ word_embeddings = state_dict["bert.embeddings.word_embeddings.weight"]
244
+ state_dict["bert.embeddings.word_embeddings.weight"] = F.pad(
245
+ word_embeddings, (0, 0, 0, config.vocab_size - word_embeddings.shape[0])
246
+ )
247
+ if not remove_cls_weights:
248
+ decoder_weight = state_dict["cls.predictions.decoder.weight"]
249
+ state_dict["cls.predictions.decoder.weight"] = F.pad(
250
+ decoder_weight, (0, 0, 0, config.vocab_size - decoder_weight.shape[0])
251
+ )
252
+ # If the vocab was padded, we want to set the decoder bias for those padded indices to be
253
+ # strongly negative (i.e. the decoder shouldn't predict those indices).
254
+ # TD [2022-05-09]: I don't think it affects the MLPerf training.
255
+ if "cls.predictions.decoder.bias" in state_dict:
256
+ decoder_bias = state_dict["cls.predictions.decoder.bias"]
257
+ state_dict["cls.predictions.decoder.bias"] = F.pad(
258
+ decoder_bias, (0, config.vocab_size - decoder_bias.shape[0]), value=-100.0
259
+ )
260
+
261
+ if add_pooling_layer is False:
262
+ pooler_weights = [
263
+ "bert.pooler.dense.weight",
264
+ "bert.pooler.dense.bias",
265
+ ]
266
+ for key in pooler_weights:
267
+ state_dict.pop(key, None)
268
+
269
+ if remove_bert:
270
+
271
+ def remove_bert_prefix(key):
272
+ key = re.sub(r"^bert.", "", key)
273
+ return key
274
+
275
+ state_dict = OrderedDict((remove_bert_prefix(k), v) for k, v in state_dict.items())
276
+
277
+ return state_dict
278
+
279
+
280
+ def _trunc_normal_(tensor, mean, std, a, b):
281
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
282
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
283
+ def norm_cdf(x):
284
+ # Computes standard normal cumulative distribution function
285
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
286
+
287
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
288
+ print(
289
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
290
+ "The distribution of values may be incorrect.",
291
+ stacklevel=2,
292
+ )
293
+
294
+ # Values are generated by using a truncated uniform distribution and
295
+ # then using the inverse CDF for the normal distribution.
296
+ # Get upper and lower cdf values
297
+ l = norm_cdf((a - mean) / std)
298
+ u = norm_cdf((b - mean) / std)
299
+
300
+ # Uniformly fill tensor with values from [l, u], then translate to
301
+ # [2l-1, 2u-1].
302
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
303
+
304
+ # Use inverse cdf transform for normal distribution to get truncated
305
+ # standard normal
306
+ tensor.erfinv_()
307
+
308
+ # Transform to proper mean, std
309
+ tensor.mul_(std * math.sqrt(2.0))
310
+ tensor.add_(mean)
311
+
312
+ # Clamp to ensure it's in the proper range
313
+ tensor.clamp_(min=a, max=b)
314
+ return tensor
315
+
316
+
317
+ def trunc_normal_tf_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
318
+ r"""Fills the input Tensor with values drawn from a truncated
319
+ normal distribution. The values are effectively drawn from the
320
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
321
+ with values outside :math:`[a, b]` redrawn until they are within
322
+ the bounds. The method used for generating the random values works
323
+ best when :math:`a \leq \text{mean} \leq b`.
324
+
325
+ NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
326
+ bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
327
+ and the result is subsquently scaled and shifted by the mean and std args.
328
+
329
+ Args:
330
+ tensor: an n-dimensional `torch.Tensor`
331
+ mean: the mean of the normal distribution
332
+ std: the standard deviation of the normal distribution
333
+ a: the minimum cutoff value
334
+ b: the maximum cutoff value
335
+ Examples:
336
+ >>> w = torch.empty(3, 5)
337
+ >>> nn.init.trunc_normal_(w)
338
+ """
339
+ with torch.no_grad():
340
+ _trunc_normal_(tensor, 0, 1.0, a, b)
341
+ tensor.mul_(std).add_(mean)
342
+ return tensor
343
+
344
+
345
+ class NomicBertPreTrainedModel(PreTrainedModel):
346
+ """An abstract class to handle weights initialization and
347
+ a simple interface for dowloading and loading pretrained models.
348
+ """
349
+
350
+ config_class = NomicBertConfig
351
+ base_model_prefix = "model"
352
+ supports_gradient_checkpointing = True
353
+ _no_split_modules = ["Block"]
354
+ _skip_keys_device_placement = "past_key_values"
355
+
356
+ def __init__(self, config, *inputs, **kwargs):
357
+ super().__init__(config)
358
+ if not isinstance(config, GPT2Config):
359
+ raise ValueError(
360
+ "Parameter config in `{}(config)` should be an instance of class `GPT2Config`. "
361
+ "To create a model from a Google pretrained model use "
362
+ "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
363
+ self.__class__.__name__, self.__class__.__name__
364
+ )
365
+ )
366
+ self.config = config
367
+
368
+ @classmethod
369
+ def from_pretrained(cls, model_name, config=None, *inputs, **kwargs):
370
+ """
371
+ Instantiate a NomicBertPreTrainedModel from a pre-trained model file or a pytorch state dict.
372
+ Download and cache the pre-trained model file if needed.
373
+
374
+ Params:
375
+ pretrained_model_name_or_path: either:
376
+ - a path or url to a pretrained model archive containing:
377
+ . `bert_config.json` a configuration file for the model
378
+ . `pytorch_model.bin` a PyTorch dump of a NomicBertForPretraining instance
379
+ - a path or url to a pretrained model archive containing:
380
+ . `bert_config.json` a configuration file for the model
381
+ . `model.chkpt` a TensorFlow checkpoint
382
+ *inputs, **kwargs: additional input for the specific NomicBert class
383
+ (ex: num_labels for NomicBertForSequenceClassification)
384
+ """
385
+ # Instantiate model.
386
+ if config is None:
387
+ config = cls.config_class.from_pretrained(model_name)
388
+ remove_cls = cls != NomicBertForPreTraining
389
+ remove_bert_prefix = cls != NomicBertForPreTraining and cls != NomicBertForSequenceClassification
390
+ ignore_mismatched_shapes = kwargs.pop("ignore_mismatched_sizes", False)
391
+ num_labels = kwargs.pop("num_labels", None)
392
+ rotary_scaling_factor = kwargs.pop("rotary_scaling_factor", None)
393
+ strict = kwargs.pop("strict", True)
394
+ dtype = kwargs.pop("torch_dtype", None)
395
+ if rotary_scaling_factor:
396
+ config.rotary_scaling_factor = rotary_scaling_factor
397
+
398
+ if config.n_positions <= 0 and config.rotary_emb_fraction > 0:
399
+ config.n_positions = 2048
400
+ if num_labels:
401
+ config.num_labels = num_labels
402
+
403
+ if "add_pooling_layer" in kwargs:
404
+ model = cls(config, *inputs, add_pooling_layer=kwargs.pop("add_pooling_layer"))
405
+ else:
406
+ if cls == NomicBertModel:
407
+ model = cls(config, *inputs, add_pooling_layer=False)
408
+ else:
409
+ model = cls(config, *inputs)
410
+
411
+ if dtype is not None:
412
+ model = model.to(dtype=dtype)
413
+ # TODO: fix this
414
+ # Assuming we know what we're doing when loading from disk
415
+ # Prob a bad assumption but i'm tired and want to train this asap
416
+ if os.path.exists(model_name):
417
+ model_path = f"{model_name}/pytorch_model.bin"
418
+ if os.path.exists(model_path):
419
+ state_dict = torch.load(f"{model_name}/pytorch_model.bin")
420
+ else:
421
+ model_path = f"{model_name}/model.safetensors"
422
+ if not os.path.exists(model_path):
423
+ raise ValueError(f"Model path {model_path} not found")
424
+ state_dict = safe_load_file(model_path)
425
+
426
+ if ignore_mismatched_shapes:
427
+ state_dict = filter_shapes(state_dict, model)
428
+ load_return = model.load_state_dict(state_dict, strict=False)
429
+ else:
430
+ # TODO: can probably check config class and see if we need to remap from a bert model
431
+ state_dict = state_dict_from_pretrained(model_name, dtype=dtype)
432
+ state_dict = remap_bert_state_dict(
433
+ state_dict,
434
+ config,
435
+ remove_bert=remove_bert_prefix,
436
+ remove_cls_weights=remove_cls,
437
+ add_pooling_layer=getattr(config, "add_pooling_layer", False),
438
+ )
439
+ if ignore_mismatched_shapes:
440
+ state_dict = filter_shapes(state_dict, model)
441
+
442
+ load_return = model.load_state_dict(state_dict, strict=strict)
443
+ logger.warning(load_return)
444
+ return model
445
+
446
+ def _set_gradient_checkpointing(self, module, value=False):
447
+ if isinstance(module, NomicBertEncoder):
448
+ module.gradient_checkpointing = value
449
+
450
+
451
+ # https://github.com/huggingface/transformers/blob/7032e0203262ebb2ebf55da8d2e01f873973e835/src/transformers/models/bert/modeling_bert.py#L748
452
+ def _init_weights(module, initializer_range=0.02):
453
+ if isinstance(module, nn.Linear):
454
+ nn.init.normal_(module.weight, std=initializer_range)
455
+ if module.bias is not None:
456
+ nn.init.zeros_(module.bias)
457
+ elif isinstance(module, nn.Embedding):
458
+ nn.init.normal_(module.weight, std=initializer_range)
459
+ if module.padding_idx is not None:
460
+ nn.init.zeros_(module.weight[module.padding_idx])
461
+
462
+
463
+ def _ntuple(n):
464
+ def parse(x):
465
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
466
+ return tuple(x)
467
+ return tuple(repeat(x, n))
468
+
469
+ return parse
470
+
471
+
472
+ to_1tuple = _ntuple(1)
473
+ to_2tuple = _ntuple(2)
474
+ to_3tuple = _ntuple(3)
475
+ to_4tuple = _ntuple(4)
476
+ to_ntuple = _ntuple
477
+
478
+
479
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False):
480
+ """
481
+ Create 2D sin/cos positional embeddings.
482
+
483
+ Args:
484
+ embed_dim (`int`):
485
+ Embedding dimension.
486
+ grid_size (`int`):
487
+ The grid height and width.
488
+ add_cls_token (`bool`, *optional*, defaults to `False`):
489
+ Whether or not to add a classification (CLS) token.
490
+
491
+ Returns:
492
+ (`torch.FloatTensor` of shape (grid_size*grid_size, embed_dim) or (1+grid_size*grid_size, embed_dim): the
493
+ position embeddings (with or without classification token)
494
+ """
495
+ grid_h = np.arange(grid_size, dtype=np.float32)
496
+
497
+ grid_w = np.arange(grid_size, dtype=np.float32)
498
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
499
+ grid = np.stack(grid, axis=0)
500
+
501
+ grid = grid.reshape([2, 1, grid_size, grid_size])
502
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
503
+ if add_cls_token:
504
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
505
+ return pos_embed
506
+
507
+
508
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
509
+ if embed_dim % 2 != 0:
510
+ raise ValueError("embed_dim must be even")
511
+
512
+ # use half of dimensions to encode grid_h
513
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
514
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
515
+
516
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
517
+ return emb
518
+
519
+
520
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
521
+ """
522
+ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
523
+ """
524
+ if embed_dim % 2 != 0:
525
+ raise ValueError("embed_dim must be even")
526
+
527
+ omega = np.arange(embed_dim // 2, dtype=float)
528
+ omega /= embed_dim / 2.0
529
+ omega = 1.0 / 10000**omega # (D/2,)
530
+
531
+ pos = pos.reshape(-1) # (M,)
532
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
533
+
534
+ emb_sin = np.sin(out) # (M, D/2)
535
+ emb_cos = np.cos(out) # (M, D/2)
536
+
537
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
538
+ return emb
539
+
540
+
541
+ def ndgrid(*tensors) -> Tuple[torch.Tensor, ...]:
542
+ """generate N-D grid in dimension order.
543
+
544
+ The ndgrid function is like meshgrid except that the order of the first two input arguments are switched.
545
+
546
+ That is, the statement
547
+ [X1,X2,X3] = ndgrid(x1,x2,x3)
548
+
549
+ produces the same result as
550
+
551
+ [X2,X1,X3] = meshgrid(x2,x1,x3)
552
+
553
+ This naming is based on MATLAB, the purpose is to avoid confusion due to torch's change to make
554
+ torch.meshgrid behaviour move from matching ndgrid ('ij') indexing to numpy meshgrid defaults of ('xy').
555
+
556
+ """
557
+ try:
558
+ return torch.meshgrid(*tensors, indexing='ij')
559
+ except TypeError:
560
+ # old PyTorch < 1.10 will follow this path as it does not have indexing arg,
561
+ # the old behaviour of meshgrid was 'ij'
562
+ return torch.meshgrid(*tensors)
563
+
564
+
565
+ def build_fourier_pos_embed(
566
+ feat_shape: List[int],
567
+ bands: Optional[torch.Tensor] = None,
568
+ num_bands: int = 64,
569
+ max_res: int = 224,
570
+ temperature: float = 10000.0,
571
+ linear_bands: bool = False,
572
+ include_grid: bool = False,
573
+ in_pixels: bool = True,
574
+ ref_feat_shape: Optional[List[int]] = None,
575
+ dtype: torch.dtype = torch.float32,
576
+ device: Optional[torch.device] = None,
577
+ ) -> List[torch.Tensor]:
578
+ """
579
+
580
+ Args:
581
+ feat_shape: Feature shape for embedding.
582
+ bands: Pre-calculated frequency bands.
583
+ num_bands: Number of frequency bands (determines output dim).
584
+ max_res: Maximum resolution for pixel based freq.
585
+ temperature: Temperature for non-pixel freq.
586
+ linear_bands: Linear band spacing for pixel based freq.
587
+ include_grid: Include the spatial grid in output.
588
+ in_pixels: Output in pixel freq.
589
+ ref_feat_shape: Reference feature shape for resize / fine-tune.
590
+ dtype: Output dtype.
591
+ device: Output device.
592
+
593
+ Returns:
594
+
595
+ """
596
+ if bands is None:
597
+ if in_pixels:
598
+ bands = pixel_freq_bands(
599
+ num_bands,
600
+ float(max_res),
601
+ linear_bands=linear_bands,
602
+ device=device,
603
+ )
604
+ else:
605
+ bands = freq_bands(
606
+ num_bands,
607
+ temperature=temperature,
608
+ step=1,
609
+ device=device,
610
+ )
611
+ else:
612
+ if device is None:
613
+ device = bands.device
614
+ if dtype is None:
615
+ dtype = bands.dtype
616
+
617
+ if in_pixels:
618
+ t = [torch.linspace(-1.0, 1.0, steps=s, device=device, dtype=torch.float32) for s in feat_shape]
619
+ else:
620
+ t = [torch.arange(s, device=device, dtype=torch.int64).to(torch.float32) for s in feat_shape]
621
+
622
+ if ref_feat_shape is not None:
623
+ # eva's scheme for resizing rope embeddings (ref shape = pretrain)
624
+ t = [x / f * r for x, f, r in zip(t, feat_shape, ref_feat_shape)]
625
+
626
+ grid = torch.stack(ndgrid(t), dim=-1)
627
+ grid = grid.unsqueeze(-1)
628
+ pos = grid * bands
629
+
630
+ pos_sin, pos_cos = pos.sin().to(dtype=dtype), pos.cos().to(dtype)
631
+ out = [grid, pos_sin, pos_cos] if include_grid else [pos_sin, pos_cos]
632
+ return out
633
+
634
+
635
+ def build_rotary_pos_embed(
636
+ feat_shape: List[int],
637
+ bands: Optional[torch.Tensor] = None,
638
+ dim: int = 64,
639
+ max_res: int = 224,
640
+ temperature: float = 10000.0,
641
+ linear_bands: bool = False,
642
+ in_pixels: bool = True,
643
+ ref_feat_shape: Optional[List[int]] = None,
644
+ dtype: torch.dtype = torch.float32,
645
+ device: Optional[torch.device] = None,
646
+ ):
647
+ """
648
+
649
+ Args:
650
+ feat_shape: Spatial shape of the target tensor for embedding.
651
+ bands: Optional pre-generated frequency bands
652
+ dim: Output dimension of embedding tensor.
653
+ max_res: Maximum resolution for pixel mode.
654
+ temperature: Temperature (inv freq) for non-pixel mode
655
+ linear_bands: Linearly (instead of log) spaced bands for pixel mode
656
+ in_pixels: Pixel vs language (inv freq) mode.
657
+ dtype: Output dtype.
658
+ device: Output device.
659
+
660
+ Returns:
661
+
662
+ """
663
+ sin_emb, cos_emb = build_fourier_pos_embed(
664
+ feat_shape,
665
+ bands=bands,
666
+ num_bands=dim // 4,
667
+ max_res=max_res,
668
+ temperature=temperature,
669
+ linear_bands=linear_bands,
670
+ in_pixels=in_pixels,
671
+ ref_feat_shape=ref_feat_shape,
672
+ device=device,
673
+ dtype=dtype,
674
+ )
675
+ num_spatial_dim = 1
676
+ # this would be much nicer as a .numel() call to torch.Size(), but torchscript sucks
677
+ for x in feat_shape:
678
+ num_spatial_dim *= x
679
+ sin_emb = sin_emb.reshape(num_spatial_dim, -1).repeat_interleave(2, -1)
680
+ cos_emb = cos_emb.reshape(num_spatial_dim, -1).repeat_interleave(2, -1)
681
+ return sin_emb, cos_emb
682
+
683
+
684
+ def freq_bands(
685
+ num_bands: int,
686
+ temperature: float = 10000.0,
687
+ step: int = 2,
688
+ device: Optional[torch.device] = None,
689
+ ) -> torch.Tensor:
690
+ exp = torch.arange(0, num_bands, step, dtype=torch.int64, device=device).to(torch.float32) / num_bands
691
+ bands = 1.0 / (temperature**exp)
692
+ return bands
693
+
694
+
695
+ def pixel_freq_bands(
696
+ num_bands: int,
697
+ max_freq: float = 224.0,
698
+ linear_bands: bool = True,
699
+ device: Optional[torch.device] = None,
700
+ ):
701
+ if linear_bands:
702
+ bands = torch.linspace(1.0, max_freq / 2, num_bands, dtype=torch.float32, device=device)
703
+ else:
704
+ bands = 2 ** torch.linspace(0, math.log(max_freq, 2) - 1, num_bands, dtype=torch.float32, device=device)
705
+ return bands * torch.pi
706
+
707
+
708
+ def rot(x):
709
+ return torch.stack([-x[..., 1::2], x[..., ::2]], -1).reshape(x.shape)
710
+
711
+
712
+ def apply_rot_embed_cat(x: torch.Tensor, emb):
713
+ sin_emb, cos_emb = emb.tensor_split(2, -1)
714
+ if sin_emb.ndim == 3:
715
+ return x * cos_emb.unsqueeze(1).expand_as(x) + rot(x) * sin_emb.unsqueeze(1).expand_as(x)
716
+ return x * cos_emb + rot(x) * sin_emb
717
+
718
+
719
+ # taken from https://github.com/huggingface/pytorch-image-models/blob/cb0e4391beedcc5ac3ae4bce16561b95c326f32c/timm/layers/pos_embed_sincos.py#L363
720
+ class NomicVisionRotaryEmbeddingCat(nn.Module):
721
+ """Rotary position embedding w/ concatenatd sin & cos
722
+
723
+ The following impl/resources were referenced for this impl:
724
+ * https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py
725
+ * https://blog.eleuther.ai/rotary-embeddings/
726
+ """
727
 
728
  def __init__(
729
  self,
730
+ dim,
731
+ max_res=224,
732
+ temperature=10000,
733
+ in_pixels=True,
734
+ linear_bands: bool = False,
735
+ feat_shape: Optional[List[int]] = None,
736
+ ref_feat_shape: Optional[List[int]] = None,
737
+ ):
738
+ super().__init__()
739
+ self.dim = dim
740
+ self.max_res = max_res
741
+ self.temperature = temperature
742
+ self.in_pixels = in_pixels
743
+ self.feat_shape = feat_shape
744
+ self.ref_feat_shape = ref_feat_shape
745
+
746
+ if feat_shape is None:
747
+ # only cache bands
748
+ if in_pixels:
749
+ bands = pixel_freq_bands(
750
+ dim // 4,
751
+ float(max_res),
752
+ linear_bands=linear_bands,
753
+ )
754
+ else:
755
+ bands = freq_bands(
756
+ dim // 4,
757
+ temperature=temperature,
758
+ step=1,
759
+ )
760
+ self.register_buffer(
761
+ 'bands',
762
+ bands,
763
+ persistent=False,
764
+ )
765
+ self.pos_embed = None
766
+ else:
767
+ # cache full sin/cos embeddings if shape provided up front
768
+ embeds = build_rotary_pos_embed(
769
+ feat_shape=feat_shape,
770
+ dim=dim,
771
+ max_res=max_res,
772
+ linear_bands=linear_bands,
773
+ in_pixels=in_pixels,
774
+ ref_feat_shape=self.ref_feat_shape,
775
+ )
776
+ self.bands = None
777
+ self.register_buffer(
778
+ 'pos_embed',
779
+ torch.cat(embeds, -1),
780
+ persistent=False,
781
+ )
782
+
783
+ def get_embed(self, shape: Optional[List[int]] = None):
784
+ if self.bands is not None and shape is not None:
785
+ # rebuild embeddings every call, use if target shape changes
786
+ embeds = build_rotary_pos_embed(
787
+ shape,
788
+ self.bands,
789
+ in_pixels=self.in_pixels,
790
+ ref_feat_shape=self.ref_feat_shape,
791
+ )
792
+ return torch.cat(embeds, -1)
793
+ elif self.pos_embed is not None:
794
+ return self.pos_embed
795
+ else:
796
+ assert False, "get_embed() requires pre-computed pos_embed or valid shape w/ pre-computed bands"
797
+
798
+ def forward(self, x):
799
+ # assuming channel-first tensor where spatial dim are >= 2
800
+ pos_embed = self.get_embed(x.shape[2:])
801
+ return apply_rot_embed_cat(x, pos_embed)
802
+
803
+
804
+ class NomicVisionPatchEmbeddings(nn.Module):
805
+ def __init__(
806
+ self,
807
+ config,
808
+ ):
809
+ super().__init__()
810
+ img_size = _pair(config.img_size)
811
+ patch_size = _pair(config.patch_size)
812
+ self.img_size = img_size
813
+ self.patch_size = patch_size
814
+ self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
815
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
816
+
817
+ self.proj = nn.Linear(
818
+ config.num_channels * patch_size[0] * patch_size[1], config.n_embd, bias=config.patch_embed_bias
819
+ )
820
+
821
+ self.learned_pos_embedding = False
822
+ self.sinusoidal_pos_embedding = False
823
+ self.no_embed_class = getattr(config, "no_embed_class", False)
824
+
825
+ self.cls_token = (
826
+ nn.Parameter(torch.zeros(1, 1, config.n_embd)) if not getattr(config, "no_cls_token", False) else None
827
+ )
828
+ if config.learned_pos_embedding:
829
+ # this is the default in DINO
830
+ self.learned_pos_embedding = True
831
+ # hack for timm dinov2 with registers
832
+ num_patches = self.num_patches if getattr(config, "register_tokens", 0) > 0 else self.num_patches + 1
833
+ self.pos_embed = (
834
+ nn.Parameter(torch.randn(1, num_patches, config.n_embd) * 0.02)
835
+ if getattr(config, "use_pos_embed", True)
836
+ else None
837
+ )
838
+ elif getattr(config, "sinusoidal_pos_embedding", False):
839
+ self.sinusoidal_pos_embedding = True
840
+ if getattr(config, "use_pos_embed", True):
841
+ self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, config.n_embd), requires_grad=False)
842
+ pos_embed = get_2d_sincos_pos_embed(config.n_embd, self.grid_size[0], add_cls_token=True)
843
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).to(self.pos_embed))
844
+ else:
845
+ self.pos_embed = None
846
+ else:
847
+ self.pos_embed = (
848
+ nn.Parameter(torch.randn(1, self.num_patches + 1, config.n_embd) * 0.02)
849
+ if getattr(config, "use_pos_embed", True)
850
+ else None
851
+ )
852
+
853
+ if getattr(config, "register_tokens", 0) > 0:
854
+ self.reg_token = nn.Parameter(torch.randn(1, config.register_tokens, config.n_embd) * 0.02)
855
+ else:
856
+ self.reg_token = None
857
+
858
+ if config.mask_token:
859
+ self.mask_token = nn.Parameter(torch.zeros(1, config.n_embd))
860
+
861
+ self.patch_dropout = nn.Identity()
862
+
863
+ if getattr(config, "use_rotary_pos_emb", False):
864
+ ref_feat_shape = getattr(config, "ref_feat_shape", None)
865
+ ref_feat_shape = to_2tuple(ref_feat_shape) if ref_feat_shape is not None else None
866
+ self.rope = NomicVisionRotaryEmbeddingCat(
867
+ config.n_embd // config.n_head,
868
+ in_pixels=False,
869
+ feat_shape=self.grid_size,
870
+ ref_feat_shape=ref_feat_shape,
871
+ )
872
+ else:
873
+ self.rope = None
874
+
875
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
876
+ """
877
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
878
+ resolution images.
879
+
880
+ Source:
881
+ https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
882
+ """
883
+ num_patches = embeddings.shape[1] - 1
884
+ num_positions = self.pos_embed.shape[1] - 1
885
+ if num_patches == num_positions and height == width:
886
+ return self.pos_embed
887
+ class_pos_embed = self.pos_embed[:, 0]
888
+ patch_pos_embed = self.pos_embed[:, 1:]
889
+ dim = embeddings.shape[-1]
890
+ height = height // self.patch_size[0]
891
+ width = width // self.patch_size[1]
892
+ # we add a small number to avoid floating point error in the interpolation
893
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
894
+ height, width = height + 0.1, width + 0.1
895
+ patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
896
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
897
+ patch_pos_embed = nn.functional.interpolate(
898
+ patch_pos_embed,
899
+ scale_factor=(height / math.sqrt(num_positions), width / math.sqrt(num_positions)),
900
+ mode="bicubic",
901
+ align_corners=False,
902
+ )
903
+ if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]:
904
+ raise ValueError("Width or height does not match with the interpolated position embeddings")
905
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
906
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
907
+
908
+ def forward(self, x):
909
+ # deepspeed case where the input is in fp32
910
+ if x.dtype != self.proj.weight.dtype:
911
+ x = x.to(dtype=self.proj.weight.dtype)
912
+
913
+ _, _, height, width = x.shape
914
+ x = self.proj(
915
+ rearrange(
916
+ x,
917
+ "b c (h p1) (w p2) -> b h w (c p1 p2)",
918
+ p1=self.patch_size[0],
919
+ p2=self.patch_size[1],
920
+ )
921
+ )
922
+ embeddings = rearrange(x, "b h w c -> b (h w) c")
923
+
924
+ to_cat = []
925
+ if self.cls_token is not None:
926
+ if self.sinusoidal_pos_embedding:
927
+ cls_token = self.cls_token + self.pos_embed[:, 0]
928
+ cls_token = cls_token.expand(embeddings.shape[0], -1, -1)
929
+ to_cat += [cls_token]
930
+ else:
931
+ cls_token = self.cls_token.expand(embeddings.shape[0], 1, -1)
932
+ to_cat += [cls_token]
933
+
934
+ if self.reg_token is not None:
935
+ to_cat += [self.reg_token.expand(embeddings.shape[0], -1, -1)]
936
+
937
+ rot_pos_embed = self.rope.get_embed() if self.rope is not None else None
938
+
939
+ if self.no_embed_class:
940
+ if self.learned_pos_embedding:
941
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
942
+ else:
943
+ if self.pos_embed is not None:
944
+ embeddings = embeddings + self.pos_embed
945
+ if to_cat:
946
+ embeddings = torch.cat(to_cat + [embeddings], dim=1)
947
+ else:
948
+ if to_cat:
949
+ embeddings = torch.cat(to_cat + [embeddings], dim=1)
950
+ if self.learned_pos_embedding:
951
+ if self.pos_embed is not None:
952
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
953
+ else:
954
+ if self.pos_embed is not None:
955
+ embeddings = embeddings + self.pos_embed
956
+
957
+ embeddings = self.patch_dropout(embeddings)
958
+
959
+ return embeddings, rot_pos_embed
960
+
961
+
962
+ class NomicBertEmbeddings(nn.Module):
963
+ def __init__(self, config):
964
+ """
965
+ If max_position_embeddings <= 0, there's no position embeddings
966
+ If type_vocab_size <= 0, there's no token type embeddings
967
+ """
968
+ super().__init__()
969
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
970
+ self.max_position_embeddings = config.max_position_embeddings if config.rotary_emb_fraction <= 0 else 0
971
+ self.type_vocab_size = config.type_vocab_size
972
+ if self.max_position_embeddings > 0 and config.rotary_emb_fraction <= 0:
973
+ self.position_embeddings = nn.Embedding(
974
+ config.max_position_embeddings,
975
+ config.hidden_size,
976
+ )
977
+ if self.type_vocab_size > 0:
978
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
979
+
980
+ def forward(self, input_ids=None, position_ids=None, token_type_ids=None, inputs_embeds=None):
981
+ """
982
+ input_ids: (batch, seqlen)
983
+ position_ids: (batch, seqlen)
984
+ token_type_ids: (batch, seqlen)
985
+ """
986
+ if inputs_embeds is None:
987
+ embeddings = self.word_embeddings(input_ids)
988
+ else:
989
+ embeddings = inputs_embeds
990
+ batch_size, seqlen, _ = embeddings.shape
991
+
992
+ if self.type_vocab_size > 0:
993
+ if token_type_ids is None:
994
+ token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=embeddings.device)
995
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
996
+ embeddings = embeddings + token_type_embeddings
997
+
998
+ if self.max_position_embeddings > 0:
999
+ if position_ids is None:
1000
+ position_ids = torch.arange(seqlen, dtype=torch.long, device=embeddings.device)
1001
+ position_embeddings = self.position_embeddings(position_ids)
1002
+ embeddings = embeddings + position_embeddings
1003
+ return embeddings
1004
+
1005
+
1006
+ class NomicBertMLP(nn.Module):
1007
+ def __init__(
1008
+ self,
1009
+ in_features,
1010
+ hidden_features=None,
1011
+ out_features=None,
1012
+ activation=F.gelu,
1013
+ bias1=True,
1014
+ bias2=True,
1015
+ return_residual=False,
1016
+ fused_bias_fc=False,
1017
+ ):
1018
+ super().__init__()
1019
+ out_features = out_features if out_features is not None else in_features
1020
+ hidden_features = hidden_features if hidden_features is not None else in_features * 4
1021
+ self.return_residual = return_residual
1022
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1)
1023
+ approximate = "tanh" if activation in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"] else "none"
1024
+ self.activation = nn.GELU(approximate=approximate) if activation == "gelu" else activation
1025
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2)
1026
+
1027
+ def forward(self, x):
1028
+ y = self.fc1(x)
1029
+ y = self.activation(y)
1030
+ y = self.fc2(y)
1031
+ return y if not self.return_residual else (y, x)
1032
+
1033
+
1034
+ class NomciBertGatedMLP(nn.Module):
1035
+ def __init__(
1036
+ self,
1037
+ in_features,
1038
+ hidden_features=None,
1039
+ out_features=None,
1040
+ activation=F.sigmoid,
1041
+ bias1=True,
1042
+ bias2=True,
1043
+ multiple_of=256,
1044
+ return_residual=False,
1045
+ fused_bias_fc=True,
1046
+ device=None,
1047
+ dtype=None,
1048
+ norm_layer=False,
1049
+ ):
1050
+ super().__init__()
1051
+ out_features = out_features if out_features is not None else in_features
1052
+ hidden_features = hidden_features if hidden_features is not None else int(8 * in_features / 3)
1053
+ hidden_features = int((hidden_features + multiple_of - 1) // multiple_of * multiple_of)
1054
+ self.return_residual = return_residual
1055
+
1056
+ self.fc11 = nn.Linear(in_features, hidden_features, bias=bias1)
1057
+ self.fc12 = nn.Linear(in_features, hidden_features, bias=bias1)
1058
+ self.activation = activation
1059
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2)
1060
+ self.norm = nn.LayerNorm(hidden_features) if norm_layer else nn.Identity()
1061
+
1062
+ def forward(self, x):
1063
+ y = self.fc11(x)
1064
+ gate = self.fc12(x)
1065
+ if self.activation == F.sigmoid: # Special case for GLU
1066
+ y = F.glu(torch.cat([y, gate], dim=-1), dim=-1)
1067
+ else:
1068
+ y = y * self.activation(gate)
1069
+
1070
+ # eva uses layer norm after the activation
1071
+ y = self.norm(y)
1072
+
1073
+ y = self.fc2(y)
1074
+ return y if not self.return_residual else (y, x)
1075
+
1076
+
1077
+ def rotate_half(x, interleaved=False):
1078
+ if not interleaved:
1079
+ x1, x2 = x.chunk(2, dim=-1)
1080
+ return torch.cat((-x2, x1), dim=-1)
1081
+ else:
1082
+ x1, x2 = x[..., ::2], x[..., 1::2]
1083
+ return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2)
1084
+
1085
+
1086
+ def apply_rotary_emb(x, cos, sin, offset=0, interleaved=False):
1087
+ """
1088
+ x: (batch_size, seqlen, nheads, headdim)
1089
+ cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
1090
+ """
1091
+ ro_dim = cos.shape[-1] * 2
1092
+ assert ro_dim <= x.shape[-1]
1093
+ cos, sin = (
1094
+ cos[offset : offset + x.shape[1]],
1095
+ sin[offset : offset + x.shape[1]],
1096
+ )
1097
+ cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
1098
+ sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
1099
+ return torch.cat(
1100
+ [x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]],
1101
+ dim=-1,
1102
+ )
1103
+
1104
+
1105
+ class NomicBertRotaryEmbedding(nn.Module):
1106
+ def __init__(
1107
+ self,
1108
+ dim: int,
1109
+ base=10000.0,
1110
+ interleaved=False,
1111
+ scale_base=None,
1112
+ pos_idx_in_fp32=True,
1113
+ device=None,
1114
+ ):
1115
+ """
1116
+ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
1117
+ of 1st half and 2nd half (GPT-NeoX style).
1118
+ pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
1119
+ otherwise they might be in lower precision.
1120
+ This option was added because previously (before 2023-07-02), when we construct
1121
+ the position indices, we use the dtype of self.inv_freq. In most cases this would
1122
+ be fp32, but if the model is trained in pure bf16 (not mixed precision), then
1123
+ self.inv_freq would be bf16, and the position indices are also in bf16.
1124
+ Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
1125
+ embeddings for some positions will coincide.
1126
+ To maintain compatibility with models previously trained in pure bf16,
1127
+ we add this option.
1128
+ """
1129
+ super().__init__()
1130
+ self.dim = dim
1131
+ self.base = float(base)
1132
+ self.pos_idx_in_fp32 = pos_idx_in_fp32
1133
+ # Generate and save the inverse frequency buffer (non trainable)
1134
+ inv_freq = self._compute_inv_freq(device)
1135
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
1136
+ self.interleaved = interleaved
1137
+ self.scale_base = scale_base
1138
+ scale = (
1139
+ (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
1140
+ if scale_base is not None
1141
+ else None
1142
+ )
1143
+ self.register_buffer("scale", scale, persistent=False)
1144
+
1145
+ self._seq_len_cached = 0
1146
+ self._cos_cached = None
1147
+ self._sin_cached = None
1148
+ self._cos_k_cached = None
1149
+ self._sin_k_cached = None
1150
+
1151
+ def _compute_inv_freq(self, device=None):
1152
+ return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
1153
+
1154
+ def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
1155
+ # Reset the tables if the sequence length has changed,
1156
+ # if we're on a new device (possibly due to tracing for instance),
1157
+ # or if we're switching from inference mode to training
1158
+ if (
1159
+ seqlen > self._seq_len_cached
1160
+ or self._cos_cached is None
1161
+ or self._cos_cached.device != device
1162
+ or self._cos_cached.dtype != dtype
1163
+ or (self.training and self._cos_cached.is_inference())
1164
+ ):
1165
+ self._seq_len_cached = seqlen
1166
+ # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
1167
+ # And the output of arange can be quite large, so bf16 would lose a lot of precision.
1168
+ # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
1169
+ if self.pos_idx_in_fp32:
1170
+ t = torch.arange(seqlen, device=device, dtype=torch.float32)
1171
+ # We want fp32 here as well since inv_freq will be multiplied with t, and the output
1172
+ # will be large. Having it in bf16 will lose a lot of precision and cause the
1173
+ # cos & sin output to change significantly.
1174
+ # We want to recompute self.inv_freq if it was not loaded in fp32
1175
+ if self.inv_freq.dtype != torch.float32:
1176
+ inv_freq = self._compute_inv_freq(device=device)
1177
+ else:
1178
+ inv_freq = self.inv_freq
1179
+ else:
1180
+ t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
1181
+ inv_freq = self.inv_freq
1182
+ # Don't do einsum, it converts fp32 to fp16 under AMP
1183
+ # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
1184
+ freqs = torch.outer(t, inv_freq)
1185
+ self._cos_cached = torch.cos(freqs).to(dtype)
1186
+ self._sin_cached = torch.sin(freqs).to(dtype)
1187
+
1188
+ def forward(
1189
+ self,
1190
+ qkv: torch.Tensor,
1191
+ kv: Optional[torch.Tensor] = None,
1192
+ seqlen_offset: Union[int, torch.Tensor] = 0,
1193
+ max_seqlen: Optional[int] = None,
1194
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1195
+ """
1196
+ qkv: (batch, seqlen, 3, nheads, headdim) if kv is none,
1197
+ else it's just q of shape (batch, seqlen, nheads, headdim)
1198
+ kv: (batch, seqlen, 2, nheads, headdim)
1199
+ seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount.
1200
+ Most commonly used in inference when we have KV cache.
1201
+ If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one
1202
+ should pass in max_seqlen, which will update the cos / sin cache up to that length.
1203
+ Apply rotary embedding *inplace* to qkv and / or kv.
1204
+ """
1205
+ seqlen = qkv.shape[1]
1206
+ if seqlen > self._seq_len_cached:
1207
+ self._update_cos_sin_cache(seqlen, device=qkv.device, dtype=qkv.dtype)
1208
+ elif max_seqlen is not None:
1209
+ self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
1210
+ elif isinstance(seqlen_offset, int):
1211
+ self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
1212
+
1213
+ q_rot = apply_rotary_emb(qkv[:, :, 0], self._cos_cached, self._sin_cached, seqlen_offset, self.interleaved)
1214
+ k_rot = apply_rotary_emb(qkv[:, :, 1], self._cos_cached, self._sin_cached, seqlen_offset, self.interleaved)
1215
+ return torch.stack((q_rot, k_rot, qkv[:, :, 2]), dim=2)
1216
+
1217
+
1218
+ class NomicBertDynamicNTKRotaryEmbedding(NomicBertRotaryEmbedding):
1219
+ def __init__(self, rotary_scaling_factor, max_position_embeddings, **kwargs):
1220
+ super().__init__(**kwargs)
1221
+ self.rotary_scaling_factor = rotary_scaling_factor
1222
+ self.max_position_embeddings = max_position_embeddings
1223
+
1224
+ def _compute_inv_freq(self, base=None, device=None):
1225
+ if base is None:
1226
+ base = self.base
1227
+ return 1.0 / (base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
1228
+
1229
+ def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
1230
+ # Reset the tables if the sequence length has changed,
1231
+ # if we're on a new device (possibly due to tracing for instance),
1232
+ # or if we're switching from inference mode to training
1233
+ if seqlen > self.max_position_embeddings:
1234
+ base = self.base * (
1235
+ (self.rotary_scaling_factor * seqlen / self.max_position_embeddings) - (self.rotary_scaling_factor - 1)
1236
+ ) ** (self.dim / (self.dim - 2))
1237
+ inv_freq = self._compute_inv_freq(base=base, device=device)
1238
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
1239
+
1240
+ if (
1241
+ seqlen > self._seq_len_cached
1242
+ or self._cos_cached is None
1243
+ or self._cos_cached.device != device
1244
+ or self._cos_cached.dtype != dtype
1245
+ or (self.training and self._cos_cached.is_inference())
1246
+ ):
1247
+ self._seq_len_cached = seqlen
1248
+ # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
1249
+ # And the output of arange can be quite large, so bf16 would lose a lot of precision.
1250
+ # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
1251
+ if self.pos_idx_in_fp32:
1252
+ t = torch.arange(seqlen, device=device, dtype=torch.float32)
1253
+ # We want fp32 here as well since inv_freq will be multiplied with t, and the output
1254
+ # will be large. Having it in bf16 will lose a lot of precision and cause the
1255
+ # cos & sin output to change significantly.
1256
+ # We want to recompute self.inv_freq if it was not loaded in fp32
1257
+ if self.inv_freq.dtype != torch.float32:
1258
+ if seqlen > self.max_position_embeddings:
1259
+ base = self.base * (
1260
+ (self.scaling_factor * seqlen / self.max_position_embeddings) - (self.scaling_factor - 1)
1261
+ ) ** (self.dim / (self.dim - 2))
1262
+ else:
1263
+ base = self.base
1264
+ inv_freq = self._compute_inv_freq(device=device, base=base)
1265
+ else:
1266
+ inv_freq = self.inv_freq
1267
+ else:
1268
+ t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
1269
+ inv_freq = self.inv_freq
1270
+ # Don't do einsum, it converts fp32 to fp16 under AMP
1271
+ # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
1272
+ freqs = torch.outer(t, inv_freq)
1273
+ if self.scale is None:
1274
+ self._cos_cached = torch.cos(freqs).to(dtype)
1275
+ self._sin_cached = torch.sin(freqs).to(dtype)
1276
+ else:
1277
+ power = (
1278
+ torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
1279
+ ) / self.scale_base
1280
+ scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
1281
+ # We want the multiplication by scale to happen in fp32
1282
+ self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
1283
+ self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
1284
+ self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
1285
+ self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
1286
+
1287
+
1288
+ class NomicBertAttention(nn.Module):
1289
+ """Multi-head self-attention and cross-attention"""
1290
+
1291
+ def __init__(
1292
+ self,
1293
+ config,
1294
+ ) -> None:
1295
+ """
1296
+ num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
1297
+ return_residual: whether to return the input x along with the output. This is for
1298
+ performance reason: for post-norm architecture, returning the input allows us
1299
+ to fuse the backward of nn.Linear with the residual connection.
1300
+ """
1301
+ super().__init__()
1302
+ self.embed_dim = config.n_embd
1303
+ self.use_flash_attn = config.use_flash_attn
1304
+ self.fused_bias_fc = config.fused_bias_fc
1305
+
1306
+ self.num_heads = config.n_head
1307
+ self.num_heads_kv = config.num_heads_kv if getattr(config, "num_heads_kv", None) is not None else self.num_heads
1308
+ assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads"
1309
+ self.head_dim = self.embed_dim // self.num_heads
1310
+ # we don't really support mqa / gqa for now
1311
+ qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
1312
+
1313
+ self.register_buffer(
1314
+ "norm_factor",
1315
+ torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype()),
1316
+ persistent=False,
1317
+ )
1318
+
1319
+ self.rotary_emb_dim = self.head_dim * config.rotary_emb_fraction
1320
+ if self.rotary_emb_dim > 0:
1321
+ if getattr(config, "rotary_scaling_factor", None):
1322
+ self.rotary_emb = NomicBertDynamicNTKRotaryEmbedding(
1323
+ dim=self.rotary_emb_dim,
1324
+ base=config.rotary_emb_base,
1325
+ scale_base=config.rotary_emb_scale_base,
1326
+ interleaved=config.rotary_emb_interleaved,
1327
+ rotary_scaling_factor=config.rotary_scaling_factor,
1328
+ max_position_embeddings=config.max_trained_positions,
1329
+ )
1330
+ else:
1331
+ self.rotary_emb = NomicBertRotaryEmbedding(
1332
+ dim=self.rotary_emb_dim,
1333
+ base=config.rotary_emb_base,
1334
+ scale_base=config.rotary_emb_scale_base,
1335
+ interleaved=config.rotary_emb_interleaved,
1336
+ )
1337
+ # bug in xformers: https://github.com/facebookresearch/xformers/issues/841
1338
+ # uses the head dimension instead of the sequence dimension
1339
+ self.rotary_head_dim = getattr(config, "rotary_head_dim", False)
1340
+
1341
+ self.Wqkv = nn.Linear(self.embed_dim, qkv_dim, bias=config.qkv_proj_bias)
1342
+
1343
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_proj_bias)
1344
+ self.causal = config.causal
1345
+ self.drop = nn.Dropout(config.attn_pdrop)
1346
+ self.num_prefix_tokens = max(getattr(config, "register_tokens", 1), 1)
1347
+
1348
+ def forward(
1349
+ self,
1350
+ hidden_states: torch.Tensor,
1351
+ attention_mask: Optional[torch.Tensor] = None,
1352
+ position_ids: Optional[torch.LongTensor] = None,
1353
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
1354
+ output_attentions: bool = False,
1355
+ use_cache: bool = False,
1356
+ is_padded_inputs: Optional[bool] = True,
1357
+ cu_seqlens: Optional[torch.Tensor] = None,
1358
+ max_seq_len: Optional[int] = None,
1359
+ rope: Optional[torch.Tensor] = None,
1360
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
1361
+
1362
+ has_layer_past = past_key_value is not None
1363
+
1364
+ if has_layer_past:
1365
+ past_key_value = past_key_value[0]
1366
+ past_len = past_key_value[1]
1367
+ else:
1368
+ past_len = 0
1369
+
1370
+ qkv = self.Wqkv(hidden_states)
1371
+ qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
1372
+
1373
+ past_key_value = (past_key_value, past_len + qkv.size(1)) if use_cache else None
1374
+
1375
+ if self.rotary_emb_dim > 0:
1376
+ if self.rotary_head_dim:
1377
+ qkv = rearrange(qkv, "b s three h d -> b h three s d")
1378
+ qkv = self.rotary_emb(qkv, seqlen_offset=past_len)
1379
+
1380
+ if self.rotary_head_dim:
1381
+ qkv = rearrange(qkv, "b h three s d -> b s three h d")
1382
+ elif rope is not None:
1383
+ q, k, v = qkv.permute(0, 3, 1, 2, 4).unbind(dim=-2)
1384
+ q = torch.cat(
1385
+ [q[:, :, : self.num_prefix_tokens], apply_rot_embed_cat(q[:, :, self.num_prefix_tokens :], rope)], dim=2
1386
+ ).type_as(q)
1387
+ k = torch.cat(
1388
+ [k[:, :, : self.num_prefix_tokens], apply_rot_embed_cat(k[:, :, self.num_prefix_tokens :], rope)], dim=2
1389
+ ).type_as(q)
1390
+
1391
+ qkv = torch.stack([q, k, v], dim=-2)
1392
+ qkv = rearrange(qkv, "b h s three d -> b s three h d")
1393
+
1394
+ query, key, value = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
1395
+
1396
+ query = query.permute(0, 2, 1, 3)
1397
+ key = key.permute(0, 2, 1, 3)
1398
+ value = value.permute(0, 2, 1, 3)
1399
+ if scaled_dot_product_attention is not None:
1400
+ attn_output = F.scaled_dot_product_attention(
1401
+ query, key, value, attn_mask=attention_mask, dropout_p=self.drop.p, is_causal=False
1402
+ )
1403
+ else:
1404
+ attention_scores = torch.matmul(query, key.transpose(-1, -2)) / self.norm_factor
1405
+ if attention_mask is not None:
1406
+ attention_scores = attention_scores + attention_mask
1407
+
1408
+ attentions_probs = F.softmax(attention_scores, dim=-1)
1409
+ attentions_probs = self.drop(attentions_probs)
1410
+
1411
+ attn_output = torch.matmul(attentions_probs, value)
1412
+
1413
+ attn_output = rearrange(attn_output.permute(0, 2, 1, 3), "... h d -> ... (h d)")
1414
+
1415
+ attn_output = self.out_proj(attn_output)
1416
+
1417
+ return attn_output
1418
+
1419
+
1420
+ class NomicBertBlock(NomicBertPreTrainedModel):
1421
+ def __init__(
1422
+ self,
1423
+ config,
1424
+ ):
1425
+ super().__init__(config=config)
1426
+ self.prenorm = config.prenorm
1427
+ self.fused_dropout_add_ln = config.fused_dropout_add_ln
1428
+
1429
+ self.attn = NomicBertAttention(config)
1430
+ activation = (
1431
+ F.sigmoid
1432
+ if config.activation_function == "glu"
1433
+ else (F.silu if config.activation_function == "swiglu" else F.gelu)
1434
+ )
1435
+ if config.activation_function in ["glu", "swiglu", "geglu"]:
1436
+ self.mlp = NomciBertGatedMLP(
1437
+ config.n_embd,
1438
+ hidden_features=config.n_inner,
1439
+ bias1=config.mlp_fc1_bias,
1440
+ bias2=config.mlp_fc2_bias,
1441
+ activation=activation,
1442
+ fused_bias_fc=config.fused_bias_fc,
1443
+ norm_layer=getattr(config, "norm_mlp", False),
1444
+ )
1445
+ else:
1446
+ self.mlp = NomicBertMLP(
1447
+ config.n_embd,
1448
+ hidden_features=config.n_inner,
1449
+ bias1=config.mlp_fc1_bias,
1450
+ bias2=config.mlp_fc2_bias,
1451
+ activation=activation,
1452
+ fused_bias_fc=config.fused_bias_fc,
1453
+ )
1454
+
1455
+ self.dropout1 = nn.Dropout(config.resid_pdrop)
1456
+ self.norm1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
1457
+ self.norm2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
1458
+ self.dropout2 = nn.Dropout(config.resid_pdrop)
1459
+
1460
+ def forward(
1461
+ self,
1462
+ hidden_states: torch.Tensor,
1463
+ hidden_states2: torch.Tensor,
1464
+ residual: Optional[torch.Tensor] = None,
1465
+ attention_mask: Optional[torch.Tensor] = None,
1466
+ position_ids: Optional[torch.LongTensor] = None,
1467
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
1468
+ is_padded_inputs: Optional[bool] = True,
1469
+ output_attentions: Optional[bool] = False,
1470
+ use_cache: Optional[bool] = False,
1471
+ cu_seqlens: Optional[torch.Tensor] = None,
1472
+ max_seq_len: Optional[int] = None,
1473
+ rope: Optional[torch.Tensor] = None,
1474
+ ):
1475
+ r"""Pass the input through the encoder layer.
1476
+
1477
+ Args:
1478
+ hidden_states: the sequence to the encoder layer (required).
1479
+ residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
1480
+ mixer_subset: for cross-attention only. If not None, will take a subset of x
1481
+ before applying the query projection. Useful for e.g., ViT where we only care
1482
+ about the CLS token in the last layer.
1483
+ """
1484
+ if self.prenorm:
1485
+ dropped = self.dropout1(hidden_states)
1486
+ residual = (dropped + residual) if residual is not None else dropped
1487
+ hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
1488
+ hidden_states = self.attn(
1489
+ hidden_states,
1490
+ attention_mask=attention_mask,
1491
+ is_padded_inputs=is_padded_inputs,
1492
+ cu_seqlens=cu_seqlens,
1493
+ max_seq_len=max_seq_len,
1494
+ rope=rope,
1495
+ )
1496
+
1497
+ dropped = self.dropout2(hidden_states)
1498
+ residual = (dropped + residual) if residual is not None else dropped
1499
+ hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
1500
+ hidden_states = self.mlp(hidden_states)
1501
+
1502
+ return hidden_states, None, residual
1503
+ else:
1504
+ assert residual is None
1505
+ attn_outputs = self.attn(
1506
+ hidden_states,
1507
+ attention_mask=attention_mask,
1508
+ is_padded_inputs=is_padded_inputs,
1509
+ cu_seqlens=cu_seqlens,
1510
+ max_seq_len=max_seq_len,
1511
+ rope=rope,
1512
+ )
1513
+ hidden_states = self.norm1((self.dropout1(attn_outputs) + hidden_states).to(dtype=self.norm1.weight.dtype))
1514
+ mlp_out = self.mlp(hidden_states)
1515
+
1516
+ hidden_states = self.norm2((self.dropout2(mlp_out) + hidden_states).to(dtype=self.norm2.weight.dtype))
1517
+ return hidden_states, None, None
1518
+
1519
+
1520
+ class NomicBertEncoder(nn.Module):
1521
+ def __init__(self, config: GPT2Config):
1522
+ super().__init__()
1523
+ self.layers = nn.ModuleList([NomicBertBlock(config) for _ in range(config.n_layer)])
1524
+ self.gradient_checkpointing = False
1525
+ self.config = config
1526
+
1527
+ def forward(
1528
+ self,
1529
+ hidden_states: torch.LongTensor = None,
1530
+ attention_mask: Optional[torch.Tensor] = None,
1531
+ position_ids: Optional[torch.LongTensor] = None,
1532
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1533
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1534
+ use_cache: Optional[bool] = None,
1535
+ output_attentions: Optional[bool] = None,
1536
+ output_hidden_states: Optional[bool] = None,
1537
+ return_dict: Optional[bool] = None,
1538
+ is_padded_inputs: Optional[bool] = True,
1539
+ rope: Optional[torch.Tensor] = None,
1540
+ ):
1541
+ """If subset_mask is not None, we only want output for the subset of the sequence.
1542
+ This means that we only compute the last layer output for these tokens.
1543
+ subset_mask: (batch, seqlen), dtype=torch.bool
1544
+ """
1545
+ hidden_states2 = None
1546
+ residual = None
1547
+
1548
+ for _, layer in enumerate(self.layers):
1549
+ if self.gradient_checkpointing and self.training:
1550
+
1551
+ def create_custom_forward(module):
1552
+ def custom_forward(*inputs):
1553
+ # None for past_key_value
1554
+ return module(*inputs)
1555
+
1556
+ return custom_forward
1557
+
1558
+ hidden_states, hidden_states2, residual = torch.utils.checkpoint.checkpoint(
1559
+ create_custom_forward(layer),
1560
+ hidden_states,
1561
+ hidden_states2,
1562
+ residual,
1563
+ attention_mask,
1564
+ position_ids,
1565
+ past_key_values,
1566
+ is_padded_inputs,
1567
+ output_attentions,
1568
+ use_cache,
1569
+ None,
1570
+ None,
1571
+ rope,
1572
+ # if you freeze ANY layers, you need `use_reentrant=False`
1573
+ # https://github.com/huggingface/transformers/issues/21381
1574
+ # https://discuss.pytorch.org/t/checkpoint-with-no-grad-requiring-inputs-problem/19117/7
1575
+ use_reentrant=False,
1576
+ )
1577
+
1578
+ else:
1579
+ hidden_states, hidden_states2, residual = layer(
1580
+ hidden_states,
1581
+ hidden_states2,
1582
+ residual,
1583
+ attention_mask,
1584
+ position_ids,
1585
+ None,
1586
+ is_padded_inputs,
1587
+ output_attentions,
1588
+ use_cache,
1589
+ rope=rope,
1590
+ )
1591
+ return hidden_states
1592
+
1593
+
1594
+ class NomicBertPooler(nn.Module):
1595
+ def __init__(self, config):
1596
+ super().__init__()
1597
+ self.dense = nn.Linear(config.n_embd, config.n_embd)
1598
+ self.activation = nn.Tanh()
1599
+
1600
+ def forward(self, hidden_states, pool=True):
1601
+ # We "pool" the model by simply taking the hidden state corresponding
1602
+ # to the first token.
1603
+ first_token_tensor = hidden_states[:, 0] if pool else hidden_states
1604
+ pooled_output = self.dense(first_token_tensor)
1605
+ pooled_output = self.activation(pooled_output)
1606
+ return pooled_output
1607
+
1608
+
1609
+ class NomicBertPredictionHeadTransform(nn.Module):
1610
+ def __init__(self, config):
1611
+ super().__init__()
1612
+ self.dense = nn.Linear(config.n_embd, config.n_embd, bias=config.mlp_fc1_bias)
1613
+ approximate = "tanh" if config.activation_function in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"] else "none"
1614
+ if config.activation_function == "swiglu":
1615
+ self.transform_act_fn = F.silu
1616
+ else:
1617
+ self.transform_act_fn = nn.GELU(approximate=approximate)
1618
+
1619
+ self.layer_norm = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
1620
+
1621
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
1622
+ hidden_states = self.dense(hidden_states)
1623
+ hidden_states = self.transform_act_fn(hidden_states)
1624
+ hidden_states = self.layer_norm(hidden_states)
1625
+
1626
+ return hidden_states
1627
+
1628
+
1629
+ class NomicBertLMPredictionHead(nn.Module):
1630
+ def __init__(self, config):
1631
+ super().__init__()
1632
+
1633
+ self.transform = NomicBertPredictionHeadTransform(config)
1634
+
1635
+ self.decoder = nn.Linear(config.n_embd, config.vocab_size, bias=config.mlp_fc1_bias)
1636
+
1637
+ def forward(self, hidden_states):
1638
+ hidden_states = self.transform(hidden_states)
1639
+ hidden_states = self.decoder(hidden_states)
1640
+ return hidden_states
1641
+
1642
+
1643
+ class NomicBertPreTrainingHeads(nn.Module):
1644
+ def __init__(self, config):
1645
+ super().__init__()
1646
+ self.predictions = NomicBertLMPredictionHead(config)
1647
+
1648
+ def forward(self, sequence_output):
1649
+ prediction_scores = self.predictions(sequence_output)
1650
+ return prediction_scores
1651
+
1652
+
1653
+ class NomicBertModel(NomicBertPreTrainedModel):
1654
+ def __init__(self, config: GPT2Config, add_pooling_layer=True):
1655
+ super().__init__(config)
1656
+ self.pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
1657
+ if config.vocab_size % self.pad_vocab_size_multiple != 0:
1658
+ config.vocab_size += self.pad_vocab_size_multiple - (config.vocab_size % self.pad_vocab_size_multiple)
1659
+
1660
+ assert config.activation_function in [
1661
+ "gelu",
1662
+ "gelu_new",
1663
+ "gelu_fast",
1664
+ "gelu_pytorch_tanh",
1665
+ "swiglu",
1666
+ "geglu",
1667
+ "glu",
1668
+ ]
1669
+
1670
+ self.embeddings = NomicBertEmbeddings(config)
1671
+ self.emb_drop = nn.Dropout(config.resid_pdrop)
1672
+ self.emb_ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
1673
+ self.encoder = NomicBertEncoder(config)
1674
+ self.pooler = NomicBertPooler(config) if add_pooling_layer else None
1675
+
1676
+ self.apply(partial(_init_weights, initializer_range=config.initializer_range))
1677
+
1678
+ def forward(
1679
+ self,
1680
+ input_ids=None,
1681
+ attention_mask=None,
1682
+ position_ids=None,
1683
+ token_type_ids=None,
1684
+ return_dict=None,
1685
+ matryoshka_dim=None,
1686
+ inputs_embeds=None,
1687
+ ):
1688
+ if input_ids is not None and inputs_embeds is not None:
1689
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
1690
+ hidden_states = self.embeddings(
1691
+ input_ids=input_ids,
1692
+ position_ids=position_ids,
1693
+ token_type_ids=token_type_ids,
1694
+ inputs_embeds=inputs_embeds,
1695
+ )
1696
+ hidden_states = self.emb_ln(hidden_states)
1697
+ hidden_states = self.emb_drop(hidden_states)
1698
+
1699
+ attention_mask = self.get_extended_attention_mask(attention_mask, hidden_states.shape[:-1])
1700
+ sequence_output = self.encoder(hidden_states, attention_mask=attention_mask, return_dict=return_dict)
1701
+
1702
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
1703
+
1704
+ if matryoshka_dim:
1705
+ sequence_output = sequence_output[:, :matryoshka_dim]
1706
+
1707
+ return BaseModelOutputWithPoolingAndCrossAttentions(
1708
+ last_hidden_state=sequence_output,
1709
+ pooler_output=pooled_output,
1710
+ )
1711
+
1712
+
1713
+ class NomicBertForPreTraining(NomicBertPreTrainedModel):
1714
+ _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
1715
+
1716
+ def __init__(self, config: GPT2Config):
1717
+ super().__init__(config)
1718
+
1719
+ self.bert = NomicBertModel(config, add_pooling_layer=getattr(config, "add_pooling_layer", False))
1720
+ self.cls = NomicBertPreTrainingHeads(config)
1721
+ self.mlm_loss = nn.CrossEntropyLoss()
1722
+
1723
+ # Initialize weights and apply final processing
1724
+ self.apply(partial(_init_weights, initializer_range=config.initializer_range))
1725
+ self.tie_weights()
1726
+
1727
+ def tie_weights(self):
1728
+ self.cls.predictions.decoder.weight = self.bert.embeddings.word_embeddings.weight
1729
+
1730
+ def forward(
1731
+ self,
1732
+ input_ids,
1733
+ position_ids=None,
1734
+ token_type_ids=None,
1735
+ attention_mask=None,
1736
+ labels=None,
1737
+ ):
1738
+ """
1739
+ If labels are provided, they must be -100 for masked out tokens (as specified in the attention
1740
+ mask).
1741
+ Outputs:
1742
+ if `labels` and `next_sentence_label` are not `None`:
1743
+ Outputs the total_loss which is the sum of the masked language modeling loss and the next
1744
+ sentence classification loss.
1745
+ if `labels` or `next_sentence_label` is `None`:
1746
+ Outputs a tuple comprising
1747
+ - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and
1748
+ - the next sentence classification logits of shape [batch_size, 2].
1749
+
1750
+ """
1751
+ outputs = self.bert(
1752
+ input_ids,
1753
+ position_ids=position_ids,
1754
+ token_type_ids=token_type_ids,
1755
+ attention_mask=attention_mask.bool() if attention_mask is not None else None,
1756
+ )
1757
+ sequence_output, _ = outputs.last_hidden_state, outputs.pooler_output
1758
+
1759
+ prediction_scores = self.cls(sequence_output)
1760
+
1761
+ total_loss = None
1762
+ if labels is not None:
1763
+ masked_lm_loss = self.mlm_loss(
1764
+ rearrange(prediction_scores, "... v -> (...) v"),
1765
+ rearrange(labels, "... -> (...)"),
1766
+ )
1767
+ total_loss = masked_lm_loss.float()
1768
+
1769
+ return MaskedLMOutput(
1770
+ loss=total_loss,
1771
+ logits=prediction_scores,
1772
+ hidden_states=outputs.hidden_states,
1773
+ attentions=None,
1774
+ )
1775
+
1776
+
1777
+ class NomicBertForSequenceClassification(NomicBertPreTrainedModel):
1778
+ def __init__(self, config):
1779
+ super().__init__(config)
1780
+ self.num_labels = config.num_labels
1781
+ self.config = config
1782
+
1783
+ self.bert = NomicBertModel(config)
1784
+ classifier_dropout = getattr(config, "classifier_dropout", config.embd_pdrop)
1785
+ self.dropout = nn.Dropout(classifier_dropout)
1786
+ self.classifier = nn.Linear(config.n_embd, config.num_labels)
1787
+
1788
+ # Initialize weights and apply final processing
1789
+ self.post_init()
1790
+
1791
+ def forward(
1792
+ self,
1793
+ input_ids: Optional[torch.Tensor] = None,
1794
+ attention_mask: Optional[torch.Tensor] = None,
1795
+ token_type_ids: Optional[torch.Tensor] = None,
1796
+ position_ids: Optional[torch.Tensor] = None,
1797
+ head_mask: Optional[torch.Tensor] = None,
1798
+ inputs_embeds: Optional[torch.Tensor] = None,
1799
+ labels: Optional[torch.Tensor] = None,
1800
+ output_attentions: Optional[bool] = None,
1801
+ output_hidden_states: Optional[bool] = None,
1802
+ return_dict: Optional[bool] = None,
1803
+ ):
1804
+ r"""
1805
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1806
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1807
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1808
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1809
+ """
1810
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1811
+ outputs = self.bert(
1812
+ input_ids,
1813
+ position_ids=position_ids,
1814
+ token_type_ids=token_type_ids,
1815
+ attention_mask=attention_mask.bool() if attention_mask is not None else None,
1816
+ )
1817
+
1818
+ pooled_output = outputs[1]
1819
+
1820
+ pooled_output = self.dropout(pooled_output)
1821
+ logits = self.classifier(pooled_output)
1822
+
1823
+ loss = None
1824
+ if labels is not None:
1825
+ if self.config.problem_type is None:
1826
+ if self.num_labels == 1:
1827
+ self.config.problem_type = "regression"
1828
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1829
+ self.config.problem_type = "single_label_classification"
1830
+ else:
1831
+ self.config.problem_type = "multi_label_classification"
1832
+
1833
+ if self.config.problem_type == "regression":
1834
+ loss_fct = nn.MSELoss()
1835
+ if self.num_labels == 1:
1836
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1837
+ else:
1838
+ loss = loss_fct(logits, labels)
1839
+ elif self.config.problem_type == "single_label_classification":
1840
+ loss_fct = nn.CrossEntropyLoss()
1841
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1842
+ elif self.config.problem_type == "multi_label_classification":
1843
+ loss_fct = nn.BCEWithLogitsLoss()
1844
+ loss = loss_fct(logits, labels)
1845
+ if not return_dict:
1846
+ output = (logits,) + outputs[2:]
1847
+ return ((loss,) + output) if loss is not None else output
1848
+
1849
+ return SequenceClassifierOutput(
1850
+ loss=loss,
1851
+ logits=logits,
1852
+ hidden_states=outputs.hidden_states,
1853
+ attentions=outputs.attentions,
1854
+ )
1855
+
1856
+
1857
+ def hf_vit_config_to_vit_config(vit_config: ViTConfig) -> GPT2Config:
1858
+ return GPT2Config(
1859
+ n_embd=vit_config.hidden_size,
1860
+ n_layer=vit_config.num_hidden_layers,
1861
+ n_head=vit_config.num_attention_heads,
1862
+ n_inner=vit_config.intermediate_size,
1863
+ activation_function=vit_config.hidden_act,
1864
+ vocab_size=0, # no vocab since using patches
1865
+ n_positions=0, # No absolute position embedding
1866
+ resid_pdrop=0.0, # No dropout
1867
+ embd_pdrop=getattr(vit_config, "dropout", 0.0),
1868
+ attn_pdrop=vit_config.attention_probs_dropout_prob,
1869
+ layer_norm_epsilon=vit_config.layer_norm_eps,
1870
+ initializer_range=vit_config.initializer_range,
1871
+ bos_token_id=None,
1872
+ eos_token_id=None,
1873
+ # These are new arguments not in the original GPT2Config
1874
+ drop_path_rate=0.0,
1875
+ # Why is there double layer norm??
1876
+ prepre_layernom=False,
1877
+ layer_scale=False,
1878
+ layer_scale_init=None,
1879
+ img_size=vit_config.image_size,
1880
+ patch_size=vit_config.patch_size,
1881
+ num_channels=vit_config.num_channels,
1882
+ prenorm=True,
1883
  parallel_block=False,
1884
  parallel_block_tied_norm=False,
1885
+ rotary_emb_fraction=0,
1886
+ tie_word_embeddings=False,
1887
+ fused_dropout_add_ln=True,
1888
+ fused_bias_fc=True,
1889
+ patch_embed_bias=True,
1890
+ use_flash_attn=True,
1891
  qkv_proj_bias=True,
1892
+ mlp_fc1_bias=getattr(vit_config, "mlp_fc1_bias", True),
1893
+ mlp_fc2_bias=getattr(vit_config, "mlp_fc2_bias", True),
 
 
 
1894
  use_rms_norm=False,
1895
  causal=False,
1896
+ hidden_features_scaling_factor=1.0,
1897
+ mask_token=False,
1898
+ learned_pos_embedding=False,
1899
+ patch_dropout=0,
1900
+ sinusoidal_pos_embedding=vit_config.model_type == "vit_mae",
1901
+ )
1902
+
1903
+
1904
+ class NomicAttentionPooling(nn.Module):
1905
+ def __init__(self, config):
1906
+ super().__init__()
1907
+ self.embed_dim = config.n_embd
1908
+ self.use_flash_attn = config.use_flash_attn
1909
+ self.fused_bias_fc = config.fused_bias_fc
1910
+
1911
+ self.num_heads = config.n_head
1912
+ self.num_heads_kv = config.num_heads_kv if getattr(config, "num_heads_kv", None) is not None else self.num_heads
1913
+ assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads"
1914
+ self.head_dim = self.embed_dim // self.num_heads
1915
+ # we don't really support mqa / gqa for now
1916
+ kv_dim = 2 * self.head_dim * self.num_heads_kv
1917
+
1918
+ self.register_buffer(
1919
+ "norm_factor",
1920
+ torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype()),
1921
+ persistent=False,
1922
+ )
1923
+
1924
+ self.Wq = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_proj_bias)
1925
+ self.Wkv = nn.Linear(self.embed_dim, kv_dim, bias=config.qkv_proj_bias)
1926
+
1927
+ self.latent = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
1928
+
1929
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_proj_bias)
1930
+ self.causal = config.causal
1931
+ self.drop = nn.Dropout(config.attn_pdrop)
1932
+
1933
+ def init_weights(self):
1934
+ trunc_normal_tf_(self.latent, std=self.embed_dim**-0.5)
1935
+
1936
+ def forward(
1937
+ self,
1938
+ kv,
1939
+ attention_mask=None,
1940
+ cu_seqlens_k=None,
1941
+ max_seqlen_k=None,
1942
+ is_padded_inputs: Optional[bool] = True,
1943
+ output_attentions: bool = False,
1944
+ ):
1945
+ """Implements the multihead softmax attention.
1946
+ Arguments
1947
+ ---------
1948
+ q: The tensor containing the query. (B, Sq, H, D)
1949
+ kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
1950
+ causal: if passed, will override self.causal
1951
+ cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
1952
+ of the sequences in the batch, used to index into q.
1953
+ max_seqlen: int. Maximum sequence length in the batch of q.
1954
+ cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
1955
+ of the sequences in the batch, used to index into kv.
1956
+ max_seqlen_k: int. Maximum sequence length in the batch of k and v.
1957
+ """
1958
+ q_latent = self.latent.expand(kv.size(0), -1, -1)
1959
+ q = self.Wq(q_latent)
1960
+ bsz, q_len, h_size = q.shape
1961
+ kv = self.Wkv(kv)
1962
+ query = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
1963
+ kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
1964
+
1965
+ key, value = kv[:, :, 0], kv[:, :, 1]
1966
+
1967
+ query = query.permute(0, 2, 1, 3)
1968
+ key = key.permute(0, 2, 1, 3)
1969
+ value = value.permute(0, 2, 1, 3)
1970
+
1971
+ attention_scores = torch.matmul(query, key.transpose(-1, -2)) / self.norm_factor
1972
+ if attention_mask is not None:
1973
+ attention_scores = attention_scores + attention_mask
1974
+
1975
+ attentions_probs = F.softmax(attention_scores, dim=-1)
1976
+ attentions_probs = self.drop(attentions_probs)
1977
+
1978
+ attn_output = torch.matmul(attentions_probs, value)
1979
+ attn_output = rearrange(attn_output.permute(0, 2, 1, 3), "... h d -> ... (h d)")
1980
+
1981
+ attn_output = self.out_proj(attn_output)
1982
+
1983
+ return attn_output
1984
+
1985
+
1986
+ class NomicMultiHeadAttentionPooling(nn.Module):
1987
+ def __init__(
1988
+ self,
1989
+ config,
1990
+ ):
1991
+ super().__init__()
1992
+ self.prenorm = config.prenorm
1993
+ self.fused_dropout_add_ln = config.fused_dropout_add_ln
1994
+
1995
+ self.attn = NomicAttentionPooling(config)
1996
+ activation = (
1997
+ F.sigmoid
1998
+ if config.activation_function == "glu"
1999
+ else (F.silu if config.activation_function == "swiglu" else F.gelu)
2000
+ )
2001
+ if config.activation_function in ["glu", "swiglu", "geglu"]:
2002
+ self.mlp = NomciBertGatedMLP(
2003
+ config.n_embd,
2004
+ hidden_features=config.n_inner,
2005
+ bias1=config.mlp_fc1_bias,
2006
+ bias2=config.mlp_fc2_bias,
2007
+ activation=activation,
2008
+ fused_bias_fc=config.fused_bias_fc,
2009
+ )
2010
+ else:
2011
+ self.mlp = NomicBertMLP(
2012
+ config.n_embd,
2013
+ hidden_features=config.n_inner,
2014
+ bias1=config.mlp_fc1_bias,
2015
+ bias2=config.mlp_fc2_bias,
2016
+ activation=activation,
2017
+ fused_bias_fc=config.fused_bias_fc,
2018
+ )
2019
+
2020
+ self.dropout1 = nn.Dropout(config.resid_pdrop)
2021
+ self.norm1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
2022
+ self.dropout2 = nn.Dropout(config.resid_pdrop)
2023
+
2024
+ def forward(
2025
+ self,
2026
+ hidden_states: torch.Tensor,
2027
+ attention_mask: Optional[torch.Tensor] = None,
2028
+ ):
2029
+ r"""Pass the input through the encoder layer.
2030
+
2031
+ Args:
2032
+ hidden_states: the sequence to the encoder layer (required).
2033
+ residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
2034
+ mixer_subset: for cross-attention only. If not None, will take a subset of x
2035
+ before applying the query projection. Useful for e.g., ViT where we only care
2036
+ about the CLS token in the last layer.
2037
+ """
2038
+
2039
+ attn_outputs = self.attn(
2040
+ hidden_states,
2041
+ attention_mask=attention_mask,
2042
+ )
2043
+
2044
+ normed = self.norm1(attn_outputs)
2045
+ hidden_states = hidden_states + self.mlp(normed)
2046
+
2047
+ return hidden_states
2048
+
2049
+
2050
+ class NomicVisionPreTrainedModel(PreTrainedModel):
2051
+ """An abstract class to handle weights initialization and
2052
+ a simple interface for dowloading and loading pretrained models.
2053
+ """
2054
+
2055
+ config_class = NomicBertConfig
2056
+ base_model_prefix = "model"
2057
+ supports_gradient_checkpointing = True
2058
+ _no_split_modules = ["Block"]
2059
+ _skip_keys_device_placement = "past_key_values"
2060
+
2061
+ def __init__(self, config, *inputs, **kwargs):
2062
+ super().__init__(config)
2063
+ if not isinstance(config, GPT2Config):
2064
+ raise ValueError(
2065
+ "Parameter config in `{}(config)` should be an instance of class `GPT2Config`. "
2066
+ "To create a model from a Google pretrained model use "
2067
+ "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
2068
+ self.__class__.__name__, self.__class__.__name__
2069
+ )
2070
+ )
2071
+ self.config = config
2072
+
2073
+
2074
+ class NomicVisionModel(NomicVisionPreTrainedModel):
2075
+ def __init__(self, config):
2076
+ super().__init__(config)
2077
+
2078
+ self.embeddings = NomicVisionPatchEmbeddings(config)
2079
+ self.layers = nn.ModuleList([NomicBertBlock(config) for _ in range(config.n_layer)])
2080
+
2081
+ self.selector = NomicMultiHeadAttentionPooling(config)
2082
+
2083
+ self.global_pool = getattr(config, "global_pool", None)
2084
+ self.num_prefix_tokens = (1 if not getattr(config, "no_cls_token", False) else 0) + getattr(
2085
+ config, "register_tokens", 0
2086
+ )
2087
+
2088
+ self.apply(partial(_init_weights, initializer_range=config.initializer_range))
2089
+
2090
+ def forward(
2091
+ self,
2092
+ pixel_values,
2093
+ attention_mask=None,
2094
+ position_ids=None,
2095
+ token_type_ids=None,
2096
+ return_dict=None,
2097
+ matryoshka_dim=None,
2098
+ ):
2099
+ embeddings, rope = self.embeddings(pixel_values)
2100
+
2101
+ original_dtype = embeddings.dtype
2102
+
2103
+ hidden_states = embeddings
2104
+ # unused but easier to pass to gradient checkpointing as words
2105
+ residual = None
2106
+ for layer in self.layers:
2107
+ # need to pass none for backwards compatability
2108
+ hidden_states, _, residual = layer(
2109
+ hidden_states, None, residual=residual, is_padded_inputs=False, rope=rope
2110
+ )
2111
+
2112
+ hidden_states = hidden_states + residual
2113
+ if self.global_pool == "avg":
2114
+ hidden_states = hidden_states[:, self.num_prefix_tokens :].mean(dim=1)
2115
+
2116
+ pooled_output = self.selector(hidden_states)
2117
 
2118
+ return BaseModelOutputWithPast(
2119
+ last_hidden_state=pooled_output,
2120
+ hidden_states=hidden_states,
2121
+ )