Spaces:
Running
Running
File size: 904 Bytes
30e9731 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
import torch
import json
import numpy as np
from transformers import (BertForMaskedLM, BertTokenizer)
modelpath = 'zari-bert-cda/'
tokenizer = BertTokenizer.from_pretrained(modelpath)
model = BertForMaskedLM.from_pretrained(modelpath)
model.eval()
id_of_mask = 103
def get_embeddings(sentence):
with torch.no_grad():
processed_sentence = '' + sentence + ''
tokenized = tokenizer.encode(processed_sentence)
input_ids = torch.tensor(tokenized).unsqueeze(0) # Batch size 1
outputs = model(input_ids)
index_of_mask = tokenized.index(id_of_mask)
# batch, tokens, vocab_size
prediction_scores = outputs[0]
return prediction_scores[0][index_of_mask].cpu().numpy().tolist()
import os
import shutil
# Free up memory
if os.environ.get('REMOVE_WEIGHTS') == 'TRUE':
print('removing zari-bert-cda from filesystem')
shutil.rmtree('zari-bert-cda', ignore_errors=True)
|