Update modeling_esm_plusplus.py
Browse files- modeling_esm_plusplus.py +4 -2
modeling_esm_plusplus.py
CHANGED
@@ -625,7 +625,7 @@ class PreTrainedESMplusplusModel(PreTrainedModel):
|
|
625 |
|
626 |
def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
627 |
if full_embeddings:
|
628 |
-
return residue_embeddings
|
629 |
elif pooling_type == 'mean':
|
630 |
return self.mean_pooling(residue_embeddings, attention_mask)
|
631 |
elif pooling_type == 'max':
|
@@ -653,7 +653,9 @@ class PreTrainedESMplusplusModel(PreTrainedModel):
|
|
653 |
residue_embeddings = self.transformer(x, attention_mask).last_hidden_state.detach().float() # required for sql
|
654 |
embeddings = get_embeddings(residue_embeddings, attention_mask)
|
655 |
|
656 |
-
for seq, emb in zip(seqs, embeddings):
|
|
|
|
|
657 |
c.execute("INSERT OR REPLACE INTO embeddings VALUES (?, ?)",
|
658 |
(seq, emb.cpu().numpy().tobytes()))
|
659 |
|
|
|
625 |
|
626 |
def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
627 |
if full_embeddings:
|
628 |
+
return residue_embeddings
|
629 |
elif pooling_type == 'mean':
|
630 |
return self.mean_pooling(residue_embeddings, attention_mask)
|
631 |
elif pooling_type == 'max':
|
|
|
653 |
residue_embeddings = self.transformer(x, attention_mask).last_hidden_state.detach().float() # required for sql
|
654 |
embeddings = get_embeddings(residue_embeddings, attention_mask)
|
655 |
|
656 |
+
for seq, emb, mask in zip(seqs, embeddings, attention_mask):
|
657 |
+
if full_embeddings:
|
658 |
+
emb = emb[mask.bool()]
|
659 |
c.execute("INSERT OR REPLACE INTO embeddings VALUES (?, ?)",
|
660 |
(seq, emb.cpu().numpy().tobytes()))
|
661 |
|