Update modeling_custom.py
Browse files- modeling_custom.py +5 -0
modeling_custom.py
CHANGED
@@ -148,6 +148,11 @@ class Gemma2ForQuantileSequenceClassification(Gemma2PreTrainedModel):
|
|
148 |
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
149 |
"""
|
150 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
|
|
|
|
|
|
|
151 |
transformer_outputs = self.model(
|
152 |
input_ids,
|
153 |
attention_mask=attention_mask,
|
|
|
148 |
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
149 |
"""
|
150 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
151 |
+
if input_ids.shape[0] == 1 and len(input_ids.shape) == 2 and input_ids[0,0] == input_ids[0,0] == 2:
|
152 |
+
input_ids = input_ids[:, 1:]
|
153 |
+
if attention_mask is not None:
|
154 |
+
attention_mask = attention_mask[:, 1:]
|
155 |
+
|
156 |
transformer_outputs = self.model(
|
157 |
input_ids,
|
158 |
attention_mask=attention_mask,
|