nicolinho commited on
Commit
8332a70
·
verified ·
1 Parent(s): c1ec305

Update modeling_custom.py

Browse files
Files changed (1) hide show
  1. modeling_custom.py +5 -0
modeling_custom.py CHANGED
@@ -148,6 +148,11 @@ class Gemma2ForQuantileSequenceClassification(Gemma2PreTrainedModel):
148
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
149
  """
150
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
 
 
 
 
151
  transformer_outputs = self.model(
152
  input_ids,
153
  attention_mask=attention_mask,
 
148
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
149
  """
150
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
151
+ if input_ids.shape[0] == 1 and len(input_ids.shape) == 2 and input_ids[0,0] == input_ids[0,0] == 2:
152
+ input_ids = input_ids[:, 1:]
153
+ if attention_mask is not None:
154
+ attention_mask = attention_mask[:, 1:]
155
+
156
  transformer_outputs = self.model(
157
  input_ids,
158
  attention_mask=attention_mask,