Update modeling_esm_plusplus.py
Browse files- modeling_esm_plusplus.py +6 -6
modeling_esm_plusplus.py
CHANGED
@@ -249,7 +249,7 @@ class SwiGLU(nn.Module):
|
|
249 |
return F.silu(x1) * x2
|
250 |
|
251 |
|
252 |
-
def swiglu_ln_ffn(d_model: int, expansion_ratio: float
|
253 |
"""Create SwiGLU feedforward network with layer normalization."""
|
254 |
return nn.Sequential(
|
255 |
nn.LayerNorm(d_model),
|
@@ -257,7 +257,6 @@ def swiglu_ln_ffn(d_model: int, expansion_ratio: float, dropout: float = 0.0) ->
|
|
257 |
d_model, swiglu_correction_fn(expansion_ratio, d_model) * 2, bias=False
|
258 |
),
|
259 |
SwiGLU(),
|
260 |
-
nn.Dropout(dropout),
|
261 |
nn.Linear(swiglu_correction_fn(expansion_ratio, d_model), d_model, bias=False),
|
262 |
)
|
263 |
|
@@ -377,8 +376,9 @@ class UnifiedTransformerBlock(nn.Module):
|
|
377 |
):
|
378 |
super().__init__()
|
379 |
self.attn = MultiHeadAttention(d_model, n_heads)
|
380 |
-
self.ffn = swiglu_ln_ffn(d_model, expansion_ratio
|
381 |
self.scaling_factor = residue_scaling_factor
|
|
|
382 |
|
383 |
def forward(
|
384 |
self,
|
@@ -396,9 +396,8 @@ class UnifiedTransformerBlock(nn.Module):
|
|
396 |
Output tensor after transformer block, and optionally attention weights
|
397 |
"""
|
398 |
attn_output, attn_weights = self.attn(x, attention_mask, output_attentions)
|
399 |
-
x = x + attn_output / self.scaling_factor
|
400 |
-
|
401 |
-
x = x + r3
|
402 |
if output_attentions:
|
403 |
return x, attn_weights
|
404 |
return x
|
@@ -431,6 +430,7 @@ class TransformerStack(nn.Module):
|
|
431 |
d_model: Model dimension
|
432 |
n_heads: Number of attention heads
|
433 |
n_layers: Number of transformer layers
|
|
|
434 |
"""
|
435 |
def __init__(
|
436 |
self,
|
|
|
249 |
return F.silu(x1) * x2
|
250 |
|
251 |
|
252 |
+
def swiglu_ln_ffn(d_model: int, expansion_ratio: float) -> nn.Sequential:
|
253 |
"""Create SwiGLU feedforward network with layer normalization."""
|
254 |
return nn.Sequential(
|
255 |
nn.LayerNorm(d_model),
|
|
|
257 |
d_model, swiglu_correction_fn(expansion_ratio, d_model) * 2, bias=False
|
258 |
),
|
259 |
SwiGLU(),
|
|
|
260 |
nn.Linear(swiglu_correction_fn(expansion_ratio, d_model), d_model, bias=False),
|
261 |
)
|
262 |
|
|
|
376 |
):
|
377 |
super().__init__()
|
378 |
self.attn = MultiHeadAttention(d_model, n_heads)
|
379 |
+
self.ffn = swiglu_ln_ffn(d_model, expansion_ratio)
|
380 |
self.scaling_factor = residue_scaling_factor
|
381 |
+
self.dropout = nn.Dropout(dropout)
|
382 |
|
383 |
def forward(
|
384 |
self,
|
|
|
396 |
Output tensor after transformer block, and optionally attention weights
|
397 |
"""
|
398 |
attn_output, attn_weights = self.attn(x, attention_mask, output_attentions)
|
399 |
+
x = x + self.dropout(attn_output) / self.scaling_factor
|
400 |
+
x = x + self.dropout(self.ffn(x)) / self.scaling_factor
|
|
|
401 |
if output_attentions:
|
402 |
return x, attn_weights
|
403 |
return x
|
|
|
430 |
d_model: Model dimension
|
431 |
n_heads: Number of attention heads
|
432 |
n_layers: Number of transformer layers
|
433 |
+
dropout: Dropout rate
|
434 |
"""
|
435 |
def __init__(
|
436 |
self,
|