Update modelling_landmark_llama.py

#4
by viktoroo - opened
Files changed (1) hide show
  1. modelling_landmark_llama.py +4 -2
modelling_landmark_llama.py CHANGED
@@ -32,6 +32,8 @@ from transformers.modeling_utils import PreTrainedModel
32
  from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
33
  from transformers.models.llama.configuration_llama import LlamaConfig
34
 
 
 
35
 
36
  logger = logging.get_logger(__name__)
37
 
@@ -565,7 +567,7 @@ LLAMA_START_DOCSTRING = r"""
565
  LLAMA_START_DOCSTRING,
566
  )
567
  class LlamaPreTrainedModel(PreTrainedModel):
568
- config_class = LlamaConfig
569
  base_model_prefix = "model"
570
  supports_gradient_checkpointing = True
571
  _no_split_modules = ["LlamaDecoderLayer"]
@@ -873,7 +875,7 @@ class LlamaModel(LlamaPreTrainedModel):
873
 
874
 
875
  class LlamaForCausalLM(LlamaPreTrainedModel):
876
- def __init__(self, config):
877
  super().__init__(config)
878
  self.model = LlamaModel(config)
879
 
 
32
  from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
33
  from transformers.models.llama.configuration_llama import LlamaConfig
34
 
35
+ from .configuration_landmark_llama import LlamaConfig as LandmarkLlamaConfig
36
+
37
 
38
  logger = logging.get_logger(__name__)
39
 
 
567
  LLAMA_START_DOCSTRING,
568
  )
569
  class LlamaPreTrainedModel(PreTrainedModel):
570
+ config_class = LandmarkLlamaConfig
571
  base_model_prefix = "model"
572
  supports_gradient_checkpointing = True
573
  _no_split_modules = ["LlamaDecoderLayer"]
 
875
 
876
 
877
  class LlamaForCausalLM(LlamaPreTrainedModel):
878
+ def __init__(self, config: LandmarkLlamaConfig):
879
  super().__init__(config)
880
  self.model = LlamaModel(config)
881