Update modeling_esm_plusplus.py
Browse files- modeling_esm_plusplus.py +31 -15
modeling_esm_plusplus.py
CHANGED
@@ -249,7 +249,7 @@ class SwiGLU(nn.Module):
|
|
249 |
return F.silu(x1) * x2
|
250 |
|
251 |
|
252 |
-
def swiglu_ln_ffn(d_model: int, expansion_ratio: float) -> nn.Sequential:
|
253 |
"""Create SwiGLU feedforward network with layer normalization."""
|
254 |
return nn.Sequential(
|
255 |
nn.LayerNorm(d_model),
|
@@ -257,6 +257,7 @@ def swiglu_ln_ffn(d_model: int, expansion_ratio: float) -> nn.Sequential:
|
|
257 |
d_model, swiglu_correction_fn(expansion_ratio, d_model) * 2, bias=False
|
258 |
),
|
259 |
SwiGLU(),
|
|
|
260 |
nn.Linear(swiglu_correction_fn(expansion_ratio, d_model), d_model, bias=False),
|
261 |
)
|
262 |
|
@@ -372,10 +373,11 @@ class UnifiedTransformerBlock(nn.Module):
|
|
372 |
n_heads: int,
|
373 |
residue_scaling_factor: float = 1,
|
374 |
expansion_ratio: float = 8 / 3,
|
|
|
375 |
):
|
376 |
super().__init__()
|
377 |
self.attn = MultiHeadAttention(d_model, n_heads)
|
378 |
-
self.ffn = swiglu_ln_ffn(d_model, expansion_ratio)
|
379 |
self.scaling_factor = residue_scaling_factor
|
380 |
|
381 |
def forward(
|
@@ -435,6 +437,7 @@ class TransformerStack(nn.Module):
|
|
435 |
d_model: int,
|
436 |
n_heads: int,
|
437 |
n_layers: int,
|
|
|
438 |
):
|
439 |
super().__init__()
|
440 |
self.blocks = nn.ModuleList(
|
@@ -443,6 +446,7 @@ class TransformerStack(nn.Module):
|
|
443 |
d_model,
|
444 |
n_heads,
|
445 |
residue_scaling_factor=math.sqrt(n_layers / 36),
|
|
|
446 |
)
|
447 |
for i in range(n_layers)
|
448 |
]
|
@@ -517,7 +521,7 @@ class ESMplusplusForMaskedLM(PreTrainedModel):
|
|
517 |
self.config = config
|
518 |
self.vocab_size = config.vocab_size
|
519 |
self.embed = nn.Embedding(self.vocab_size, config.hidden_size)
|
520 |
-
self.transformer = TransformerStack(config.hidden_size, config.num_attention_heads, config.num_hidden_layers)
|
521 |
self.sequence_head = RegressionHead(config.hidden_size, self.vocab_size)
|
522 |
self.ce_loss = nn.CrossEntropyLoss()
|
523 |
self.tokenizer = EsmSequenceTokenizer()
|
@@ -649,25 +653,22 @@ class ESMplusplusForMaskedLM(PreTrainedModel):
|
|
649 |
|
650 |
return embeddings_dict
|
651 |
|
652 |
-
"""
|
653 |
-
TODO
|
654 |
-
- Add dropout (default 0.0)
|
655 |
-
- Class method for returning manually computed attention maps
|
656 |
-
"""
|
657 |
-
|
658 |
def forward(
|
659 |
self,
|
660 |
input_ids: Optional[torch.Tensor] = None,
|
661 |
attention_mask: Optional[torch.Tensor] = None,
|
|
|
662 |
labels: Optional[torch.Tensor] = None,
|
663 |
-
|
664 |
-
|
|
|
665 |
) -> ESMplusplusOutput:
|
666 |
"""Forward pass for masked language modeling.
|
667 |
|
668 |
Args:
|
669 |
input_ids: Input token IDs
|
670 |
attention_mask: Attention mask
|
|
|
671 |
labels: Optional labels for masked tokens
|
672 |
output_hidden_states: Whether to return all hidden states
|
673 |
output_attentions: Whether to return attention weights
|
@@ -675,7 +676,10 @@ class ESMplusplusForMaskedLM(PreTrainedModel):
|
|
675 |
Returns:
|
676 |
ESMplusplusOutput containing loss, logits, hidden states and attention weights
|
677 |
"""
|
678 |
-
|
|
|
|
|
|
|
679 |
output = self.transformer(x, attention_mask, output_hidden_states, output_attentions)
|
680 |
x = output.last_hidden_state
|
681 |
logits = self.sequence_head(x)
|
@@ -710,15 +714,18 @@ class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM):
|
|
710 |
self,
|
711 |
input_ids: Optional[torch.Tensor] = None,
|
712 |
attention_mask: Optional[torch.Tensor] = None,
|
|
|
713 |
labels: Optional[torch.Tensor] = None,
|
714 |
-
|
715 |
-
|
|
|
716 |
) -> ESMplusplusOutput:
|
717 |
"""Forward pass for sequence classification.
|
718 |
|
719 |
Args:
|
720 |
input_ids: Input token IDs
|
721 |
attention_mask: Attention mask
|
|
|
722 |
labels: Optional labels for classification
|
723 |
output_hidden_states: Whether to return all hidden states
|
724 |
output_attentions: Whether to return attention weights
|
@@ -729,7 +736,9 @@ class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM):
|
|
729 |
output = super().forward(
|
730 |
input_ids=input_ids,
|
731 |
attention_mask=attention_mask,
|
|
|
732 |
labels=None,
|
|
|
733 |
output_hidden_states=output_hidden_states
|
734 |
)
|
735 |
x = output.last_hidden_state
|
@@ -783,16 +792,21 @@ class ESMplusplusForTokenClassification(ESMplusplusForMaskedLM):
|
|
783 |
self,
|
784 |
input_ids: Optional[torch.Tensor] = None,
|
785 |
attention_mask: Optional[torch.Tensor] = None,
|
|
|
786 |
labels: Optional[torch.Tensor] = None,
|
787 |
-
|
|
|
|
|
788 |
) -> ESMplusplusOutput:
|
789 |
"""Forward pass for token classification.
|
790 |
|
791 |
Args:
|
792 |
input_ids: Input token IDs
|
793 |
attention_mask: Attention mask
|
|
|
794 |
labels: Optional labels for token classification
|
795 |
output_hidden_states: Whether to return all hidden states
|
|
|
796 |
|
797 |
Returns:
|
798 |
ESMplusplusOutput containing loss, logits, and hidden states
|
@@ -800,7 +814,9 @@ class ESMplusplusForTokenClassification(ESMplusplusForMaskedLM):
|
|
800 |
output = super().forward(
|
801 |
input_ids=input_ids,
|
802 |
attention_mask=attention_mask,
|
|
|
803 |
labels=None,
|
|
|
804 |
output_hidden_states=output_hidden_states
|
805 |
)
|
806 |
x = output.last_hidden_state
|
|
|
249 |
return F.silu(x1) * x2
|
250 |
|
251 |
|
252 |
+
def swiglu_ln_ffn(d_model: int, expansion_ratio: float, dropout: float = 0.0) -> nn.Sequential:
|
253 |
"""Create SwiGLU feedforward network with layer normalization."""
|
254 |
return nn.Sequential(
|
255 |
nn.LayerNorm(d_model),
|
|
|
257 |
d_model, swiglu_correction_fn(expansion_ratio, d_model) * 2, bias=False
|
258 |
),
|
259 |
SwiGLU(),
|
260 |
+
nn.Dropout(dropout),
|
261 |
nn.Linear(swiglu_correction_fn(expansion_ratio, d_model), d_model, bias=False),
|
262 |
)
|
263 |
|
|
|
373 |
n_heads: int,
|
374 |
residue_scaling_factor: float = 1,
|
375 |
expansion_ratio: float = 8 / 3,
|
376 |
+
dropout: float = 0.0,
|
377 |
):
|
378 |
super().__init__()
|
379 |
self.attn = MultiHeadAttention(d_model, n_heads)
|
380 |
+
self.ffn = swiglu_ln_ffn(d_model, expansion_ratio, dropout)
|
381 |
self.scaling_factor = residue_scaling_factor
|
382 |
|
383 |
def forward(
|
|
|
437 |
d_model: int,
|
438 |
n_heads: int,
|
439 |
n_layers: int,
|
440 |
+
dropout: float = 0.0,
|
441 |
):
|
442 |
super().__init__()
|
443 |
self.blocks = nn.ModuleList(
|
|
|
446 |
d_model,
|
447 |
n_heads,
|
448 |
residue_scaling_factor=math.sqrt(n_layers / 36),
|
449 |
+
dropout=dropout,
|
450 |
)
|
451 |
for i in range(n_layers)
|
452 |
]
|
|
|
521 |
self.config = config
|
522 |
self.vocab_size = config.vocab_size
|
523 |
self.embed = nn.Embedding(self.vocab_size, config.hidden_size)
|
524 |
+
self.transformer = TransformerStack(config.hidden_size, config.num_attention_heads, config.num_hidden_layers, config.dropout)
|
525 |
self.sequence_head = RegressionHead(config.hidden_size, self.vocab_size)
|
526 |
self.ce_loss = nn.CrossEntropyLoss()
|
527 |
self.tokenizer = EsmSequenceTokenizer()
|
|
|
653 |
|
654 |
return embeddings_dict
|
655 |
|
|
|
|
|
|
|
|
|
|
|
|
|
656 |
def forward(
|
657 |
self,
|
658 |
input_ids: Optional[torch.Tensor] = None,
|
659 |
attention_mask: Optional[torch.Tensor] = None,
|
660 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
661 |
labels: Optional[torch.Tensor] = None,
|
662 |
+
output_attentions: Optional[bool] = None,
|
663 |
+
output_hidden_states: Optional[bool] = None,
|
664 |
+
return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
|
665 |
) -> ESMplusplusOutput:
|
666 |
"""Forward pass for masked language modeling.
|
667 |
|
668 |
Args:
|
669 |
input_ids: Input token IDs
|
670 |
attention_mask: Attention mask
|
671 |
+
inputs_embeds: Optional precomputed embeddings
|
672 |
labels: Optional labels for masked tokens
|
673 |
output_hidden_states: Whether to return all hidden states
|
674 |
output_attentions: Whether to return attention weights
|
|
|
676 |
Returns:
|
677 |
ESMplusplusOutput containing loss, logits, hidden states and attention weights
|
678 |
"""
|
679 |
+
if inputs_embeds is None:
|
680 |
+
x = self.embed(input_ids)
|
681 |
+
else:
|
682 |
+
x = inputs_embeds
|
683 |
output = self.transformer(x, attention_mask, output_hidden_states, output_attentions)
|
684 |
x = output.last_hidden_state
|
685 |
logits = self.sequence_head(x)
|
|
|
714 |
self,
|
715 |
input_ids: Optional[torch.Tensor] = None,
|
716 |
attention_mask: Optional[torch.Tensor] = None,
|
717 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
718 |
labels: Optional[torch.Tensor] = None,
|
719 |
+
output_attentions: Optional[bool] = None,
|
720 |
+
output_hidden_states: Optional[bool] = None,
|
721 |
+
return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
|
722 |
) -> ESMplusplusOutput:
|
723 |
"""Forward pass for sequence classification.
|
724 |
|
725 |
Args:
|
726 |
input_ids: Input token IDs
|
727 |
attention_mask: Attention mask
|
728 |
+
inputs_embeds: Optional precomputed embeddings
|
729 |
labels: Optional labels for classification
|
730 |
output_hidden_states: Whether to return all hidden states
|
731 |
output_attentions: Whether to return attention weights
|
|
|
736 |
output = super().forward(
|
737 |
input_ids=input_ids,
|
738 |
attention_mask=attention_mask,
|
739 |
+
inputs_embeds=inputs_embeds,
|
740 |
labels=None,
|
741 |
+
output_attentions=output_attentions,
|
742 |
output_hidden_states=output_hidden_states
|
743 |
)
|
744 |
x = output.last_hidden_state
|
|
|
792 |
self,
|
793 |
input_ids: Optional[torch.Tensor] = None,
|
794 |
attention_mask: Optional[torch.Tensor] = None,
|
795 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
796 |
labels: Optional[torch.Tensor] = None,
|
797 |
+
output_attentions: Optional[bool] = None,
|
798 |
+
output_hidden_states: Optional[bool] = None,
|
799 |
+
return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
|
800 |
) -> ESMplusplusOutput:
|
801 |
"""Forward pass for token classification.
|
802 |
|
803 |
Args:
|
804 |
input_ids: Input token IDs
|
805 |
attention_mask: Attention mask
|
806 |
+
inputs_embeds: Optional precomputed embeddings
|
807 |
labels: Optional labels for token classification
|
808 |
output_hidden_states: Whether to return all hidden states
|
809 |
+
output_attentions: Whether to return attention weights
|
810 |
|
811 |
Returns:
|
812 |
ESMplusplusOutput containing loss, logits, and hidden states
|
|
|
814 |
output = super().forward(
|
815 |
input_ids=input_ids,
|
816 |
attention_mask=attention_mask,
|
817 |
+
inputs_embeds=inputs_embeds,
|
818 |
labels=None,
|
819 |
+
output_attentions=output_attentions,
|
820 |
output_hidden_states=output_hidden_states
|
821 |
)
|
822 |
x = output.last_hidden_state
|