AndreaSimeri's picture
Update app.py
4b5a30d verified
raw
history blame
No virus
2.53 kB
import pandas as pd
import re
import gradio as gr
import torch
from transformers import BertTokenizerFast, BertForSequenceClassification
if torch.cuda.is_available():
device = torch.device("cuda")
print('There are %d GPU(s) available.' % torch.cuda.device_count())
print('We will use the GPU:', torch.cuda.get_device_name(0))
else:
print('No GPU available, using the CPU instead.')
device = torch.device("cpu")
dataset_path = './codice_civile_ITA_LIBRI_2_withArtRef_v2.csv'
def load_CC_from_CSV(path):
NUM_ART = 0
cc = pd.read_csv(path, header=None, sep='|', usecols=[1,2,3], names=['art','title','text'], engine='python')
article_id={}
id_article={}
article_text={}
for i in range(len(cc)):
NUM_ART +=1
art = re.sub('(\s|\.|\-)*', '', str(cc['art'][i]).lower())
article_id[art] = i
id_article[i] = art
article_text[art] = str(cc['title'][i]).lower() + " -> " + str(cc['text'][i]).lower()
if i == 59:
break
return article_id, id_article, article_text, NUM_ART
article_id, id_article, article_text, NUM_ART = load_CC_from_CSV(dataset_path)
model = BertForSequenceClassification.from_pretrained("AndreaSimeri/LamBERTa_v5")
tokenizer = BertTokenizerFast.from_pretrained("AndreaSimeri/LamBERTa_v5")
def LamBERTa_v5_placeholder(query):
n = 345
predictions = torch.softmax(torch.randn(n), dim=0)
values, indices = torch.topk(predictions, 5)
confidences = {id_article[i.item()] : v.item() for i, v in zip(indices, values)}
# confidences = {id_article[i] : float(predictions[i]) for i in range(n)}
return confidences
def LamBERTa(query):
texts = []
input_ids = torch.tensor(tokenizer.encode(query, add_special_tokens=True)).unsqueeze(0) # Batch size 1
labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=labels)
loss, logits = outputs[:2]
log_probs = torch.softmax(logits, dim=1)
values, indices = torch.topk(log_probs, 3, dim=1)
confidences = {id_article[i.item()] : v.item() for i, v in zip(indices[0], values[0])}
for art, prob in confidences.items():
texts.append(
{
"art": art,
"text": article_text[art],
}
)
return confidences, texts
demo = gr.Interface(fn=LamBERTa, inputs="text", outputs=["label", "json"], examples=["Quando si apre la successione","Dove si apre la successione","In quali casi, alla morte, non spetta l'eredità"], live=True)
demo.launch()
demo.launch(share=True)