Update modeling_esm_plusplus.py
Browse files- modeling_esm_plusplus.py +73 -17
modeling_esm_plusplus.py
CHANGED
@@ -16,7 +16,7 @@ import torch.nn.functional as F
|
|
16 |
from dataclasses import dataclass
|
17 |
from functools import cache, partial
|
18 |
from pathlib import Path
|
19 |
-
from typing import Optional, Tuple
|
20 |
from einops import rearrange, repeat
|
21 |
from huggingface_hub import snapshot_download
|
22 |
from tokenizers import Tokenizer
|
@@ -48,6 +48,7 @@ class ESMplusplusConfig(PretrainedConfig):
|
|
48 |
num_hidden_layers: int = 30,
|
49 |
num_labels: int = 2,
|
50 |
problem_type: str | None = None,
|
|
|
51 |
**kwargs,
|
52 |
):
|
53 |
super().__init__(**kwargs)
|
@@ -57,6 +58,7 @@ class ESMplusplusConfig(PretrainedConfig):
|
|
57 |
self.num_hidden_layers = num_hidden_layers
|
58 |
self.num_labels = num_labels
|
59 |
self.problem_type = problem_type
|
|
|
60 |
|
61 |
|
62 |
### Rotary Embeddings
|
@@ -290,15 +292,17 @@ class MultiHeadAttention(nn.Module):
|
|
290 |
k = k.flatten(-2, -1)
|
291 |
return q, k
|
292 |
|
293 |
-
def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
294 |
"""
|
295 |
Args:
|
296 |
x: Input tensor
|
297 |
attention_mask: Optional attention mask
|
|
|
298 |
|
299 |
Returns:
|
300 |
-
Output tensor after self attention
|
301 |
"""
|
|
|
302 |
qkv_BLD3 = self.layernorm_qkv(x)
|
303 |
query_BLD, key_BLD, value_BLD = torch.chunk(qkv_BLD3, 3, dim=-1)
|
304 |
query_BLD, key_BLD = (
|
@@ -307,11 +311,29 @@ class MultiHeadAttention(nn.Module):
|
|
307 |
)
|
308 |
query_BLD, key_BLD = self._apply_rotary(query_BLD, key_BLD)
|
309 |
query_BHLD, key_BHLD, value_BHLD = map(self.reshaper, (query_BLD, key_BLD, value_BLD))
|
310 |
-
|
311 |
-
|
312 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
313 |
context_BLD = rearrange(context_BHLD, "b h s d -> b s (h d)")
|
314 |
-
|
|
|
315 |
|
316 |
|
317 |
### Regression Head
|
@@ -360,19 +382,23 @@ class UnifiedTransformerBlock(nn.Module):
|
|
360 |
self,
|
361 |
x: torch.Tensor,
|
362 |
attention_mask: Optional[torch.Tensor] = None,
|
363 |
-
|
|
|
364 |
"""
|
365 |
Args:
|
366 |
x: Input tensor
|
367 |
attention_mask: Optional attention mask
|
|
|
368 |
|
369 |
Returns:
|
370 |
-
Output tensor after transformer block
|
371 |
"""
|
372 |
-
|
373 |
-
x = x +
|
374 |
r3 = self.ffn(x) / self.scaling_factor
|
375 |
x = x + r3
|
|
|
|
|
376 |
return x
|
377 |
|
378 |
|
@@ -382,6 +408,7 @@ class TransformerOutput(ModelOutput):
|
|
382 |
"""Output type for transformer encoder."""
|
383 |
last_hidden_state: Optional[torch.Tensor] = None
|
384 |
hidden_states: Optional[Tuple[torch.Tensor]] = None
|
|
|
385 |
|
386 |
|
387 |
@dataclass
|
@@ -391,6 +418,7 @@ class ESMplusplusOutput(ModelOutput):
|
|
391 |
logits: Optional[torch.Tensor] = None
|
392 |
last_hidden_state: Optional[torch.Tensor] = None
|
393 |
hidden_states: Optional[Tuple[torch.Tensor]] = None
|
|
|
394 |
|
395 |
|
396 |
### Transformer Stack
|
@@ -426,25 +454,42 @@ class TransformerStack(nn.Module):
|
|
426 |
x: torch.Tensor,
|
427 |
attention_mask: Optional[torch.Tensor] = None,
|
428 |
output_hidden_states: bool = False,
|
|
|
429 |
) -> TransformerOutput:
|
430 |
"""
|
431 |
Args:
|
432 |
x: Input tensor
|
433 |
attention_mask: Optional attention mask
|
434 |
output_hidden_states: Whether to return all hidden states
|
|
|
435 |
|
436 |
Returns:
|
437 |
-
TransformerOutput containing last hidden state and optionally all hidden states
|
438 |
"""
|
439 |
batch_size, seq_len, _ = x.shape
|
440 |
-
hidden_states = ()
|
|
|
|
|
441 |
if attention_mask is not None:
|
442 |
attention_mask = attention_mask[:, None, None, :].expand(batch_size, 1, seq_len, seq_len).bool()
|
|
|
443 |
for block in self.blocks:
|
444 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
445 |
if output_hidden_states:
|
|
|
446 |
hidden_states += (x,)
|
447 |
-
|
|
|
|
|
|
|
|
|
|
|
448 |
|
449 |
|
450 |
### Dataset for Embedding
|
@@ -604,12 +649,19 @@ class ESMplusplusForMaskedLM(PreTrainedModel):
|
|
604 |
|
605 |
return embeddings_dict
|
606 |
|
|
|
|
|
|
|
|
|
|
|
|
|
607 |
def forward(
|
608 |
self,
|
609 |
input_ids: Optional[torch.Tensor] = None,
|
610 |
attention_mask: Optional[torch.Tensor] = None,
|
611 |
labels: Optional[torch.Tensor] = None,
|
612 |
output_hidden_states: bool = False,
|
|
|
613 |
) -> ESMplusplusOutput:
|
614 |
"""Forward pass for masked language modeling.
|
615 |
|
@@ -618,12 +670,13 @@ class ESMplusplusForMaskedLM(PreTrainedModel):
|
|
618 |
attention_mask: Attention mask
|
619 |
labels: Optional labels for masked tokens
|
620 |
output_hidden_states: Whether to return all hidden states
|
|
|
621 |
|
622 |
Returns:
|
623 |
-
ESMplusplusOutput containing loss, logits, and
|
624 |
"""
|
625 |
x = self.embed(input_ids)
|
626 |
-
output = self.transformer(x, attention_mask, output_hidden_states)
|
627 |
x = output.last_hidden_state
|
628 |
logits = self.sequence_head(x)
|
629 |
loss = None
|
@@ -634,6 +687,7 @@ class ESMplusplusForMaskedLM(PreTrainedModel):
|
|
634 |
logits=logits,
|
635 |
last_hidden_state=x,
|
636 |
hidden_states=output.hidden_states,
|
|
|
637 |
)
|
638 |
|
639 |
|
@@ -658,6 +712,7 @@ class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM):
|
|
658 |
attention_mask: Optional[torch.Tensor] = None,
|
659 |
labels: Optional[torch.Tensor] = None,
|
660 |
output_hidden_states: bool = False,
|
|
|
661 |
) -> ESMplusplusOutput:
|
662 |
"""Forward pass for sequence classification.
|
663 |
|
@@ -666,6 +721,7 @@ class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM):
|
|
666 |
attention_mask: Attention mask
|
667 |
labels: Optional labels for classification
|
668 |
output_hidden_states: Whether to return all hidden states
|
|
|
669 |
|
670 |
Returns:
|
671 |
ESMplusplusOutput containing loss, logits, and hidden states
|
|
|
16 |
from dataclasses import dataclass
|
17 |
from functools import cache, partial
|
18 |
from pathlib import Path
|
19 |
+
from typing import Optional, Tuple, Union
|
20 |
from einops import rearrange, repeat
|
21 |
from huggingface_hub import snapshot_download
|
22 |
from tokenizers import Tokenizer
|
|
|
48 |
num_hidden_layers: int = 30,
|
49 |
num_labels: int = 2,
|
50 |
problem_type: str | None = None,
|
51 |
+
dropout: float = 0.0,
|
52 |
**kwargs,
|
53 |
):
|
54 |
super().__init__(**kwargs)
|
|
|
58 |
self.num_hidden_layers = num_hidden_layers
|
59 |
self.num_labels = num_labels
|
60 |
self.problem_type = problem_type
|
61 |
+
self.dropout = dropout
|
62 |
|
63 |
|
64 |
### Rotary Embeddings
|
|
|
292 |
k = k.flatten(-2, -1)
|
293 |
return q, k
|
294 |
|
295 |
+
def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, output_attentions: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
296 |
"""
|
297 |
Args:
|
298 |
x: Input tensor
|
299 |
attention_mask: Optional attention mask
|
300 |
+
output_attentions: Whether to return attention weights
|
301 |
|
302 |
Returns:
|
303 |
+
Output tensor after self attention, and optionally attention weights
|
304 |
"""
|
305 |
+
attn_weights = None
|
306 |
qkv_BLD3 = self.layernorm_qkv(x)
|
307 |
query_BLD, key_BLD, value_BLD = torch.chunk(qkv_BLD3, 3, dim=-1)
|
308 |
query_BLD, key_BLD = (
|
|
|
311 |
)
|
312 |
query_BLD, key_BLD = self._apply_rotary(query_BLD, key_BLD)
|
313 |
query_BHLD, key_BHLD, value_BHLD = map(self.reshaper, (query_BLD, key_BLD, value_BLD))
|
314 |
+
|
315 |
+
if output_attentions: # Manual attention computation
|
316 |
+
L, S = query_BLD.size(-2), key_BLD.size(-2)
|
317 |
+
scale = 1 / math.sqrt(query_BLD.size(-1))
|
318 |
+
attn_bias = torch.zeros(L, S, dtype=query_BLD.dtype, device=query_BLD.device)
|
319 |
+
if attention_mask is not None:
|
320 |
+
if attention_mask.dtype == torch.bool:
|
321 |
+
attention_mask.masked_fill_(attention_mask.logical_not(), float('-inf'))
|
322 |
+
else:
|
323 |
+
attn_bias += attention_mask
|
324 |
+
|
325 |
+
attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-2, -1)) * scale
|
326 |
+
attn_weights += attn_bias
|
327 |
+
attn_weights = F.softmax(attn_weights, dim=-1)
|
328 |
+
context_BHLD = torch.matmul(attn_weights, value_BHLD)
|
329 |
+
else:
|
330 |
+
context_BHLD = F.scaled_dot_product_attention(
|
331 |
+
query_BHLD, key_BHLD, value_BHLD, attention_mask
|
332 |
+
)
|
333 |
+
|
334 |
context_BLD = rearrange(context_BHLD, "b h s d -> b s (h d)")
|
335 |
+
output = self.out_proj(context_BLD)
|
336 |
+
return output, attn_weights
|
337 |
|
338 |
|
339 |
### Regression Head
|
|
|
382 |
self,
|
383 |
x: torch.Tensor,
|
384 |
attention_mask: Optional[torch.Tensor] = None,
|
385 |
+
output_attentions: bool = False,
|
386 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
387 |
"""
|
388 |
Args:
|
389 |
x: Input tensor
|
390 |
attention_mask: Optional attention mask
|
391 |
+
output_attentions: Whether to return attention weights
|
392 |
|
393 |
Returns:
|
394 |
+
Output tensor after transformer block, and optionally attention weights
|
395 |
"""
|
396 |
+
attn_output, attn_weights = self.attn(x, attention_mask, output_attentions)
|
397 |
+
x = x + attn_output / self.scaling_factor
|
398 |
r3 = self.ffn(x) / self.scaling_factor
|
399 |
x = x + r3
|
400 |
+
if output_attentions:
|
401 |
+
return x, attn_weights
|
402 |
return x
|
403 |
|
404 |
|
|
|
408 |
"""Output type for transformer encoder."""
|
409 |
last_hidden_state: Optional[torch.Tensor] = None
|
410 |
hidden_states: Optional[Tuple[torch.Tensor]] = None
|
411 |
+
attentions: Optional[Tuple[torch.Tensor]] = None
|
412 |
|
413 |
|
414 |
@dataclass
|
|
|
418 |
logits: Optional[torch.Tensor] = None
|
419 |
last_hidden_state: Optional[torch.Tensor] = None
|
420 |
hidden_states: Optional[Tuple[torch.Tensor]] = None
|
421 |
+
attentions: Optional[Tuple[torch.Tensor]] = None
|
422 |
|
423 |
|
424 |
### Transformer Stack
|
|
|
454 |
x: torch.Tensor,
|
455 |
attention_mask: Optional[torch.Tensor] = None,
|
456 |
output_hidden_states: bool = False,
|
457 |
+
output_attentions: bool = False,
|
458 |
) -> TransformerOutput:
|
459 |
"""
|
460 |
Args:
|
461 |
x: Input tensor
|
462 |
attention_mask: Optional attention mask
|
463 |
output_hidden_states: Whether to return all hidden states
|
464 |
+
output_attentions: Whether to return attention weights
|
465 |
|
466 |
Returns:
|
467 |
+
TransformerOutput containing last hidden state and optionally all hidden states and attention weights
|
468 |
"""
|
469 |
batch_size, seq_len, _ = x.shape
|
470 |
+
hidden_states = () if output_hidden_states else None
|
471 |
+
attentions = () if output_attentions else None
|
472 |
+
|
473 |
if attention_mask is not None:
|
474 |
attention_mask = attention_mask[:, None, None, :].expand(batch_size, 1, seq_len, seq_len).bool()
|
475 |
+
|
476 |
for block in self.blocks:
|
477 |
+
if output_attentions:
|
478 |
+
x, attn_weights = block(x, attention_mask, output_attentions)
|
479 |
+
if attentions is not None:
|
480 |
+
attentions += (attn_weights,)
|
481 |
+
else:
|
482 |
+
x = block(x, attention_mask, output_attentions)
|
483 |
+
|
484 |
if output_hidden_states:
|
485 |
+
assert hidden_states is not None
|
486 |
hidden_states += (x,)
|
487 |
+
|
488 |
+
return TransformerOutput(
|
489 |
+
last_hidden_state=self.norm(x),
|
490 |
+
hidden_states=hidden_states,
|
491 |
+
attentions=attentions
|
492 |
+
)
|
493 |
|
494 |
|
495 |
### Dataset for Embedding
|
|
|
649 |
|
650 |
return embeddings_dict
|
651 |
|
652 |
+
"""
|
653 |
+
TODO
|
654 |
+
- Add dropout (default 0.0)
|
655 |
+
- Class method for returning manually computed attention maps
|
656 |
+
"""
|
657 |
+
|
658 |
def forward(
|
659 |
self,
|
660 |
input_ids: Optional[torch.Tensor] = None,
|
661 |
attention_mask: Optional[torch.Tensor] = None,
|
662 |
labels: Optional[torch.Tensor] = None,
|
663 |
output_hidden_states: bool = False,
|
664 |
+
output_attentions: bool = False,
|
665 |
) -> ESMplusplusOutput:
|
666 |
"""Forward pass for masked language modeling.
|
667 |
|
|
|
670 |
attention_mask: Attention mask
|
671 |
labels: Optional labels for masked tokens
|
672 |
output_hidden_states: Whether to return all hidden states
|
673 |
+
output_attentions: Whether to return attention weights
|
674 |
|
675 |
Returns:
|
676 |
+
ESMplusplusOutput containing loss, logits, hidden states and attention weights
|
677 |
"""
|
678 |
x = self.embed(input_ids)
|
679 |
+
output = self.transformer(x, attention_mask, output_hidden_states, output_attentions)
|
680 |
x = output.last_hidden_state
|
681 |
logits = self.sequence_head(x)
|
682 |
loss = None
|
|
|
687 |
logits=logits,
|
688 |
last_hidden_state=x,
|
689 |
hidden_states=output.hidden_states,
|
690 |
+
attentions=output.attentions,
|
691 |
)
|
692 |
|
693 |
|
|
|
712 |
attention_mask: Optional[torch.Tensor] = None,
|
713 |
labels: Optional[torch.Tensor] = None,
|
714 |
output_hidden_states: bool = False,
|
715 |
+
output_attentions: bool = False,
|
716 |
) -> ESMplusplusOutput:
|
717 |
"""Forward pass for sequence classification.
|
718 |
|
|
|
721 |
attention_mask: Attention mask
|
722 |
labels: Optional labels for classification
|
723 |
output_hidden_states: Whether to return all hidden states
|
724 |
+
output_attentions: Whether to return attention weights
|
725 |
|
726 |
Returns:
|
727 |
ESMplusplusOutput containing loss, logits, and hidden states
|