Crystalcareai
commited on
Update modeling_gemmoe.py
Browse files- modeling_gemmoe.py +38 -14
modeling_gemmoe.py
CHANGED
@@ -65,6 +65,10 @@ logger = logging.get_logger(__name__)
|
|
65 |
|
66 |
_CONFIG_FOR_DOC = "GemmoeConfig"
|
67 |
|
|
|
|
|
|
|
|
|
68 |
def approx_gelu(x):
|
69 |
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * x**3)))
|
70 |
|
@@ -80,6 +84,7 @@ def _get_unpad_data(attention_mask):
|
|
80 |
)
|
81 |
|
82 |
|
|
|
83 |
class GemmoeRMSNorm(nn.Module):
|
84 |
def __init__(self, dim: int, eps: float = 1e-6):
|
85 |
super().__init__()
|
@@ -1112,6 +1117,28 @@ class GemmoeModel(GemmoePreTrainedModel):
|
|
1112 |
return causal_mask
|
1113 |
|
1114 |
class GemmoeForCausalLM(GemmoePreTrainedModel):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1115 |
_tied_weights_keys = ["lm_head.weight"]
|
1116 |
|
1117 |
def __init__(self, config):
|
@@ -1126,6 +1153,13 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
|
|
1126 |
# Initialize weights and apply final processing
|
1127 |
self.post_init()
|
1128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1129 |
def get_input_embeddings(self):
|
1130 |
return self.model.embed_tokens
|
1131 |
|
@@ -1215,11 +1249,10 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
|
|
1215 |
|
1216 |
# Handle unused parameters
|
1217 |
if self.training:
|
1218 |
-
for
|
1219 |
-
for
|
1220 |
-
|
1221 |
-
|
1222 |
-
param.grad = torch.zeros_like(param)
|
1223 |
|
1224 |
loss = None
|
1225 |
if labels is not None:
|
@@ -1328,15 +1361,6 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
|
|
1328 |
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
1329 |
)
|
1330 |
return reordered_past
|
1331 |
-
|
1332 |
-
@staticmethod
|
1333 |
-
def _reorder_cache(past_key_values, beam_idx):
|
1334 |
-
reordered_past = ()
|
1335 |
-
for layer_past in past_key_values:
|
1336 |
-
reordered_past += (
|
1337 |
-
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
1338 |
-
)
|
1339 |
-
return reordered_past
|
1340 |
@add_start_docstrings(
|
1341 |
"""
|
1342 |
The Gemmoe Model transformer with a sequence classification head on top (linear layer).
|
|
|
65 |
|
66 |
_CONFIG_FOR_DOC = "GemmoeConfig"
|
67 |
|
68 |
+
class GemmoeDistributedDataParallel(nn.parallel.DistributedDataParallel):
|
69 |
+
def __init__(self, model, **kwargs):
|
70 |
+
super().__init__(model, find_unused_parameters=True, **kwargs)
|
71 |
+
|
72 |
def approx_gelu(x):
|
73 |
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * x**3)))
|
74 |
|
|
|
84 |
)
|
85 |
|
86 |
|
87 |
+
|
88 |
class GemmoeRMSNorm(nn.Module):
|
89 |
def __init__(self, dim: int, eps: float = 1e-6):
|
90 |
super().__init__()
|
|
|
1117 |
return causal_mask
|
1118 |
|
1119 |
class GemmoeForCausalLM(GemmoePreTrainedModel):
|
1120 |
+
r"""
|
1121 |
+
The Gemmoe Model transformer with a language modeling head on top for causal language modeling (CLM).
|
1122 |
+
|
1123 |
+
Args:
|
1124 |
+
config (GemmoeConfig): The configuration object for the Gemmoe model.
|
1125 |
+
|
1126 |
+
Example usage:
|
1127 |
+
```python
|
1128 |
+
>>> from transformers import AutoTokenizer, GemmoeForCausalLM
|
1129 |
+
|
1130 |
+
>>> model = GemmoeForCausalLM.from_pretrained("google/gemmoe-7b")
|
1131 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemmoe-7b")
|
1132 |
+
|
1133 |
+
>>> prompt = "What is your favorite condiment?"
|
1134 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
1135 |
+
|
1136 |
+
>>> # Generate
|
1137 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
1138 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
1139 |
+
"What is your favorite condiment?"
|
1140 |
+
```
|
1141 |
+
"""
|
1142 |
_tied_weights_keys = ["lm_head.weight"]
|
1143 |
|
1144 |
def __init__(self, config):
|
|
|
1153 |
# Initialize weights and apply final processing
|
1154 |
self.post_init()
|
1155 |
|
1156 |
+
def parallelize(self, device_map=None):
|
1157 |
+
self.model = GemmoeDistributedDataParallel(
|
1158 |
+
self.model,
|
1159 |
+
device_ids=[torch.cuda.current_device()],
|
1160 |
+
output_device=torch.cuda.current_device(),
|
1161 |
+
)
|
1162 |
+
|
1163 |
def get_input_embeddings(self):
|
1164 |
return self.model.embed_tokens
|
1165 |
|
|
|
1249 |
|
1250 |
# Handle unused parameters
|
1251 |
if self.training:
|
1252 |
+
for expert in self.model.layers[-1].block_sparse_moe.experts:
|
1253 |
+
for param in expert.parameters():
|
1254 |
+
if param.requires_grad and param.grad is None:
|
1255 |
+
param.grad = torch.zeros_like(param)
|
|
|
1256 |
|
1257 |
loss = None
|
1258 |
if labels is not None:
|
|
|
1361 |
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
1362 |
)
|
1363 |
return reordered_past
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1364 |
@add_start_docstrings(
|
1365 |
"""
|
1366 |
The Gemmoe Model transformer with a sequence classification head on top (linear layer).
|