Update modeling_esm_plusplus.py
Browse files- modeling_esm_plusplus.py +3 -3
modeling_esm_plusplus.py
CHANGED
@@ -575,7 +575,7 @@ class ESMplusplusForMaskedLM(PreTrainedModel):
|
|
575 |
seqs = sequences[i * batch_size:(i + 1) * batch_size]
|
576 |
input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
|
577 |
x = self.embed(input_ids)
|
578 |
-
residue_embeddings = self.transformer(x, attention_mask).last_hidden_state.float() # required for sql
|
579 |
embeddings = get_embeddings(residue_embeddings, attention_mask)
|
580 |
|
581 |
for seq, emb in zip(seqs, embeddings):
|
@@ -595,10 +595,10 @@ class ESMplusplusForMaskedLM(PreTrainedModel):
|
|
595 |
seqs = sequences[i * batch_size:(i + 1) * batch_size]
|
596 |
input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
|
597 |
x = self.embed(input_ids)
|
598 |
-
residue_embeddings = self.transformer(x, attention_mask).last_hidden_state
|
599 |
if full_precision:
|
600 |
residue_embeddings = residue_embeddings.float()
|
601 |
-
embeddings = get_embeddings(residue_embeddings, attention_mask)
|
602 |
for seq, emb in zip(seqs, embeddings):
|
603 |
embeddings_dict[seq] = emb
|
604 |
|
|
|
575 |
seqs = sequences[i * batch_size:(i + 1) * batch_size]
|
576 |
input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
|
577 |
x = self.embed(input_ids)
|
578 |
+
residue_embeddings = self.transformer(x, attention_mask).last_hidden_state.detach().float() # required for sql
|
579 |
embeddings = get_embeddings(residue_embeddings, attention_mask)
|
580 |
|
581 |
for seq, emb in zip(seqs, embeddings):
|
|
|
595 |
seqs = sequences[i * batch_size:(i + 1) * batch_size]
|
596 |
input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
|
597 |
x = self.embed(input_ids)
|
598 |
+
residue_embeddings = self.transformer(x, attention_mask).last_hidden_state.detach()
|
599 |
if full_precision:
|
600 |
residue_embeddings = residue_embeddings.float()
|
601 |
+
embeddings = get_embeddings(residue_embeddings, attention_mask).cpu()
|
602 |
for seq, emb in zip(seqs, embeddings):
|
603 |
embeddings_dict[seq] = emb
|
604 |
|