lfcc's picture
requirements fix
de6996a
raw
history blame
8.28 kB
import streamlit as st
from annotated_text import annotated_text
import torch
from transformers import pipeline
from transformers import AutoModelForTokenClassification, AutoTokenizer
import spacy
import json
st.set_page_config(layout="wide")
model = AutoModelForTokenClassification.from_pretrained("./models/lusa_prepo", use_safetensors=True)
tokenizer = AutoTokenizer.from_pretrained("./models/lusa_prepo", model_max_length=512)
tagger = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy='first') #aggregation_strategy='max'
from spacy.matcher import PhraseMatcher
nlp = spacy.load("en_core_web_sm")
tokenization_contractions = {
"no": ["n", "o"],
"na": ["n", "a"],
"nos": ["n", "os"],
"nas": ["n", "as"],
"ao": ["a", "o"],
# "à": ["a", "a"],
"aos": ["a", "os"],
# "às": ["a", "as"],
"do": ["d", "o"],
"da": ["d", "a"],
"dos": ["d", "os"],
"das": ["d", "as"],
"pelo": ["pel", "o"],
"pela": ["pel", "a"],
"pelos": ["pel", "os"],
"pelas": ["pel", "as"],
"dum": ["d", "um"],
"duma": ["d", "uma"],
"duns": ["d", "uns"],
"dumas": ["d", "umas"],
"num": ["n", "um"],
"numa": ["n", "uma"],
"nuns": ["n", "uns"],
"numas": ["n", "umas"],
"dele": ["d", "ele"],
"dela": ["d", "ela"],
"deles": ["d", "eles"],
"delas": ["d", "elas"],
"deste": ["d", "este"],
"desta": ["d", "esta"],
"destes": ["d", "estes"],
"destas": ["d", "estas"],
"desse": ["d", "esse"],
"dessa": ["d", "essa"],
"desses": ["d", "esses"],
"dessas": ["d", "essas"],
"daquele": ["d", "aquele"],
"daquela": ["d", "aquela"],
"daqueles": ["d", "aqueles"],
"daquelas": ["d", "aquelas"],
}
def tokenize_contractions(doc, tokenization_contractions):
words = tokenization_contractions.keys() # Example: words to be split
splits = tokenization_contractions
matcher = PhraseMatcher(nlp.vocab)
patterns = [nlp.make_doc(text) for text in words]
matcher.add("Terminology", None, *patterns)
matches = matcher(doc)
with doc.retokenize() as retokenizer:
for match_id, start, end in matches:
heads = [(doc[start],1), doc[start]]
attrs = {"POS": ["ADP", "DET"], "DEP": ["pobj", "compound"]}
orths= splits[doc[start:end].text]
retokenizer.split(doc[start], orths=orths, heads=heads, attrs=attrs)
return doc
def aggregate_subwords(input_tokens, labels):
new_inputs = []
new_labels = []
current_word = ""
current_label = ""
for i, token in enumerate(input_tokens):
label = labels[i]
# Handle subwords
if token.startswith('##'):
current_word += token[2:]
else:
# Finish previous word
if current_word:
new_inputs.append(current_word)
new_labels.append(current_label)
# Start new word
current_word = token
current_label = label
new_inputs.append(current_word)
new_labels.append(current_label)
return new_inputs, new_labels
def annotateTriggers(line):
line = line.strip()
doc = nlp(line)
doc = tokenize_contractions(doc, tokenization_contractions)
tokens = [token.text for token in doc]
inputs = tokenizer(tokens, is_split_into_words=True, return_tensors="pt")
input_tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
with torch.no_grad():
logits = model(**inputs).logits
predictions = torch.argmax(logits, dim=2)
predicted_token_class = [model.config.id2label[t.item()] for t in predictions[0]]
input_tokens, predicted_token_class = aggregate_subwords(input_tokens,predicted_token_class)
input_tokens = input_tokens[1:-1]
predicted_token_class = predicted_token_class[1:-1]
print(input_tokens)
print(predicted_token_class)
print(len(input_tokens), len(predicted_token_class))
token_labels = []
current_entity = ''
for i, label in enumerate(predicted_token_class):
token = input_tokens[i]
if label == 'O':
token_labels.append((token, 'O', ''))
current_entity = ''
elif label.startswith('B-'):
current_entity = label[2:]
token_labels.append((token, 'B', current_entity))
elif label.startswith('I-'):
if current_entity == '':
#raise ValueError(f"Invalid label sequence: {predicted_token_class}")
continue
token_labels[-1] = (token_labels[-1][0] + f" {token}", 'I', current_entity)
else:
raise ValueError(f"Invalid label: {label}")
return token_labels
def joinEntities(entities):
joined_entities = []
i = 0
while i < len(entities):
curr_entity = entities[i]
if curr_entity['entity'][0] == 'B':
label = curr_entity['entity'][2:]
j = i + 1
while j < len(entities) and entities[j]['entity'][0] == 'I':
j += 1
joined_entity = {
'entity': label,
'score': max(e['score'] for e in entities[i:j]),
'index': min(e['index'] for e in entities[i:j]),
'word': ' '.join(e['word'] for e in entities[i:j]),
'start': entities[i]['start'],
'end': entities[j-1]['end']
}
joined_entities.append(joined_entity)
i = j - 1
i += 1
return joined_entities
import pysbd
seg = pysbd.Segmenter(language="es", clean=False)
def sent_tokenize(text):
return seg.segment(text)
def getSentenceIndex(lines,span):
i = 1
sum = len(lines[0])
while sum < span:
sum += len(lines[i])
i = i + 1
return i - 1
def generateContext(text, window,span):
lines = sent_tokenize(text)
index = getSentenceIndex(lines,span)
text = " ".join(lines[max(0,index-window):index+window +1])
return text
def annotateEvents(text,squad,window):
text = text.strip()
ner_results = tagger(text)
#print(ner_results)
#ner_results = joinEntities(ner_results)
i = 0
#exit()
while i < len(ner_results):
ner_results[i]["entity"] = ner_results[i]["entity_group"].lstrip("B-")
ner_results[i]["entity"] = ner_results[i]["entity_group"].lstrip("I-")
i = i + 1
events = []
for trigger in ner_results:
tipo = trigger["entity_group"]
context = generateContext(text,window,trigger["start"])
event = {
"trigger":trigger["word"],
"type": tipo,
"score": trigger["score"],
"context": context,
}
events.append(event)
return events
#"A Joana foi atacada pelo João nas ruas do Porto, com uma faca."
st.title('Identify Events')
options = ["Naquele ano o rei morreu na batalha em Almograve. A rainha casou com o irmão dele.","O presidente da Federação Haitiana de Futebol, Yves Jean-Bart, foi banido para sempre de toda a atividade ligada ao futebol, por ter sido considerado culpado de abuso sexual sistemático de jogadoras, anunciou hoje a FIFA.",
"O navio 'Figaro', no qual viajavam 30 tripulantes - 16 angolanos, cinco espanhóis, cinco senegaleses, três peruanos e um do Gana - acionou por telefone o alarme de incêndio a bordo.", "A Polícia Judiciária (PJ) está a investigar o aparecimento de ossadas que foram hoje avistadas pelo proprietário de um terreno na freguesia de Meadela, em Viana do Castelo, disse à Lusa fonte daquela força policial."]
option = st.selectbox(
'Select examples',
options)
#option = options [index]
line = st.text_area("Insert Text",option)
st.button('Run')
window = 1
if line != "":
st.header("Triggers:")
triggerss = annotateTriggers(line)
annotated_text(*[word[0]+" " if word[1] == 'O' else (word[0]+" ",word[2]) for word in triggerss ])
eventos_1 = annotateEvents(line,1,window)
eventos_2 = annotateEvents(line,2,window)
for mention1, mention2 in zip(eventos_1,eventos_2):
st.text(f"| Trigger: {mention1['trigger']:20} | Type: {mention1['type']:10} | Score: {str(round(mention1['score'],3)):5} |")
st.markdown("""---""")