lhallee commited on
Commit
7bdd97f
·
verified ·
1 Parent(s): a24444b

Update modeling_esm_plusplus.py

Browse files
Files changed (1) hide show
  1. modeling_esm_plusplus.py +6 -6
modeling_esm_plusplus.py CHANGED
@@ -249,7 +249,7 @@ class SwiGLU(nn.Module):
249
  return F.silu(x1) * x2
250
 
251
 
252
- def swiglu_ln_ffn(d_model: int, expansion_ratio: float, dropout: float = 0.0) -> nn.Sequential:
253
  """Create SwiGLU feedforward network with layer normalization."""
254
  return nn.Sequential(
255
  nn.LayerNorm(d_model),
@@ -257,7 +257,6 @@ def swiglu_ln_ffn(d_model: int, expansion_ratio: float, dropout: float = 0.0) ->
257
  d_model, swiglu_correction_fn(expansion_ratio, d_model) * 2, bias=False
258
  ),
259
  SwiGLU(),
260
- nn.Dropout(dropout),
261
  nn.Linear(swiglu_correction_fn(expansion_ratio, d_model), d_model, bias=False),
262
  )
263
 
@@ -377,8 +376,9 @@ class UnifiedTransformerBlock(nn.Module):
377
  ):
378
  super().__init__()
379
  self.attn = MultiHeadAttention(d_model, n_heads)
380
- self.ffn = swiglu_ln_ffn(d_model, expansion_ratio, dropout)
381
  self.scaling_factor = residue_scaling_factor
 
382
 
383
  def forward(
384
  self,
@@ -396,9 +396,8 @@ class UnifiedTransformerBlock(nn.Module):
396
  Output tensor after transformer block, and optionally attention weights
397
  """
398
  attn_output, attn_weights = self.attn(x, attention_mask, output_attentions)
399
- x = x + attn_output / self.scaling_factor
400
- r3 = self.ffn(x) / self.scaling_factor
401
- x = x + r3
402
  if output_attentions:
403
  return x, attn_weights
404
  return x
@@ -431,6 +430,7 @@ class TransformerStack(nn.Module):
431
  d_model: Model dimension
432
  n_heads: Number of attention heads
433
  n_layers: Number of transformer layers
 
434
  """
435
  def __init__(
436
  self,
 
249
  return F.silu(x1) * x2
250
 
251
 
252
+ def swiglu_ln_ffn(d_model: int, expansion_ratio: float) -> nn.Sequential:
253
  """Create SwiGLU feedforward network with layer normalization."""
254
  return nn.Sequential(
255
  nn.LayerNorm(d_model),
 
257
  d_model, swiglu_correction_fn(expansion_ratio, d_model) * 2, bias=False
258
  ),
259
  SwiGLU(),
 
260
  nn.Linear(swiglu_correction_fn(expansion_ratio, d_model), d_model, bias=False),
261
  )
262
 
 
376
  ):
377
  super().__init__()
378
  self.attn = MultiHeadAttention(d_model, n_heads)
379
+ self.ffn = swiglu_ln_ffn(d_model, expansion_ratio)
380
  self.scaling_factor = residue_scaling_factor
381
+ self.dropout = nn.Dropout(dropout)
382
 
383
  def forward(
384
  self,
 
396
  Output tensor after transformer block, and optionally attention weights
397
  """
398
  attn_output, attn_weights = self.attn(x, attention_mask, output_attentions)
399
+ x = x + self.dropout(attn_output) / self.scaling_factor
400
+ x = x + self.dropout(self.ffn(x)) / self.scaling_factor
 
401
  if output_attentions:
402
  return x, attn_weights
403
  return x
 
430
  d_model: Model dimension
431
  n_heads: Number of attention heads
432
  n_layers: Number of transformer layers
433
+ dropout: Dropout rate
434
  """
435
  def __init__(
436
  self,