Spaces:
Build error
Build error
from transformers import pipeline | |
from transformers import AutoModelForSeq2SeqLM | |
from transformers import AutoTokenizer | |
from textblob import TextBlob | |
from hatesonar import Sonar | |
import gradio as gr | |
import torch | |
# Load trained model | |
model = AutoModelForSeq2SeqLM.from_pretrained("output/reframer") | |
tokenizer = AutoTokenizer.from_pretrained("output/reframer") | |
reframer = pipeline('summarization', model=model, tokenizer=tokenizer) | |
CHAR_LENGTH_LOWER_BOUND = 15 # The minimum character length threshold for the input text | |
CHAR_LENGTH_HIGHER_BOUND = 150 # The maximum character length threshold for the input text | |
SENTIMENT_THRESHOLD = 0.2 # The maximum Textblob sentiment score for the input text | |
OFFENSIVENESS_CONFIDENCE_THRESHOLD = 0.8 # The threshold for the confidence score of a text being offensive | |
LENGTH_ERROR = "The input text is too long or too short. Please try again by inputing text with moderate length." | |
SENTIMENT_ERROR = "The input text is too positive. Please try again by inputing text with negative sentiment." | |
OFFENSIVE_ERROR = "The input text is offensive. Please try again by inputing non-offensive text." | |
CACHE = [] # A list storing the most recent 5 reframing history | |
MAX_STORE = 5 # The maximum number of history user would like to store | |
BEST_N = 3 # The number of best decodes user would like to seee | |
def input_error_message(error_type): | |
# type: (str) -> str | |
"""Generate an input error message from error type.""" | |
return "[Error]: Invalid Input. " + error_type | |
def update_cache(cache, new_record): | |
# type: List[List[str, str, str]] -> List[List[str, str, str]] | |
"""Update the cache to store the most recent five reframing histories.""" | |
cache.append(new_record) | |
if len(cache) > MAX_STORE: | |
cache = cache[1:] | |
return cache | |
def reframe(input_text, strategy): | |
# type: (str, str) -> str | |
"""Reframe the input text with a specified strategy. | |
The strategy will be concetenated to the input text and passed to a finetuned BART model. | |
The reframed positive text will be returned. | |
""" | |
text_with_strategy = input_text + "Strategy: ['" + strategy + "']" | |
# Input Control | |
# The input text cannot be too short to ensure it has substantial content to be reframed. It also cannot be too long to ensure the text has a focused idea. | |
if len(input_text) < CHAR_LENGTH_LOWER_BOUND or len(input_text) > CHAR_LENGTH_HIGHER_BOUND: | |
return input_text + input_error_message(LENGTH_ERROR) | |
# The input text cannot be too positive to ensure the text can be positively reframed. | |
if TextBlob(input_text).sentiment.polarity > 0.2: | |
return input_text + input_error_message(SENTIMENT_ERROR) | |
# The input text cannot be offensive. | |
sonar = Sonar() | |
# sonar.ping(input_text) outputs a dictionary and the second score under the key classes is the confidence for the input text being offensive language | |
if sonar.ping(input_text)['classes'][1]['confidence'] > OFFENSIVENESS_CONFIDENCE_THRESHOLD: | |
return input_text + input_error_message(OFFENSIVE_ERROR) | |
# Reframing | |
# reframer pipeline outputs a list containing one dictionary where the value for 'summary_text' is the reframed text output | |
reframed_text = reframer(text_with_strategy)[0]['summary_text'] | |
# Update cache | |
global CACHE | |
CACHE = update_cache(CACHE, [input_text, strategy, reframed_text]) | |
return reframed_text | |
def show_reframe_change(input_text, strategy): | |
# type: (str, str) -> List[Tuple[str, str]] | |
"""Compare the addition and deletion of characters in input_text to form reframed_text. | |
The returned output is a list of tuples with two elements, the first element being the character in reframed text and the second element being the action performed with respect to the input text. | |
""" | |
reframed_text = reframe(input_text, strategy) | |
from difflib import Differ | |
d = Differ() | |
return [ | |
(token[2:], token[0] if token[0] != " " else None) | |
for token in d.compare(input_text, reframed_text) | |
] | |
def show_n_best_decodes(input_text, strategy): | |
# type: (str, str) -> str | |
prompt = [input_text + "Strategy: ['" + strategy + "']"] | |
n_best_decodes = model.generate(torch.tensor(tokenizer(prompt, padding=True)['input_ids']), | |
do_sample=True, | |
num_return_sequences=BEST_N | |
) | |
best_n_result = "" | |
for i in range(len(n_best_decodes)): | |
best_n_result += str(i+1) + " " + tokenizer.decode(n_best_decodes[i], skip_special_tokens=True) | |
if i < BEST_N - 1: | |
best_n_result += "\n" | |
return best_n_result | |
def show_history(cache): | |
# type: List[List[str, str, str]] -> str | |
history = "" | |
for i in cache: | |
input_text, strategy, reframed_text = i | |
history += "Input text: " + input_text + " Strategy: " + strategy + " -> Reframed text: " + reframed_text + "\n" | |
return gr.Textbox.update(value=history, visible=True) | |
demo = gr.Interface( | |
fn=show_reframe_change, | |
inputs=[gr.Textbox(lines=2, placeholder="Please input the sentence to be reframed.", label="Original Text"), gr.Radio(["thankfulness", "neutralizing", "optimism", "growth", "impermanence", "self_affirmation"], label="Strategy to use?")], | |
outputs=gr.HighlightedText(label="Diff",combine_adjacent=True,).style(color_map={"+": "green", "-": "red"}), | |
) | |
demo.launch(show_api=True) |