Update modeling_esm_plusplus.py
Browse files- modeling_esm_plusplus.py +19 -3
modeling_esm_plusplus.py
CHANGED
@@ -339,9 +339,7 @@ class MultiHeadAttention(nn.Module):
|
|
339 |
|
340 |
|
341 |
### Regression Head
|
342 |
-
def RegressionHead(
|
343 |
-
d_model: int, output_dim: int, hidden_dim: Optional[int] = None
|
344 |
-
) -> nn.Module:
|
345 |
"""Create a regression head with optional hidden dimension.
|
346 |
|
347 |
Args:
|
@@ -707,6 +705,12 @@ class ESMplusplusModel(PreTrainedESMplusplusModel):
|
|
707 |
self.tokenizer = EsmSequenceTokenizer()
|
708 |
self.init_weights()
|
709 |
|
|
|
|
|
|
|
|
|
|
|
|
|
710 |
def forward(
|
711 |
self,
|
712 |
input_ids: Optional[torch.Tensor] = None,
|
@@ -752,6 +756,18 @@ class ESMplusplusForMaskedLM(PreTrainedESMplusplusModel):
|
|
752 |
self.tokenizer = EsmSequenceTokenizer()
|
753 |
self.init_weights()
|
754 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
755 |
def forward(
|
756 |
self,
|
757 |
input_ids: Optional[torch.Tensor] = None,
|
|
|
339 |
|
340 |
|
341 |
### Regression Head
|
342 |
+
def RegressionHead(d_model: int, output_dim: int, hidden_dim: Optional[int] = None) -> nn.Module:
|
|
|
|
|
343 |
"""Create a regression head with optional hidden dimension.
|
344 |
|
345 |
Args:
|
|
|
705 |
self.tokenizer = EsmSequenceTokenizer()
|
706 |
self.init_weights()
|
707 |
|
708 |
+
def get_input_embeddings(self):
|
709 |
+
return self.embed
|
710 |
+
|
711 |
+
def set_input_embeddings(self, value):
|
712 |
+
self.embed = value
|
713 |
+
|
714 |
def forward(
|
715 |
self,
|
716 |
input_ids: Optional[torch.Tensor] = None,
|
|
|
756 |
self.tokenizer = EsmSequenceTokenizer()
|
757 |
self.init_weights()
|
758 |
|
759 |
+
def get_input_embeddings(self):
|
760 |
+
return self.embed
|
761 |
+
|
762 |
+
def set_input_embeddings(self, value):
|
763 |
+
self.embed = value
|
764 |
+
|
765 |
+
def get_output_embeddings(self):
|
766 |
+
return self.sequence_head[-1]
|
767 |
+
|
768 |
+
def set_output_embeddings(self, new_embeddings):
|
769 |
+
self.sequence_head[-1] = new_embeddings
|
770 |
+
|
771 |
def forward(
|
772 |
self,
|
773 |
input_ids: Optional[torch.Tensor] = None,
|