Update modeling_esm_plusplus.py
Browse files- modeling_esm_plusplus.py +1 -11
modeling_esm_plusplus.py
CHANGED
@@ -567,14 +567,6 @@ class PreTrainedESMplusplusModel(PreTrainedModel):
|
|
567 |
attention_mask = attention_mask.unsqueeze(-1)
|
568 |
return (x * attention_mask).max(dim=1).values
|
569 |
|
570 |
-
def min_pooling(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
571 |
-
"""Apply min pooling to sequence outputs."""
|
572 |
-
if attention_mask is None:
|
573 |
-
return x.min(dim=1).values
|
574 |
-
else:
|
575 |
-
attention_mask = attention_mask.unsqueeze(-1)
|
576 |
-
return (x * attention_mask).min(dim=1).values
|
577 |
-
|
578 |
def cls_pooling(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
579 |
"""Apply cls pooling to sequence outputs."""
|
580 |
return x[:, 0, :]
|
@@ -633,13 +625,11 @@ class PreTrainedESMplusplusModel(PreTrainedModel):
|
|
633 |
|
634 |
def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
635 |
if full_embeddings:
|
636 |
-
return residue_embeddings
|
637 |
elif pooling_type == 'mean':
|
638 |
return self.mean_pooling(residue_embeddings, attention_mask)
|
639 |
elif pooling_type == 'max':
|
640 |
return self.max_pooling(residue_embeddings, attention_mask)
|
641 |
-
elif pooling_type == 'min':
|
642 |
-
return self.min_pooling(residue_embeddings, attention_mask)
|
643 |
elif pooling_type == 'cls':
|
644 |
return self.cls_pooling(residue_embeddings, attention_mask)
|
645 |
else:
|
|
|
567 |
attention_mask = attention_mask.unsqueeze(-1)
|
568 |
return (x * attention_mask).max(dim=1).values
|
569 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
570 |
def cls_pooling(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
571 |
"""Apply cls pooling to sequence outputs."""
|
572 |
return x[:, 0, :]
|
|
|
625 |
|
626 |
def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
627 |
if full_embeddings:
|
628 |
+
return residue_embeddings, attention_mask
|
629 |
elif pooling_type == 'mean':
|
630 |
return self.mean_pooling(residue_embeddings, attention_mask)
|
631 |
elif pooling_type == 'max':
|
632 |
return self.max_pooling(residue_embeddings, attention_mask)
|
|
|
|
|
633 |
elif pooling_type == 'cls':
|
634 |
return self.cls_pooling(residue_embeddings, attention_mask)
|
635 |
else:
|