Update modeling_esm_plusplus.py
Browse files- modeling_esm_plusplus.py +2 -2
modeling_esm_plusplus.py
CHANGED
@@ -647,7 +647,7 @@ class PreTrainedESMplusplusModel(PreTrainedModel):
|
|
647 |
if len(to_embed) > 0:
|
648 |
with torch.no_grad():
|
649 |
for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
|
650 |
-
seqs =
|
651 |
input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
|
652 |
x = self.embed(input_ids)
|
653 |
residue_embeddings = self.transformer(x, attention_mask).last_hidden_state.detach().float() # required for sql
|
@@ -665,7 +665,7 @@ class PreTrainedESMplusplusModel(PreTrainedModel):
|
|
665 |
conn.commit()
|
666 |
conn.close()
|
667 |
return None
|
668 |
-
|
669 |
embeddings_dict = {}
|
670 |
with torch.no_grad():
|
671 |
for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
|
|
|
647 |
if len(to_embed) > 0:
|
648 |
with torch.no_grad():
|
649 |
for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
|
650 |
+
seqs = to_embed[i * batch_size:(i + 1) * batch_size]
|
651 |
input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
|
652 |
x = self.embed(input_ids)
|
653 |
residue_embeddings = self.transformer(x, attention_mask).last_hidden_state.detach().float() # required for sql
|
|
|
665 |
conn.commit()
|
666 |
conn.close()
|
667 |
return None
|
668 |
+
|
669 |
embeddings_dict = {}
|
670 |
with torch.no_grad():
|
671 |
for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
|