lhallee commited on
Commit
5ae9ea4
·
verified ·
1 Parent(s): 7f0b439

Update modeling_esm_plusplus.py

Browse files
Files changed (1) hide show
  1. modeling_esm_plusplus.py +2 -2
modeling_esm_plusplus.py CHANGED
@@ -647,7 +647,7 @@ class PreTrainedESMplusplusModel(PreTrainedModel):
647
  if len(to_embed) > 0:
648
  with torch.no_grad():
649
  for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
650
- seqs = sequences[i * batch_size:(i + 1) * batch_size]
651
  input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
652
  x = self.embed(input_ids)
653
  residue_embeddings = self.transformer(x, attention_mask).last_hidden_state.detach().float() # required for sql
@@ -665,7 +665,7 @@ class PreTrainedESMplusplusModel(PreTrainedModel):
665
  conn.commit()
666
  conn.close()
667
  return None
668
-
669
  embeddings_dict = {}
670
  with torch.no_grad():
671
  for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
 
647
  if len(to_embed) > 0:
648
  with torch.no_grad():
649
  for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
650
+ seqs = to_embed[i * batch_size:(i + 1) * batch_size]
651
  input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
652
  x = self.embed(input_ids)
653
  residue_embeddings = self.transformer(x, attention_mask).last_hidden_state.detach().float() # required for sql
 
665
  conn.commit()
666
  conn.close()
667
  return None
668
+
669
  embeddings_dict = {}
670
  with torch.no_grad():
671
  for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):