lhallee commited on
Commit
3099b74
·
verified ·
1 Parent(s): 60cc6d7

Update modeling_esm_plusplus.py

Browse files
Files changed (1) hide show
  1. modeling_esm_plusplus.py +19 -3
modeling_esm_plusplus.py CHANGED
@@ -339,9 +339,7 @@ class MultiHeadAttention(nn.Module):
339
 
340
 
341
  ### Regression Head
342
- def RegressionHead(
343
- d_model: int, output_dim: int, hidden_dim: Optional[int] = None
344
- ) -> nn.Module:
345
  """Create a regression head with optional hidden dimension.
346
 
347
  Args:
@@ -707,6 +705,12 @@ class ESMplusplusModel(PreTrainedESMplusplusModel):
707
  self.tokenizer = EsmSequenceTokenizer()
708
  self.init_weights()
709
 
 
 
 
 
 
 
710
  def forward(
711
  self,
712
  input_ids: Optional[torch.Tensor] = None,
@@ -752,6 +756,18 @@ class ESMplusplusForMaskedLM(PreTrainedESMplusplusModel):
752
  self.tokenizer = EsmSequenceTokenizer()
753
  self.init_weights()
754
 
 
 
 
 
 
 
 
 
 
 
 
 
755
  def forward(
756
  self,
757
  input_ids: Optional[torch.Tensor] = None,
 
339
 
340
 
341
  ### Regression Head
342
+ def RegressionHead(d_model: int, output_dim: int, hidden_dim: Optional[int] = None) -> nn.Module:
 
 
343
  """Create a regression head with optional hidden dimension.
344
 
345
  Args:
 
705
  self.tokenizer = EsmSequenceTokenizer()
706
  self.init_weights()
707
 
708
+ def get_input_embeddings(self):
709
+ return self.embed
710
+
711
+ def set_input_embeddings(self, value):
712
+ self.embed = value
713
+
714
  def forward(
715
  self,
716
  input_ids: Optional[torch.Tensor] = None,
 
756
  self.tokenizer = EsmSequenceTokenizer()
757
  self.init_weights()
758
 
759
+ def get_input_embeddings(self):
760
+ return self.embed
761
+
762
+ def set_input_embeddings(self, value):
763
+ self.embed = value
764
+
765
+ def get_output_embeddings(self):
766
+ return self.sequence_head[-1]
767
+
768
+ def set_output_embeddings(self, new_embeddings):
769
+ self.sequence_head[-1] = new_embeddings
770
+
771
  def forward(
772
  self,
773
  input_ids: Optional[torch.Tensor] = None,