tamilatis / app.py.txt
seanbenhur's picture
add code
2705a3e
raw
history blame
1.17 kB
from tamilatis.predict import TamilATISPredictor
from tamilatis.model import JointATISModel
import numpy as np
from sklearn.preprocessing import LabelEncoder
import gradio as gr
model_name = "microsoft/xlm-align-base"
tokenizer_name = "microsoft/xlm-align-base"
num_labels = 78
num_intents = 23
checkpoint_path = "/content/tamil_atis/xlm_align_base.bin"
intent_encoder_path = "/content/tamil_atis/intent_annotations.npy"
ner_encoder_path = "/content/tamil_atis/ner_annotations.npy"
def predict_function(text):
label_encoder = LabelEncoder()
label_encoder.classes_ = np.load(ner_encoder_path)
intent_encoder = LabelEncoder()
intent_encoder.classes_ = np.load(intent_encoder_path)
model = JointATISModel(model_name,num_labels,num_intents)
predictor = TamilATISPredictor(model,checkpoint_path,tokenizer_name,
label_encoder,intent_encoder,num_labels)
slot_prediction, intent_preds = predictor.get_predictions(text)
return slot_prediction, intent_preds
title = "Tamil ATIS"
iface = gr.Interface(fn=predict_function, inputs="text", title=title,theme="huggingface",outputs=["text","text"])
iface.launch(share=True)