Crystalcareai commited on
Commit
2de5917
·
verified ·
1 Parent(s): 198cba7

Update modeling_gemmoe.py

Browse files
Files changed (1) hide show
  1. modeling_gemmoe.py +38 -14
modeling_gemmoe.py CHANGED
@@ -65,6 +65,10 @@ logger = logging.get_logger(__name__)
65
 
66
  _CONFIG_FOR_DOC = "GemmoeConfig"
67
 
 
 
 
 
68
  def approx_gelu(x):
69
  return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * x**3)))
70
 
@@ -80,6 +84,7 @@ def _get_unpad_data(attention_mask):
80
  )
81
 
82
 
 
83
  class GemmoeRMSNorm(nn.Module):
84
  def __init__(self, dim: int, eps: float = 1e-6):
85
  super().__init__()
@@ -1112,6 +1117,28 @@ class GemmoeModel(GemmoePreTrainedModel):
1112
  return causal_mask
1113
 
1114
  class GemmoeForCausalLM(GemmoePreTrainedModel):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1115
  _tied_weights_keys = ["lm_head.weight"]
1116
 
1117
  def __init__(self, config):
@@ -1126,6 +1153,13 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
1126
  # Initialize weights and apply final processing
1127
  self.post_init()
1128
 
 
 
 
 
 
 
 
1129
  def get_input_embeddings(self):
1130
  return self.model.embed_tokens
1131
 
@@ -1215,11 +1249,10 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
1215
 
1216
  # Handle unused parameters
1217
  if self.training:
1218
- for layer in self.model.layers:
1219
- for expert in layer.block_sparse_moe.experts:
1220
- for param in expert.parameters():
1221
- if param.requires_grad and param.grad is None:
1222
- param.grad = torch.zeros_like(param)
1223
 
1224
  loss = None
1225
  if labels is not None:
@@ -1328,15 +1361,6 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
1328
  tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1329
  )
1330
  return reordered_past
1331
-
1332
- @staticmethod
1333
- def _reorder_cache(past_key_values, beam_idx):
1334
- reordered_past = ()
1335
- for layer_past in past_key_values:
1336
- reordered_past += (
1337
- tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1338
- )
1339
- return reordered_past
1340
  @add_start_docstrings(
1341
  """
1342
  The Gemmoe Model transformer with a sequence classification head on top (linear layer).
 
65
 
66
  _CONFIG_FOR_DOC = "GemmoeConfig"
67
 
68
+ class GemmoeDistributedDataParallel(nn.parallel.DistributedDataParallel):
69
+ def __init__(self, model, **kwargs):
70
+ super().__init__(model, find_unused_parameters=True, **kwargs)
71
+
72
  def approx_gelu(x):
73
  return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * x**3)))
74
 
 
84
  )
85
 
86
 
87
+
88
  class GemmoeRMSNorm(nn.Module):
89
  def __init__(self, dim: int, eps: float = 1e-6):
90
  super().__init__()
 
1117
  return causal_mask
1118
 
1119
  class GemmoeForCausalLM(GemmoePreTrainedModel):
1120
+ r"""
1121
+ The Gemmoe Model transformer with a language modeling head on top for causal language modeling (CLM).
1122
+
1123
+ Args:
1124
+ config (GemmoeConfig): The configuration object for the Gemmoe model.
1125
+
1126
+ Example usage:
1127
+ ```python
1128
+ >>> from transformers import AutoTokenizer, GemmoeForCausalLM
1129
+
1130
+ >>> model = GemmoeForCausalLM.from_pretrained("google/gemmoe-7b")
1131
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/gemmoe-7b")
1132
+
1133
+ >>> prompt = "What is your favorite condiment?"
1134
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1135
+
1136
+ >>> # Generate
1137
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1138
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1139
+ "What is your favorite condiment?"
1140
+ ```
1141
+ """
1142
  _tied_weights_keys = ["lm_head.weight"]
1143
 
1144
  def __init__(self, config):
 
1153
  # Initialize weights and apply final processing
1154
  self.post_init()
1155
 
1156
+ def parallelize(self, device_map=None):
1157
+ self.model = GemmoeDistributedDataParallel(
1158
+ self.model,
1159
+ device_ids=[torch.cuda.current_device()],
1160
+ output_device=torch.cuda.current_device(),
1161
+ )
1162
+
1163
  def get_input_embeddings(self):
1164
  return self.model.embed_tokens
1165
 
 
1249
 
1250
  # Handle unused parameters
1251
  if self.training:
1252
+ for expert in self.model.layers[-1].block_sparse_moe.experts:
1253
+ for param in expert.parameters():
1254
+ if param.requires_grad and param.grad is None:
1255
+ param.grad = torch.zeros_like(param)
 
1256
 
1257
  loss = None
1258
  if labels is not None:
 
1361
  tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1362
  )
1363
  return reordered_past
 
 
 
 
 
 
 
 
 
1364
  @add_start_docstrings(
1365
  """
1366
  The Gemmoe Model transformer with a sequence classification head on top (linear layer).