Spaces:
Runtime error
Runtime error
import gradio as gr | |
import numpy as np | |
import os | |
import requests | |
from transformers import pipeline, AutoModelForTokenClassification, AutoTokenizer | |
from sentence_transformers import SentenceTransformer | |
from typing import List | |
NER_MODEL_PATH = 'dell-research-harvard/historical_newspaper_ner' | |
EMBED_MODEL_PATH = 'dell-research-harvard/same-story' | |
AZURE_VM_ALABAMA = os.environ.get('AZURE_VM_ALABAMA') | |
def find_sep_token(tokenizer): | |
""" | |
Returns sep token for given tokenizer | |
""" | |
if 'eos_token' in tokenizer.special_tokens_map: | |
sep = " " + tokenizer.special_tokens_map['eos_token'] + " " + tokenizer.special_tokens_map['sep_token'] + " " | |
else: | |
sep = " " + tokenizer.special_tokens_map['sep_token'] + " " | |
return sep | |
def find_mask_token(tokenizer): | |
""" | |
Returns mask token for given tokenizer | |
""" | |
mask_tok = tokenizer.special_tokens_map['mask_token'] | |
return mask_tok | |
if gr.NO_RELOAD: | |
ner_model=AutoModelForTokenClassification.from_pretrained(NER_MODEL_PATH) | |
ner_tokenizer=AutoTokenizer.from_pretrained(NER_MODEL_PATH, return_tensors = "pt", | |
max_length=256, truncation = True) | |
token_classifier = pipeline(task = "ner", | |
model = ner_model, tokenizer = ner_tokenizer, | |
ignore_labels = [], aggregation_strategy='max') | |
embedding_tokenizer = AutoTokenizer.from_pretrained(EMBED_MODEL_PATH) | |
embedding_model = SentenceTransformer(EMBED_MODEL_PATH) | |
embed_mask_tok = find_mask_token(embedding_tokenizer) | |
embed_sep_tok = find_sep_token(embedding_tokenizer) | |
# with open(REF_INDEX_PATH, 'r') as f: | |
# news_paths = [l.strip() for l in f.readlines()] | |
def handle_punctuation_for_generic_mask(word): | |
"""If punctuation comes before the word, return it before the mask, ow return it after the mask""" | |
if word[0] in [".",",","!","?"]: | |
return word[0] + " [MASK]" | |
elif word[-1] in [".",",","!","?"]: | |
return "[MASK]" + word[-1] | |
else: | |
return "[MASK]" | |
def handle_punctuation_for_entity_mask(word,entity_group): | |
"""If punctuation comes before the word, return it before the mask, ow return it after the mask - this is for specific entity masks""" | |
if word[0] in [".",",","!","?"]: | |
return word[0]+" "+entity_group | |
elif word[-1] in [".",",","!","?"]: | |
return entity_group+word[-1] | |
else: | |
return entity_group | |
def replace_words_with_entity_tokens(ner_output_dict: List[dict], | |
desired_labels: List[str] = ['PER', 'ORG', 'LOC', 'MISC'], | |
all_masks_same: bool = True) -> str: | |
if not all_masks_same: | |
new_word_list=[subdict["word"] if subdict["entity_group"] not in desired_labels else handle_punctuation_for_entity_mask(subdict["word"],subdict["entity_group"]) for subdict in ner_output_dict] | |
else: | |
new_word_list=[subdict["word"] if subdict["entity_group"] not in desired_labels else handle_punctuation_for_generic_mask(subdict["word"]) for subdict in ner_output_dict] | |
return " ".join(new_word_list) | |
def mask(ner_output_list: List[List[dict]], desired_labels: List[str] = ['PER', 'ORG', 'LOC', 'MISC'], | |
all_masks_same: bool = True) -> List[str]: | |
return replace_words_with_entity_tokens(ner_output_list, desired_labels, all_masks_same) | |
def ner(text: List[str]) -> List[str]: | |
results = token_classifier(text) | |
return results[0] | |
def ner_and_mask(text: List[str], labels_to_mask: List[str] = ['PER', 'ORG', 'LOC', 'MISC'], all_masks_same: bool = True) -> List[str]: | |
ner_output_list = ner(text) | |
return mask(ner_output_list, labels_to_mask, all_masks_same) | |
def embed(text: str) -> List[str]: | |
data = [] | |
# Correct [MASK] token for tokenizer | |
text = text.replace('[MASK]', embed_mask_tok) | |
text = text.replace('[SEP]', embed_sep_tok) | |
data.append(text) | |
embedding = embedding_model.encode(data, show_progress_bar = False, batch_size = 1) | |
embedding = embedding / np.linalg.norm(embedding, axis = 1, keepdims = True) | |
return embedding | |
def query(sentence: str) -> List[str]: | |
mask_results = ner_and_mask([sentence]) | |
embedding = embed(mask_results) | |
assert embedding.shape == (1, 768) | |
embedding = embedding[0].astype(np.float64) | |
req = {"vector": list(embedding), 'nn': 5} | |
# Send embedding to Azure VM | |
response = requests.post(f"http://{AZURE_VM_ALABAMA}/retrieve", json = req) | |
doc = response.json() | |
article = doc['bboxes'][doc['article_id']] | |
return article['raw_text'] | |
if __name__ == "__main__": | |
demo = gr.Interface( | |
fn=query, | |
inputs=["text"], | |
outputs=["text"], | |
) | |
demo.launch() |