lhallee commited on
Commit
7f0b439
·
verified ·
1 Parent(s): 4b18c15

Update modeling_esm_plusplus.py

Browse files
Files changed (1) hide show
  1. modeling_esm_plusplus.py +4 -2
modeling_esm_plusplus.py CHANGED
@@ -625,7 +625,7 @@ class PreTrainedESMplusplusModel(PreTrainedModel):
625
 
626
  def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
627
  if full_embeddings:
628
- return residue_embeddings, attention_mask
629
  elif pooling_type == 'mean':
630
  return self.mean_pooling(residue_embeddings, attention_mask)
631
  elif pooling_type == 'max':
@@ -653,7 +653,9 @@ class PreTrainedESMplusplusModel(PreTrainedModel):
653
  residue_embeddings = self.transformer(x, attention_mask).last_hidden_state.detach().float() # required for sql
654
  embeddings = get_embeddings(residue_embeddings, attention_mask)
655
 
656
- for seq, emb in zip(seqs, embeddings):
 
 
657
  c.execute("INSERT OR REPLACE INTO embeddings VALUES (?, ?)",
658
  (seq, emb.cpu().numpy().tobytes()))
659
 
 
625
 
626
  def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
627
  if full_embeddings:
628
+ return residue_embeddings
629
  elif pooling_type == 'mean':
630
  return self.mean_pooling(residue_embeddings, attention_mask)
631
  elif pooling_type == 'max':
 
653
  residue_embeddings = self.transformer(x, attention_mask).last_hidden_state.detach().float() # required for sql
654
  embeddings = get_embeddings(residue_embeddings, attention_mask)
655
 
656
+ for seq, emb, mask in zip(seqs, embeddings, attention_mask):
657
+ if full_embeddings:
658
+ emb = emb[mask.bool()]
659
  c.execute("INSERT OR REPLACE INTO embeddings VALUES (?, ?)",
660
  (seq, emb.cpu().numpy().tobytes()))
661