lhallee commited on
Commit
a878f42
·
verified ·
1 Parent(s): e13d7d8

Update modeling_esm_plusplus.py

Browse files
Files changed (1) hide show
  1. modeling_esm_plusplus.py +73 -17
modeling_esm_plusplus.py CHANGED
@@ -16,7 +16,7 @@ import torch.nn.functional as F
16
  from dataclasses import dataclass
17
  from functools import cache, partial
18
  from pathlib import Path
19
- from typing import Optional, Tuple
20
  from einops import rearrange, repeat
21
  from huggingface_hub import snapshot_download
22
  from tokenizers import Tokenizer
@@ -48,6 +48,7 @@ class ESMplusplusConfig(PretrainedConfig):
48
  num_hidden_layers: int = 30,
49
  num_labels: int = 2,
50
  problem_type: str | None = None,
 
51
  **kwargs,
52
  ):
53
  super().__init__(**kwargs)
@@ -57,6 +58,7 @@ class ESMplusplusConfig(PretrainedConfig):
57
  self.num_hidden_layers = num_hidden_layers
58
  self.num_labels = num_labels
59
  self.problem_type = problem_type
 
60
 
61
 
62
  ### Rotary Embeddings
@@ -290,15 +292,17 @@ class MultiHeadAttention(nn.Module):
290
  k = k.flatten(-2, -1)
291
  return q, k
292
 
293
- def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
294
  """
295
  Args:
296
  x: Input tensor
297
  attention_mask: Optional attention mask
 
298
 
299
  Returns:
300
- Output tensor after self attention
301
  """
 
302
  qkv_BLD3 = self.layernorm_qkv(x)
303
  query_BLD, key_BLD, value_BLD = torch.chunk(qkv_BLD3, 3, dim=-1)
304
  query_BLD, key_BLD = (
@@ -307,11 +311,29 @@ class MultiHeadAttention(nn.Module):
307
  )
308
  query_BLD, key_BLD = self._apply_rotary(query_BLD, key_BLD)
309
  query_BHLD, key_BHLD, value_BHLD = map(self.reshaper, (query_BLD, key_BLD, value_BLD))
310
- context_BHLD = F.scaled_dot_product_attention(
311
- query_BHLD, key_BHLD, value_BHLD, attention_mask
312
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
  context_BLD = rearrange(context_BHLD, "b h s d -> b s (h d)")
314
- return self.out_proj(context_BLD)
 
315
 
316
 
317
  ### Regression Head
@@ -360,19 +382,23 @@ class UnifiedTransformerBlock(nn.Module):
360
  self,
361
  x: torch.Tensor,
362
  attention_mask: Optional[torch.Tensor] = None,
363
- ) -> torch.Tensor:
 
364
  """
365
  Args:
366
  x: Input tensor
367
  attention_mask: Optional attention mask
 
368
 
369
  Returns:
370
- Output tensor after transformer block
371
  """
372
- r1 = self.attn(x, attention_mask)
373
- x = x + r1 / self.scaling_factor
374
  r3 = self.ffn(x) / self.scaling_factor
375
  x = x + r3
 
 
376
  return x
377
 
378
 
@@ -382,6 +408,7 @@ class TransformerOutput(ModelOutput):
382
  """Output type for transformer encoder."""
383
  last_hidden_state: Optional[torch.Tensor] = None
384
  hidden_states: Optional[Tuple[torch.Tensor]] = None
 
385
 
386
 
387
  @dataclass
@@ -391,6 +418,7 @@ class ESMplusplusOutput(ModelOutput):
391
  logits: Optional[torch.Tensor] = None
392
  last_hidden_state: Optional[torch.Tensor] = None
393
  hidden_states: Optional[Tuple[torch.Tensor]] = None
 
394
 
395
 
396
  ### Transformer Stack
@@ -426,25 +454,42 @@ class TransformerStack(nn.Module):
426
  x: torch.Tensor,
427
  attention_mask: Optional[torch.Tensor] = None,
428
  output_hidden_states: bool = False,
 
429
  ) -> TransformerOutput:
430
  """
431
  Args:
432
  x: Input tensor
433
  attention_mask: Optional attention mask
434
  output_hidden_states: Whether to return all hidden states
 
435
 
436
  Returns:
437
- TransformerOutput containing last hidden state and optionally all hidden states
438
  """
439
  batch_size, seq_len, _ = x.shape
440
- hidden_states = ()
 
 
441
  if attention_mask is not None:
442
  attention_mask = attention_mask[:, None, None, :].expand(batch_size, 1, seq_len, seq_len).bool()
 
443
  for block in self.blocks:
444
- x = block(x, attention_mask)
 
 
 
 
 
 
445
  if output_hidden_states:
 
446
  hidden_states += (x,)
447
- return TransformerOutput(last_hidden_state=self.norm(x), hidden_states=hidden_states)
 
 
 
 
 
448
 
449
 
450
  ### Dataset for Embedding
@@ -604,12 +649,19 @@ class ESMplusplusForMaskedLM(PreTrainedModel):
604
 
605
  return embeddings_dict
606
 
 
 
 
 
 
 
607
  def forward(
608
  self,
609
  input_ids: Optional[torch.Tensor] = None,
610
  attention_mask: Optional[torch.Tensor] = None,
611
  labels: Optional[torch.Tensor] = None,
612
  output_hidden_states: bool = False,
 
613
  ) -> ESMplusplusOutput:
614
  """Forward pass for masked language modeling.
615
 
@@ -618,12 +670,13 @@ class ESMplusplusForMaskedLM(PreTrainedModel):
618
  attention_mask: Attention mask
619
  labels: Optional labels for masked tokens
620
  output_hidden_states: Whether to return all hidden states
 
621
 
622
  Returns:
623
- ESMplusplusOutput containing loss, logits, and hidden states
624
  """
625
  x = self.embed(input_ids)
626
- output = self.transformer(x, attention_mask, output_hidden_states)
627
  x = output.last_hidden_state
628
  logits = self.sequence_head(x)
629
  loss = None
@@ -634,6 +687,7 @@ class ESMplusplusForMaskedLM(PreTrainedModel):
634
  logits=logits,
635
  last_hidden_state=x,
636
  hidden_states=output.hidden_states,
 
637
  )
638
 
639
 
@@ -658,6 +712,7 @@ class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM):
658
  attention_mask: Optional[torch.Tensor] = None,
659
  labels: Optional[torch.Tensor] = None,
660
  output_hidden_states: bool = False,
 
661
  ) -> ESMplusplusOutput:
662
  """Forward pass for sequence classification.
663
 
@@ -666,6 +721,7 @@ class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM):
666
  attention_mask: Attention mask
667
  labels: Optional labels for classification
668
  output_hidden_states: Whether to return all hidden states
 
669
 
670
  Returns:
671
  ESMplusplusOutput containing loss, logits, and hidden states
 
16
  from dataclasses import dataclass
17
  from functools import cache, partial
18
  from pathlib import Path
19
+ from typing import Optional, Tuple, Union
20
  from einops import rearrange, repeat
21
  from huggingface_hub import snapshot_download
22
  from tokenizers import Tokenizer
 
48
  num_hidden_layers: int = 30,
49
  num_labels: int = 2,
50
  problem_type: str | None = None,
51
+ dropout: float = 0.0,
52
  **kwargs,
53
  ):
54
  super().__init__(**kwargs)
 
58
  self.num_hidden_layers = num_hidden_layers
59
  self.num_labels = num_labels
60
  self.problem_type = problem_type
61
+ self.dropout = dropout
62
 
63
 
64
  ### Rotary Embeddings
 
292
  k = k.flatten(-2, -1)
293
  return q, k
294
 
295
+ def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, output_attentions: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
296
  """
297
  Args:
298
  x: Input tensor
299
  attention_mask: Optional attention mask
300
+ output_attentions: Whether to return attention weights
301
 
302
  Returns:
303
+ Output tensor after self attention, and optionally attention weights
304
  """
305
+ attn_weights = None
306
  qkv_BLD3 = self.layernorm_qkv(x)
307
  query_BLD, key_BLD, value_BLD = torch.chunk(qkv_BLD3, 3, dim=-1)
308
  query_BLD, key_BLD = (
 
311
  )
312
  query_BLD, key_BLD = self._apply_rotary(query_BLD, key_BLD)
313
  query_BHLD, key_BHLD, value_BHLD = map(self.reshaper, (query_BLD, key_BLD, value_BLD))
314
+
315
+ if output_attentions: # Manual attention computation
316
+ L, S = query_BLD.size(-2), key_BLD.size(-2)
317
+ scale = 1 / math.sqrt(query_BLD.size(-1))
318
+ attn_bias = torch.zeros(L, S, dtype=query_BLD.dtype, device=query_BLD.device)
319
+ if attention_mask is not None:
320
+ if attention_mask.dtype == torch.bool:
321
+ attention_mask.masked_fill_(attention_mask.logical_not(), float('-inf'))
322
+ else:
323
+ attn_bias += attention_mask
324
+
325
+ attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-2, -1)) * scale
326
+ attn_weights += attn_bias
327
+ attn_weights = F.softmax(attn_weights, dim=-1)
328
+ context_BHLD = torch.matmul(attn_weights, value_BHLD)
329
+ else:
330
+ context_BHLD = F.scaled_dot_product_attention(
331
+ query_BHLD, key_BHLD, value_BHLD, attention_mask
332
+ )
333
+
334
  context_BLD = rearrange(context_BHLD, "b h s d -> b s (h d)")
335
+ output = self.out_proj(context_BLD)
336
+ return output, attn_weights
337
 
338
 
339
  ### Regression Head
 
382
  self,
383
  x: torch.Tensor,
384
  attention_mask: Optional[torch.Tensor] = None,
385
+ output_attentions: bool = False,
386
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
387
  """
388
  Args:
389
  x: Input tensor
390
  attention_mask: Optional attention mask
391
+ output_attentions: Whether to return attention weights
392
 
393
  Returns:
394
+ Output tensor after transformer block, and optionally attention weights
395
  """
396
+ attn_output, attn_weights = self.attn(x, attention_mask, output_attentions)
397
+ x = x + attn_output / self.scaling_factor
398
  r3 = self.ffn(x) / self.scaling_factor
399
  x = x + r3
400
+ if output_attentions:
401
+ return x, attn_weights
402
  return x
403
 
404
 
 
408
  """Output type for transformer encoder."""
409
  last_hidden_state: Optional[torch.Tensor] = None
410
  hidden_states: Optional[Tuple[torch.Tensor]] = None
411
+ attentions: Optional[Tuple[torch.Tensor]] = None
412
 
413
 
414
  @dataclass
 
418
  logits: Optional[torch.Tensor] = None
419
  last_hidden_state: Optional[torch.Tensor] = None
420
  hidden_states: Optional[Tuple[torch.Tensor]] = None
421
+ attentions: Optional[Tuple[torch.Tensor]] = None
422
 
423
 
424
  ### Transformer Stack
 
454
  x: torch.Tensor,
455
  attention_mask: Optional[torch.Tensor] = None,
456
  output_hidden_states: bool = False,
457
+ output_attentions: bool = False,
458
  ) -> TransformerOutput:
459
  """
460
  Args:
461
  x: Input tensor
462
  attention_mask: Optional attention mask
463
  output_hidden_states: Whether to return all hidden states
464
+ output_attentions: Whether to return attention weights
465
 
466
  Returns:
467
+ TransformerOutput containing last hidden state and optionally all hidden states and attention weights
468
  """
469
  batch_size, seq_len, _ = x.shape
470
+ hidden_states = () if output_hidden_states else None
471
+ attentions = () if output_attentions else None
472
+
473
  if attention_mask is not None:
474
  attention_mask = attention_mask[:, None, None, :].expand(batch_size, 1, seq_len, seq_len).bool()
475
+
476
  for block in self.blocks:
477
+ if output_attentions:
478
+ x, attn_weights = block(x, attention_mask, output_attentions)
479
+ if attentions is not None:
480
+ attentions += (attn_weights,)
481
+ else:
482
+ x = block(x, attention_mask, output_attentions)
483
+
484
  if output_hidden_states:
485
+ assert hidden_states is not None
486
  hidden_states += (x,)
487
+
488
+ return TransformerOutput(
489
+ last_hidden_state=self.norm(x),
490
+ hidden_states=hidden_states,
491
+ attentions=attentions
492
+ )
493
 
494
 
495
  ### Dataset for Embedding
 
649
 
650
  return embeddings_dict
651
 
652
+ """
653
+ TODO
654
+ - Add dropout (default 0.0)
655
+ - Class method for returning manually computed attention maps
656
+ """
657
+
658
  def forward(
659
  self,
660
  input_ids: Optional[torch.Tensor] = None,
661
  attention_mask: Optional[torch.Tensor] = None,
662
  labels: Optional[torch.Tensor] = None,
663
  output_hidden_states: bool = False,
664
+ output_attentions: bool = False,
665
  ) -> ESMplusplusOutput:
666
  """Forward pass for masked language modeling.
667
 
 
670
  attention_mask: Attention mask
671
  labels: Optional labels for masked tokens
672
  output_hidden_states: Whether to return all hidden states
673
+ output_attentions: Whether to return attention weights
674
 
675
  Returns:
676
+ ESMplusplusOutput containing loss, logits, hidden states and attention weights
677
  """
678
  x = self.embed(input_ids)
679
+ output = self.transformer(x, attention_mask, output_hidden_states, output_attentions)
680
  x = output.last_hidden_state
681
  logits = self.sequence_head(x)
682
  loss = None
 
687
  logits=logits,
688
  last_hidden_state=x,
689
  hidden_states=output.hidden_states,
690
+ attentions=output.attentions,
691
  )
692
 
693
 
 
712
  attention_mask: Optional[torch.Tensor] = None,
713
  labels: Optional[torch.Tensor] = None,
714
  output_hidden_states: bool = False,
715
+ output_attentions: bool = False,
716
  ) -> ESMplusplusOutput:
717
  """Forward pass for sequence classification.
718
 
 
721
  attention_mask: Attention mask
722
  labels: Optional labels for classification
723
  output_hidden_states: Whether to return all hidden states
724
+ output_attentions: Whether to return attention weights
725
 
726
  Returns:
727
  ESMplusplusOutput containing loss, logits, and hidden states