lfcc's picture
Merge branch 'main' of https://huggingface.co/spaces/lfcc/Event-Identifier
25a8261
raw
history blame
5.97 kB
import streamlit as st
from annotated_text import annotated_text
import torch
from transformers import pipeline
from transformers import AutoModelForTokenClassification, AutoTokenizer
import json
st.set_page_config(layout="wide")
model = AutoModelForTokenClassification.from_pretrained("models/lusa")
tokenizer = AutoTokenizer.from_pretrained("models/lusa", model_max_length=512)
tagger = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy='first') #aggregation_strategy='max'
def aggregate_subwords(input_tokens, labels):
new_inputs = []
new_labels = []
current_word = ""
current_label = ""
for i, token in enumerate(input_tokens):
label = labels[i]
# Handle subwords
if token.startswith('##'):
current_word += token[2:]
else:
# Finish previous word
if current_word:
new_inputs.append(current_word)
new_labels.append(current_label)
# Start new word
current_word = token
current_label = label
new_inputs.append(current_word)
new_labels.append(current_label)
return new_inputs, new_labels
def annotateTriggers(line):
line = line.strip()
inputs = tokenizer(line, return_tensors="pt")
input_tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
with torch.no_grad():
logits = model(**inputs).logits
predictions = torch.argmax(logits, dim=2)
predicted_token_class = [model.config.id2label[t.item()] for t in predictions[0]]
input_tokens, predicted_token_class = aggregate_subwords(input_tokens,predicted_token_class)
token_labels = []
current_entity = ''
for i, label in enumerate(predicted_token_class):
token = input_tokens[i]
if label == 'O':
token_labels.append((token, 'O', ''))
current_entity = ''
elif label.startswith('B-'):
current_entity = label[2:]
token_labels.append((token, 'B', current_entity))
elif label.startswith('I-'):
if current_entity == '':
raise ValueError(f"Invalid label sequence: {predicted_token_class}")
token_labels[-1] = (token_labels[-1][0] + f" {token}", 'I', current_entity)
else:
raise ValueError(f"Invalid label: {label}")
return token_labels[1:-1]
def joinEntities(entities):
joined_entities = []
i = 0
while i < len(entities):
curr_entity = entities[i]
if curr_entity['entity'][0] == 'B':
label = curr_entity['entity'][2:]
j = i + 1
while j < len(entities) and entities[j]['entity'][0] == 'I':
j += 1
joined_entity = {
'entity': label,
'score': max(e['score'] for e in entities[i:j]),
'index': min(e['index'] for e in entities[i:j]),
'word': ' '.join(e['word'] for e in entities[i:j]),
'start': entities[i]['start'],
'end': entities[j-1]['end']
}
joined_entities.append(joined_entity)
i = j - 1
i += 1
return joined_entities
import pysbd
seg = pysbd.Segmenter(language="es", clean=False)
def sent_tokenize(text):
return seg.segment(text)
def getSentenceIndex(lines,span):
i = 1
sum = len(lines[0])
while sum < span:
sum += len(lines[i])
i = i + 1
return i - 1
def generateContext(text, window,span):
lines = sent_tokenize(text)
index = getSentenceIndex(lines,span)
text = " ".join(lines[max(0,index-window):index+window +1])
return text
def annotateEvents(text,squad,window):
text = text.strip()
ner_results = tagger(text)
#print(ner_results)
#ner_results = joinEntities(ner_results)
i = 0
#exit()
while i < len(ner_results):
ner_results[i]["entity"] = ner_results[i]["entity_group"].lstrip("B-")
ner_results[i]["entity"] = ner_results[i]["entity_group"].lstrip("I-")
i = i + 1
events = []
for trigger in ner_results:
tipo = trigger["entity_group"]
context = generateContext(text,window,trigger["start"])
event = {
"trigger":trigger["word"],
"type": tipo,
"score": trigger["score"],
"context": context,
}
events.append(event)
return events
#"A Joana foi atacada pelo João nas ruas do Porto, com uma faca."
st.title('Extract Events')
options = ["O presidente da Federação Haitiana de Futebol, Yves Jean-Bart, foi banido para sempre de toda a atividade ligada ao futebol, por ter sido considerado culpado de abuso sexual sistemático de jogadoras, anunciou hoje a FIFA.", "O barco de pesca Figaro ainda está a flutuar, embora esteja à deriva e ainda a arder.",
"O navio 'Figaro', no qual viajavam 30 tripulantes - 16 angolanos, cinco espanhóis, cinco senegaleses, três peruanos e um do Gana - acionou por telefone o alarme de incêndio a bordo.", "A Polícia Judiciária (PJ) está a investigar o aparecimento de ossadas que foram hoje avistadas pelo proprietário de um terreno na freguesia de Meadela, em Viana do Castelo, disse à Lusa fonte daquela força policial."]
option = st.selectbox(
'Select examples',
options)
#option = options [index]
line = st.text_area("Insert Text",option)
st.button('Run')
window = 1
if line != "":
st.header("Triggers:")
triggerss = annotateTriggers(line)
annotated_text(*[word[0]+" " if word[1] == 'O' else (word[0]+" ",word[2]) for word in triggerss ])
eventos_1 = annotateEvents(line,1,window)
eventos_2 = annotateEvents(line,2,window)
for mention1, mention2 in zip(eventos_1,eventos_2):
st.text(f"| Trigger: {mention1['trigger']:20} | Type: {mention1['type']:10} | Score: {str(round(mention1['score'],3)):5} |")
st.markdown("""---""")