Crystalcareai commited on
Commit
875b2bf
·
verified ·
1 Parent(s): 42bb1d0

Update modeling_gemmoe.py

Browse files
Files changed (1) hide show
  1. modeling_gemmoe.py +8 -0
modeling_gemmoe.py CHANGED
@@ -1205,6 +1205,14 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
1205
  )
1206
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1207
 
 
 
 
 
 
 
 
 
1208
  outputs = self.model(
1209
  input_ids=input_ids,
1210
  attention_mask=attention_mask,
 
1205
  )
1206
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1207
 
1208
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1209
+ attention_mask = attention_mask.to(device) if attention_mask is not None else None
1210
+ position_ids = position_ids.to(device) if position_ids is not None else None
1211
+ past_key_values = [t.to(device) for t in past_key_values] if past_key_values is not None else None
1212
+ inputs_embeds = inputs_embeds.to(device) if inputs_embeds is not None else None
1213
+ labels = labels.to(device) if labels is not None else None
1214
+ cache_position = cache_position.to(device) if cache_position is not None else None
1215
+
1216
  outputs = self.model(
1217
  input_ids=input_ids,
1218
  attention_mask=attention_mask,