Spaces:
Sleeping
Sleeping
import streamlit as st | |
from annotated_text import annotated_text | |
import torch | |
from transformers import pipeline | |
from transformers import AutoModelForTokenClassification, AutoTokenizer | |
import json | |
st.set_page_config(layout="wide") | |
model = AutoModelForTokenClassification.from_pretrained("models/lusa", use_safetensors=True) | |
tokenizer = AutoTokenizer.from_pretrained("models/lusa", model_max_length=512) | |
tagger = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy='first') #aggregation_strategy='max' | |
def aggregate_subwords(input_tokens, labels): | |
new_inputs = [] | |
new_labels = [] | |
current_word = "" | |
current_label = "" | |
for i, token in enumerate(input_tokens): | |
label = labels[i] | |
# Handle subwords | |
if token.startswith('##'): | |
current_word += token[2:] | |
else: | |
# Finish previous word | |
if current_word: | |
new_inputs.append(current_word) | |
new_labels.append(current_label) | |
# Start new word | |
current_word = token | |
current_label = label | |
new_inputs.append(current_word) | |
new_labels.append(current_label) | |
return new_inputs, new_labels | |
def annotateTriggers(line): | |
line = line.strip() | |
inputs = tokenizer(line, return_tensors="pt") | |
input_tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]) | |
with torch.no_grad(): | |
logits = model(**inputs).logits | |
predictions = torch.argmax(logits, dim=2) | |
predicted_token_class = [model.config.id2label[t.item()] for t in predictions[0]] | |
input_tokens, predicted_token_class = aggregate_subwords(input_tokens,predicted_token_class) | |
token_labels = [] | |
current_entity = '' | |
for i, label in enumerate(predicted_token_class): | |
token = input_tokens[i] | |
if label == 'O': | |
token_labels.append((token, 'O', '')) | |
current_entity = '' | |
elif label.startswith('B-'): | |
current_entity = label[2:] | |
token_labels.append((token, 'B', current_entity)) | |
elif label.startswith('I-'): | |
if current_entity == '': | |
#raise ValueError(f"Invalid label sequence: {predicted_token_class}") | |
continue | |
token_labels[-1] = (token_labels[-1][0] + f" {token}", 'I', current_entity) | |
else: | |
raise ValueError(f"Invalid label: {label}") | |
return token_labels[1:-1] | |
def joinEntities(entities): | |
joined_entities = [] | |
i = 0 | |
while i < len(entities): | |
curr_entity = entities[i] | |
if curr_entity['entity'][0] == 'B': | |
label = curr_entity['entity'][2:] | |
j = i + 1 | |
while j < len(entities) and entities[j]['entity'][0] == 'I': | |
j += 1 | |
joined_entity = { | |
'entity': label, | |
'score': max(e['score'] for e in entities[i:j]), | |
'index': min(e['index'] for e in entities[i:j]), | |
'word': ' '.join(e['word'] for e in entities[i:j]), | |
'start': entities[i]['start'], | |
'end': entities[j-1]['end'] | |
} | |
joined_entities.append(joined_entity) | |
i = j - 1 | |
i += 1 | |
return joined_entities | |
import pysbd | |
seg = pysbd.Segmenter(language="es", clean=False) | |
def sent_tokenize(text): | |
return seg.segment(text) | |
def getSentenceIndex(lines,span): | |
i = 1 | |
sum = len(lines[0]) | |
while sum < span: | |
sum += len(lines[i]) | |
i = i + 1 | |
return i - 1 | |
def generateContext(text, window,span): | |
lines = sent_tokenize(text) | |
index = getSentenceIndex(lines,span) | |
text = " ".join(lines[max(0,index-window):index+window +1]) | |
return text | |
def annotateEvents(text,squad,window): | |
text = text.strip() | |
ner_results = tagger(text) | |
#print(ner_results) | |
#ner_results = joinEntities(ner_results) | |
i = 0 | |
#exit() | |
while i < len(ner_results): | |
ner_results[i]["entity"] = ner_results[i]["entity_group"].lstrip("B-") | |
ner_results[i]["entity"] = ner_results[i]["entity_group"].lstrip("I-") | |
i = i + 1 | |
events = [] | |
for trigger in ner_results: | |
tipo = trigger["entity_group"] | |
context = generateContext(text,window,trigger["start"]) | |
event = { | |
"trigger":trigger["word"], | |
"type": tipo, | |
"score": trigger["score"], | |
"context": context, | |
} | |
events.append(event) | |
return events | |
#"A Joana foi atacada pelo João nas ruas do Porto, com uma faca." | |
st.title('Identify Events') | |
options = ["O presidente da Federação Haitiana de Futebol, Yves Jean-Bart, foi banido para sempre de toda a atividade ligada ao futebol, por ter sido considerado culpado de abuso sexual sistemático de jogadoras, anunciou hoje a FIFA.", | |
"O navio 'Figaro', no qual viajavam 30 tripulantes - 16 angolanos, cinco espanhóis, cinco senegaleses, três peruanos e um do Gana - acionou por telefone o alarme de incêndio a bordo.", "A Polícia Judiciária (PJ) está a investigar o aparecimento de ossadas que foram hoje avistadas pelo proprietário de um terreno na freguesia de Meadela, em Viana do Castelo, disse à Lusa fonte daquela força policial."] | |
option = st.selectbox( | |
'Select examples', | |
options) | |
#option = options [index] | |
line = st.text_area("Insert Text",option) | |
st.button('Run') | |
window = 1 | |
if line != "": | |
st.header("Triggers:") | |
triggerss = annotateTriggers(line) | |
annotated_text(*[word[0]+" " if word[1] == 'O' else (word[0]+" ",word[2]) for word in triggerss ]) | |
eventos_1 = annotateEvents(line,1,window) | |
eventos_2 = annotateEvents(line,2,window) | |
for mention1, mention2 in zip(eventos_1,eventos_2): | |
st.text(f"| Trigger: {mention1['trigger']:20} | Type: {mention1['type']:10} | Score: {str(round(mention1['score'],3)):5} |") | |
st.markdown("""---""") | |