lhallee commited on
Commit
a24444b
·
verified ·
1 Parent(s): 778b394

Update modeling_esm_plusplus.py

Browse files
Files changed (1) hide show
  1. modeling_esm_plusplus.py +31 -15
modeling_esm_plusplus.py CHANGED
@@ -249,7 +249,7 @@ class SwiGLU(nn.Module):
249
  return F.silu(x1) * x2
250
 
251
 
252
- def swiglu_ln_ffn(d_model: int, expansion_ratio: float) -> nn.Sequential:
253
  """Create SwiGLU feedforward network with layer normalization."""
254
  return nn.Sequential(
255
  nn.LayerNorm(d_model),
@@ -257,6 +257,7 @@ def swiglu_ln_ffn(d_model: int, expansion_ratio: float) -> nn.Sequential:
257
  d_model, swiglu_correction_fn(expansion_ratio, d_model) * 2, bias=False
258
  ),
259
  SwiGLU(),
 
260
  nn.Linear(swiglu_correction_fn(expansion_ratio, d_model), d_model, bias=False),
261
  )
262
 
@@ -372,10 +373,11 @@ class UnifiedTransformerBlock(nn.Module):
372
  n_heads: int,
373
  residue_scaling_factor: float = 1,
374
  expansion_ratio: float = 8 / 3,
 
375
  ):
376
  super().__init__()
377
  self.attn = MultiHeadAttention(d_model, n_heads)
378
- self.ffn = swiglu_ln_ffn(d_model, expansion_ratio)
379
  self.scaling_factor = residue_scaling_factor
380
 
381
  def forward(
@@ -435,6 +437,7 @@ class TransformerStack(nn.Module):
435
  d_model: int,
436
  n_heads: int,
437
  n_layers: int,
 
438
  ):
439
  super().__init__()
440
  self.blocks = nn.ModuleList(
@@ -443,6 +446,7 @@ class TransformerStack(nn.Module):
443
  d_model,
444
  n_heads,
445
  residue_scaling_factor=math.sqrt(n_layers / 36),
 
446
  )
447
  for i in range(n_layers)
448
  ]
@@ -517,7 +521,7 @@ class ESMplusplusForMaskedLM(PreTrainedModel):
517
  self.config = config
518
  self.vocab_size = config.vocab_size
519
  self.embed = nn.Embedding(self.vocab_size, config.hidden_size)
520
- self.transformer = TransformerStack(config.hidden_size, config.num_attention_heads, config.num_hidden_layers)
521
  self.sequence_head = RegressionHead(config.hidden_size, self.vocab_size)
522
  self.ce_loss = nn.CrossEntropyLoss()
523
  self.tokenizer = EsmSequenceTokenizer()
@@ -649,25 +653,22 @@ class ESMplusplusForMaskedLM(PreTrainedModel):
649
 
650
  return embeddings_dict
651
 
652
- """
653
- TODO
654
- - Add dropout (default 0.0)
655
- - Class method for returning manually computed attention maps
656
- """
657
-
658
  def forward(
659
  self,
660
  input_ids: Optional[torch.Tensor] = None,
661
  attention_mask: Optional[torch.Tensor] = None,
 
662
  labels: Optional[torch.Tensor] = None,
663
- output_hidden_states: bool = False,
664
- output_attentions: bool = False,
 
665
  ) -> ESMplusplusOutput:
666
  """Forward pass for masked language modeling.
667
 
668
  Args:
669
  input_ids: Input token IDs
670
  attention_mask: Attention mask
 
671
  labels: Optional labels for masked tokens
672
  output_hidden_states: Whether to return all hidden states
673
  output_attentions: Whether to return attention weights
@@ -675,7 +676,10 @@ class ESMplusplusForMaskedLM(PreTrainedModel):
675
  Returns:
676
  ESMplusplusOutput containing loss, logits, hidden states and attention weights
677
  """
678
- x = self.embed(input_ids)
 
 
 
679
  output = self.transformer(x, attention_mask, output_hidden_states, output_attentions)
680
  x = output.last_hidden_state
681
  logits = self.sequence_head(x)
@@ -710,15 +714,18 @@ class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM):
710
  self,
711
  input_ids: Optional[torch.Tensor] = None,
712
  attention_mask: Optional[torch.Tensor] = None,
 
713
  labels: Optional[torch.Tensor] = None,
714
- output_hidden_states: bool = False,
715
- output_attentions: bool = False,
 
716
  ) -> ESMplusplusOutput:
717
  """Forward pass for sequence classification.
718
 
719
  Args:
720
  input_ids: Input token IDs
721
  attention_mask: Attention mask
 
722
  labels: Optional labels for classification
723
  output_hidden_states: Whether to return all hidden states
724
  output_attentions: Whether to return attention weights
@@ -729,7 +736,9 @@ class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM):
729
  output = super().forward(
730
  input_ids=input_ids,
731
  attention_mask=attention_mask,
 
732
  labels=None,
 
733
  output_hidden_states=output_hidden_states
734
  )
735
  x = output.last_hidden_state
@@ -783,16 +792,21 @@ class ESMplusplusForTokenClassification(ESMplusplusForMaskedLM):
783
  self,
784
  input_ids: Optional[torch.Tensor] = None,
785
  attention_mask: Optional[torch.Tensor] = None,
 
786
  labels: Optional[torch.Tensor] = None,
787
- output_hidden_states: bool = False,
 
 
788
  ) -> ESMplusplusOutput:
789
  """Forward pass for token classification.
790
 
791
  Args:
792
  input_ids: Input token IDs
793
  attention_mask: Attention mask
 
794
  labels: Optional labels for token classification
795
  output_hidden_states: Whether to return all hidden states
 
796
 
797
  Returns:
798
  ESMplusplusOutput containing loss, logits, and hidden states
@@ -800,7 +814,9 @@ class ESMplusplusForTokenClassification(ESMplusplusForMaskedLM):
800
  output = super().forward(
801
  input_ids=input_ids,
802
  attention_mask=attention_mask,
 
803
  labels=None,
 
804
  output_hidden_states=output_hidden_states
805
  )
806
  x = output.last_hidden_state
 
249
  return F.silu(x1) * x2
250
 
251
 
252
+ def swiglu_ln_ffn(d_model: int, expansion_ratio: float, dropout: float = 0.0) -> nn.Sequential:
253
  """Create SwiGLU feedforward network with layer normalization."""
254
  return nn.Sequential(
255
  nn.LayerNorm(d_model),
 
257
  d_model, swiglu_correction_fn(expansion_ratio, d_model) * 2, bias=False
258
  ),
259
  SwiGLU(),
260
+ nn.Dropout(dropout),
261
  nn.Linear(swiglu_correction_fn(expansion_ratio, d_model), d_model, bias=False),
262
  )
263
 
 
373
  n_heads: int,
374
  residue_scaling_factor: float = 1,
375
  expansion_ratio: float = 8 / 3,
376
+ dropout: float = 0.0,
377
  ):
378
  super().__init__()
379
  self.attn = MultiHeadAttention(d_model, n_heads)
380
+ self.ffn = swiglu_ln_ffn(d_model, expansion_ratio, dropout)
381
  self.scaling_factor = residue_scaling_factor
382
 
383
  def forward(
 
437
  d_model: int,
438
  n_heads: int,
439
  n_layers: int,
440
+ dropout: float = 0.0,
441
  ):
442
  super().__init__()
443
  self.blocks = nn.ModuleList(
 
446
  d_model,
447
  n_heads,
448
  residue_scaling_factor=math.sqrt(n_layers / 36),
449
+ dropout=dropout,
450
  )
451
  for i in range(n_layers)
452
  ]
 
521
  self.config = config
522
  self.vocab_size = config.vocab_size
523
  self.embed = nn.Embedding(self.vocab_size, config.hidden_size)
524
+ self.transformer = TransformerStack(config.hidden_size, config.num_attention_heads, config.num_hidden_layers, config.dropout)
525
  self.sequence_head = RegressionHead(config.hidden_size, self.vocab_size)
526
  self.ce_loss = nn.CrossEntropyLoss()
527
  self.tokenizer = EsmSequenceTokenizer()
 
653
 
654
  return embeddings_dict
655
 
 
 
 
 
 
 
656
  def forward(
657
  self,
658
  input_ids: Optional[torch.Tensor] = None,
659
  attention_mask: Optional[torch.Tensor] = None,
660
+ inputs_embeds: Optional[torch.Tensor] = None,
661
  labels: Optional[torch.Tensor] = None,
662
+ output_attentions: Optional[bool] = None,
663
+ output_hidden_states: Optional[bool] = None,
664
+ return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
665
  ) -> ESMplusplusOutput:
666
  """Forward pass for masked language modeling.
667
 
668
  Args:
669
  input_ids: Input token IDs
670
  attention_mask: Attention mask
671
+ inputs_embeds: Optional precomputed embeddings
672
  labels: Optional labels for masked tokens
673
  output_hidden_states: Whether to return all hidden states
674
  output_attentions: Whether to return attention weights
 
676
  Returns:
677
  ESMplusplusOutput containing loss, logits, hidden states and attention weights
678
  """
679
+ if inputs_embeds is None:
680
+ x = self.embed(input_ids)
681
+ else:
682
+ x = inputs_embeds
683
  output = self.transformer(x, attention_mask, output_hidden_states, output_attentions)
684
  x = output.last_hidden_state
685
  logits = self.sequence_head(x)
 
714
  self,
715
  input_ids: Optional[torch.Tensor] = None,
716
  attention_mask: Optional[torch.Tensor] = None,
717
+ inputs_embeds: Optional[torch.Tensor] = None,
718
  labels: Optional[torch.Tensor] = None,
719
+ output_attentions: Optional[bool] = None,
720
+ output_hidden_states: Optional[bool] = None,
721
+ return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
722
  ) -> ESMplusplusOutput:
723
  """Forward pass for sequence classification.
724
 
725
  Args:
726
  input_ids: Input token IDs
727
  attention_mask: Attention mask
728
+ inputs_embeds: Optional precomputed embeddings
729
  labels: Optional labels for classification
730
  output_hidden_states: Whether to return all hidden states
731
  output_attentions: Whether to return attention weights
 
736
  output = super().forward(
737
  input_ids=input_ids,
738
  attention_mask=attention_mask,
739
+ inputs_embeds=inputs_embeds,
740
  labels=None,
741
+ output_attentions=output_attentions,
742
  output_hidden_states=output_hidden_states
743
  )
744
  x = output.last_hidden_state
 
792
  self,
793
  input_ids: Optional[torch.Tensor] = None,
794
  attention_mask: Optional[torch.Tensor] = None,
795
+ inputs_embeds: Optional[torch.Tensor] = None,
796
  labels: Optional[torch.Tensor] = None,
797
+ output_attentions: Optional[bool] = None,
798
+ output_hidden_states: Optional[bool] = None,
799
+ return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
800
  ) -> ESMplusplusOutput:
801
  """Forward pass for token classification.
802
 
803
  Args:
804
  input_ids: Input token IDs
805
  attention_mask: Attention mask
806
+ inputs_embeds: Optional precomputed embeddings
807
  labels: Optional labels for token classification
808
  output_hidden_states: Whether to return all hidden states
809
+ output_attentions: Whether to return attention weights
810
 
811
  Returns:
812
  ESMplusplusOutput containing loss, logits, and hidden states
 
814
  output = super().forward(
815
  input_ids=input_ids,
816
  attention_mask=attention_mask,
817
+ inputs_embeds=inputs_embeds,
818
  labels=None,
819
+ output_attentions=output_attentions,
820
  output_hidden_states=output_hidden_states
821
  )
822
  x = output.last_hidden_state