lhallee commited on
Commit
1725572
·
verified ·
1 Parent(s): 005c0c6

Update modeling_esm_plusplus.py

Browse files
Files changed (1) hide show
  1. modeling_esm_plusplus.py +3 -3
modeling_esm_plusplus.py CHANGED
@@ -575,7 +575,7 @@ class ESMplusplusForMaskedLM(PreTrainedModel):
575
  seqs = sequences[i * batch_size:(i + 1) * batch_size]
576
  input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
577
  x = self.embed(input_ids)
578
- residue_embeddings = self.transformer(x, attention_mask).last_hidden_state.float() # required for sql
579
  embeddings = get_embeddings(residue_embeddings, attention_mask)
580
 
581
  for seq, emb in zip(seqs, embeddings):
@@ -595,10 +595,10 @@ class ESMplusplusForMaskedLM(PreTrainedModel):
595
  seqs = sequences[i * batch_size:(i + 1) * batch_size]
596
  input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
597
  x = self.embed(input_ids)
598
- residue_embeddings = self.transformer(x, attention_mask).last_hidden_state
599
  if full_precision:
600
  residue_embeddings = residue_embeddings.float()
601
- embeddings = get_embeddings(residue_embeddings, attention_mask)
602
  for seq, emb in zip(seqs, embeddings):
603
  embeddings_dict[seq] = emb
604
 
 
575
  seqs = sequences[i * batch_size:(i + 1) * batch_size]
576
  input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
577
  x = self.embed(input_ids)
578
+ residue_embeddings = self.transformer(x, attention_mask).last_hidden_state.detach().float() # required for sql
579
  embeddings = get_embeddings(residue_embeddings, attention_mask)
580
 
581
  for seq, emb in zip(seqs, embeddings):
 
595
  seqs = sequences[i * batch_size:(i + 1) * batch_size]
596
  input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
597
  x = self.embed(input_ids)
598
+ residue_embeddings = self.transformer(x, attention_mask).last_hidden_state.detach()
599
  if full_precision:
600
  residue_embeddings = residue_embeddings.float()
601
+ embeddings = get_embeddings(residue_embeddings, attention_mask).cpu()
602
  for seq, emb in zip(seqs, embeddings):
603
  embeddings_dict[seq] = emb
604