Upload 7 files
Browse files- modeling_indictrans.py +5 -0
modeling_indictrans.py
CHANGED
@@ -61,6 +61,11 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start
|
|
61 |
|
62 |
|
63 |
def prepare_decoder_input_ids_label(decoder_input_ids, decoder_attention_mask):
|
|
|
|
|
|
|
|
|
|
|
64 |
mask = (decoder_input_ids == eos_token_id)
|
65 |
decoder_input_ids[mask] = 1
|
66 |
decoder_attention_mask[mask] = 0
|
|
|
61 |
|
62 |
|
63 |
def prepare_decoder_input_ids_label(decoder_input_ids, decoder_attention_mask):
|
64 |
+
labels = decoder_input_ids[:, 1:]
|
65 |
+
|
66 |
+
labels_mask = labels == 1
|
67 |
+
labels[labels_mask] = -100
|
68 |
+
|
69 |
mask = (decoder_input_ids == eos_token_id)
|
70 |
decoder_input_ids[mask] = 1
|
71 |
decoder_attention_mask[mask] = 0
|