lhallee commited on
Commit
4b18c15
·
verified ·
1 Parent(s): 3099b74

Update modeling_esm_plusplus.py

Browse files
Files changed (1) hide show
  1. modeling_esm_plusplus.py +1 -11
modeling_esm_plusplus.py CHANGED
@@ -567,14 +567,6 @@ class PreTrainedESMplusplusModel(PreTrainedModel):
567
  attention_mask = attention_mask.unsqueeze(-1)
568
  return (x * attention_mask).max(dim=1).values
569
 
570
- def min_pooling(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
571
- """Apply min pooling to sequence outputs."""
572
- if attention_mask is None:
573
- return x.min(dim=1).values
574
- else:
575
- attention_mask = attention_mask.unsqueeze(-1)
576
- return (x * attention_mask).min(dim=1).values
577
-
578
  def cls_pooling(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
579
  """Apply cls pooling to sequence outputs."""
580
  return x[:, 0, :]
@@ -633,13 +625,11 @@ class PreTrainedESMplusplusModel(PreTrainedModel):
633
 
634
  def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
635
  if full_embeddings:
636
- return residue_embeddings
637
  elif pooling_type == 'mean':
638
  return self.mean_pooling(residue_embeddings, attention_mask)
639
  elif pooling_type == 'max':
640
  return self.max_pooling(residue_embeddings, attention_mask)
641
- elif pooling_type == 'min':
642
- return self.min_pooling(residue_embeddings, attention_mask)
643
  elif pooling_type == 'cls':
644
  return self.cls_pooling(residue_embeddings, attention_mask)
645
  else:
 
567
  attention_mask = attention_mask.unsqueeze(-1)
568
  return (x * attention_mask).max(dim=1).values
569
 
 
 
 
 
 
 
 
 
570
  def cls_pooling(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
571
  """Apply cls pooling to sequence outputs."""
572
  return x[:, 0, :]
 
625
 
626
  def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
627
  if full_embeddings:
628
+ return residue_embeddings, attention_mask
629
  elif pooling_type == 'mean':
630
  return self.mean_pooling(residue_embeddings, attention_mask)
631
  elif pooling_type == 'max':
632
  return self.max_pooling(residue_embeddings, attention_mask)
 
 
633
  elif pooling_type == 'cls':
634
  return self.cls_pooling(residue_embeddings, attention_mask)
635
  else: