Update modeling_norbert.py
Browse files- modeling_norbert.py +13 -13
modeling_norbert.py
CHANGED
@@ -277,12 +277,12 @@ class NorbertPreTrainedModel(PreTrainedModel):
|
|
277 |
|
278 |
|
279 |
class NorbertModel(NorbertPreTrainedModel):
|
280 |
-
def __init__(self, config, add_mlm_layer=False):
|
281 |
-
super().__init__(config)
|
282 |
self.config = config
|
283 |
|
284 |
self.embedding = Embedding(config)
|
285 |
-
self.transformer = Encoder(config, activation_checkpointing=
|
286 |
self.classifier = MaskClassifier(config, self.embedding.word_embedding.weight) if add_mlm_layer else None
|
287 |
|
288 |
def get_input_embeddings(self):
|
@@ -352,8 +352,8 @@ class NorbertModel(NorbertPreTrainedModel):
|
|
352 |
class NorbertForMaskedLM(NorbertModel):
|
353 |
_keys_to_ignore_on_load_unexpected = ["head"]
|
354 |
|
355 |
-
def __init__(self, config):
|
356 |
-
super().__init__(config, add_mlm_layer=True)
|
357 |
|
358 |
def get_output_embeddings(self):
|
359 |
return self.classifier.nonlinearity[-1].weight
|
@@ -432,8 +432,8 @@ class NorbertForSequenceClassification(NorbertModel):
|
|
432 |
_keys_to_ignore_on_load_unexpected = ["classifier"]
|
433 |
_keys_to_ignore_on_load_missing = ["head"]
|
434 |
|
435 |
-
def __init__(self, config):
|
436 |
-
super().__init__(config, add_mlm_layer=False)
|
437 |
|
438 |
self.num_labels = config.num_labels
|
439 |
self.head = Classifier(config, self.num_labels)
|
@@ -498,8 +498,8 @@ class NorbertForTokenClassification(NorbertModel):
|
|
498 |
_keys_to_ignore_on_load_unexpected = ["classifier"]
|
499 |
_keys_to_ignore_on_load_missing = ["head"]
|
500 |
|
501 |
-
def __init__(self, config):
|
502 |
-
super().__init__(config, add_mlm_layer=False)
|
503 |
|
504 |
self.num_labels = config.num_labels
|
505 |
self.head = Classifier(config, self.num_labels)
|
@@ -546,8 +546,8 @@ class NorbertForQuestionAnswering(NorbertModel):
|
|
546 |
_keys_to_ignore_on_load_unexpected = ["classifier"]
|
547 |
_keys_to_ignore_on_load_missing = ["head"]
|
548 |
|
549 |
-
def __init__(self, config):
|
550 |
-
super().__init__(config, add_mlm_layer=False)
|
551 |
|
552 |
self.num_labels = config.num_labels
|
553 |
self.head = Classifier(config, self.num_labels)
|
@@ -614,8 +614,8 @@ class NorbertForMultipleChoice(NorbertModel):
|
|
614 |
_keys_to_ignore_on_load_unexpected = ["classifier"]
|
615 |
_keys_to_ignore_on_load_missing = ["head"]
|
616 |
|
617 |
-
def __init__(self, config):
|
618 |
-
super().__init__(config, add_mlm_layer=False)
|
619 |
|
620 |
self.num_labels = getattr(config, "num_labels", 2)
|
621 |
self.head = Classifier(config, self.num_labels)
|
|
|
277 |
|
278 |
|
279 |
class NorbertModel(NorbertPreTrainedModel):
|
280 |
+
def __init__(self, config, add_mlm_layer=False, gradient_checkpointing=False, **kwargs):
|
281 |
+
super().__init__(config, **kwargs)
|
282 |
self.config = config
|
283 |
|
284 |
self.embedding = Embedding(config)
|
285 |
+
self.transformer = Encoder(config, activation_checkpointing=gradient_checkpointing)
|
286 |
self.classifier = MaskClassifier(config, self.embedding.word_embedding.weight) if add_mlm_layer else None
|
287 |
|
288 |
def get_input_embeddings(self):
|
|
|
352 |
class NorbertForMaskedLM(NorbertModel):
|
353 |
_keys_to_ignore_on_load_unexpected = ["head"]
|
354 |
|
355 |
+
def __init__(self, config, **kwargs):
|
356 |
+
super().__init__(config, add_mlm_layer=True, **kwargs)
|
357 |
|
358 |
def get_output_embeddings(self):
|
359 |
return self.classifier.nonlinearity[-1].weight
|
|
|
432 |
_keys_to_ignore_on_load_unexpected = ["classifier"]
|
433 |
_keys_to_ignore_on_load_missing = ["head"]
|
434 |
|
435 |
+
def __init__(self, config, **kwargs):
|
436 |
+
super().__init__(config, add_mlm_layer=False, **kwargs)
|
437 |
|
438 |
self.num_labels = config.num_labels
|
439 |
self.head = Classifier(config, self.num_labels)
|
|
|
498 |
_keys_to_ignore_on_load_unexpected = ["classifier"]
|
499 |
_keys_to_ignore_on_load_missing = ["head"]
|
500 |
|
501 |
+
def __init__(self, config, **kwargs):
|
502 |
+
super().__init__(config, add_mlm_layer=False, **kwargs)
|
503 |
|
504 |
self.num_labels = config.num_labels
|
505 |
self.head = Classifier(config, self.num_labels)
|
|
|
546 |
_keys_to_ignore_on_load_unexpected = ["classifier"]
|
547 |
_keys_to_ignore_on_load_missing = ["head"]
|
548 |
|
549 |
+
def __init__(self, config, **kwargs):
|
550 |
+
super().__init__(config, add_mlm_layer=False, **kwargs)
|
551 |
|
552 |
self.num_labels = config.num_labels
|
553 |
self.head = Classifier(config, self.num_labels)
|
|
|
614 |
_keys_to_ignore_on_load_unexpected = ["classifier"]
|
615 |
_keys_to_ignore_on_load_missing = ["head"]
|
616 |
|
617 |
+
def __init__(self, config, **kwargs):
|
618 |
+
super().__init__(config, add_mlm_layer=False, **kwargs)
|
619 |
|
620 |
self.num_labels = getattr(config, "num_labels", 2)
|
621 |
self.head = Classifier(config, self.num_labels)
|