# coding=utf-8
# Copyright 2023 The GIRT Authors.
# Lint as: python3
# This space is built based on AMR-KELEG/ALDi and cis-lmu/GlotLID space.
# GIRT Space
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import streamlit as st
import pandas as pd
import base64
import json
@st.cache_data
def render_svg(svg):
"""Renders the given svg string."""
b64 = base64.b64encode(svg.encode("utf-8")).decode("utf-8")
html = rf'
'
c = st.container()
c.write(html, unsafe_allow_html=True)
@st.cache_resource
def load_model(model_name):
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
return model
@st.cache_resource
def load_tokenizer(model_name):
tokenizer = AutoTokenizer.from_pretrained(model_name)
return tokenizer
@st.cache_resource
def load_examples():
with open("assets/examples.json", "r") as f:
examples = json.load(f)
return examples
# load resources
with st.spinner(text="Please wait while the model is loading...."):
model = load_model('nafisehNik/girt-t5-base')
tokenizer = load_tokenizer('nafisehNik/girt-t5-base')
examples = load_examples()
# create instruction from metadata
def create_instruction(name, about, title, labels, assignees, headline_type, headline, summary):
value_list = [name, about, title, labels, assignees, headline_type, headline]
value_list = ['<|MASK|>' if not element else element for element in value_list]
if not summary:
summary = '<|EMPTY|>'
instruction = f'name: {value_list[0]}\nabout: {value_list[1]}\ntitle: {value_list[2]}\nlabels: {value_list[3]}\nassignees: {value_list[4]}\nheadlines_type: {value_list[5]}\nheadlines: {value_list[6]}\nsummary: {summary}'
return instruction
# compute the output
def compute(sample, top_p, top_k, do_sample, max_length, min_length):
inputs = tokenizer(sample, return_tensors="pt").to('cpu')
outputs = model.generate(
**inputs,
min_length= min_length,
max_length=max_length,
do_sample=do_sample,
top_p=top_p,
top_k=top_k).to('cpu')
generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=False)
generated_text = generated_texts[0]
replace_dict = {
'\n ': '\n',
'': '',
' ': '',
'': '',
'!--': '