Crystalcareai
commited on
Update modeling_gemmoe.py
Browse files- modeling_gemmoe.py +8 -0
modeling_gemmoe.py
CHANGED
@@ -1205,6 +1205,14 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
|
|
1205 |
)
|
1206 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1207 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1208 |
outputs = self.model(
|
1209 |
input_ids=input_ids,
|
1210 |
attention_mask=attention_mask,
|
|
|
1205 |
)
|
1206 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1207 |
|
1208 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
1209 |
+
attention_mask = attention_mask.to(device) if attention_mask is not None else None
|
1210 |
+
position_ids = position_ids.to(device) if position_ids is not None else None
|
1211 |
+
past_key_values = [t.to(device) for t in past_key_values] if past_key_values is not None else None
|
1212 |
+
inputs_embeds = inputs_embeds.to(device) if inputs_embeds is not None else None
|
1213 |
+
labels = labels.to(device) if labels is not None else None
|
1214 |
+
cache_position = cache_position.to(device) if cache_position is not None else None
|
1215 |
+
|
1216 |
outputs = self.model(
|
1217 |
input_ids=input_ids,
|
1218 |
attention_mask=attention_mask,
|