✨ gradient checkpointing
Browse files- modeling_mpt.py +4 -1
modeling_mpt.py
CHANGED
@@ -33,7 +33,10 @@ class MPTPreTrainedModel(PreTrainedModel):
|
|
33 |
base_model_prefix = "model"
|
34 |
supports_gradient_checkpointing = True
|
35 |
_no_split_modules = ["MPTBlock"]
|
36 |
-
|
|
|
|
|
|
|
37 |
|
38 |
class MPTModel(MPTPreTrainedModel):
|
39 |
def __init__(self, config: MPTConfig):
|
|
|
33 |
base_model_prefix = "model"
|
34 |
supports_gradient_checkpointing = True
|
35 |
_no_split_modules = ["MPTBlock"]
|
36 |
+
|
37 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
38 |
+
if isinstance(module, MPTModel):
|
39 |
+
module.gradient_checkpointing = value
|
40 |
|
41 |
class MPTModel(MPTPreTrainedModel):
|
42 |
def __init__(self, config: MPTConfig):
|