Spaces:
Sleeping
Sleeping
import gradio as gr | |
import requests | |
import json | |
import nltk | |
import os | |
from huggingface_hub import InferenceClient | |
from nltk import tokenize | |
nltk.download('punkt') | |
client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.2") | |
API_URL = "https://api-inference.huggingface.co/models/SamLowe/roberta-base-go_emotions" | |
headers = {"Authorization": f"Bearer {os.environ['hug_token']}"} | |
zero_shot_prompt = ''' | |
[INST] | |
You are a helpful movie emotion finder, you receive the following TMDB movie description of a movie. | |
Your task is to identify the emotion movie represent based on this description. | |
The emotion should be only one or multiple from Ekman's 6 emotion classes : happiness, anger, sadness, fear, disgust, surprise. | |
If you cannot identify the emotion for some reason, simply respond with 'Unknown' Only use Elman's emotions as output. | |
Output the emotions as json list with objects containing emotion and explanation as key value. | |
TMDB movie description: {0} | |
Emotion: | |
[/INST] | |
''' | |
def format_prompt(message): | |
prompt = zero_shot_prompt.format(message) | |
return prompt | |
def generate( | |
text, temperature=0.1, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0, | |
): | |
temperature = float(temperature) | |
if temperature < 1e-2: | |
temperature = 1e-2 | |
top_p = float(top_p) | |
generate_kwargs = dict( | |
temperature=temperature, | |
max_new_tokens=max_new_tokens, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
do_sample=True, | |
seed=42, | |
) | |
formatted_prompt = format_prompt(text) | |
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False) | |
output = "" | |
for response in stream: | |
output += response.token.text | |
yield output | |
return output | |
def query(payload): | |
response = requests.post(API_URL, headers=headers, json=payload) | |
return response.json() | |
def classify(text): | |
sentences = tokenize.sent_tokenize(text) | |
print(f"#sentences: (input) : {len(sentences)}") | |
payload = { | |
"inputs" : sentences | |
} | |
output = query(payload) | |
print(f"#sentences: (output) : {len(output)}") | |
emotions = dict() | |
for sent_emotions in output: | |
for sent_emotion in sent_emotions: | |
emotion = sent_emotion['label'] | |
score = sent_emotion['score'] | |
if emotion in emotions: | |
emotions[emotion] += score | |
else: | |
emotions[emotion] = score | |
sorted_emotions = dict(sorted(emotions.items(), key=lambda item: item[1], reverse=True)) | |
#sorted_emotions = sorted(emotions, key=emotions.get, reverse=True) | |
print(json.dumps(sorted_emotions)) | |
return json.dumps(sorted_emotions) | |
additional_inputs=[ | |
gr.Slider( | |
label="Temperature", | |
value=0.1, | |
minimum=0.0, | |
maximum=1.0, | |
step=0.05, | |
interactive=True, | |
info="Higher values produce more diverse outputs", | |
), | |
gr.Slider( | |
label="Max new tokens", | |
value=256, | |
minimum=0, | |
maximum=1048, | |
step=64, | |
interactive=True, | |
info="The maximum numbers of new tokens", | |
), | |
gr.Slider( | |
label="Top-p (nucleus sampling)", | |
value=0.90, | |
minimum=0.0, | |
maximum=1, | |
step=0.05, | |
interactive=True, | |
info="Higher values sample more low-probability tokens", | |
), | |
gr.Slider( | |
label="Repetition penalty", | |
value=1.2, | |
minimum=1.0, | |
maximum=2.0, | |
step=0.05, | |
interactive=True, | |
info="Penalize repeated tokens", | |
) | |
] | |
#iface = gr.Interface(fn=generate, inputs="text", outputs="text", additional_inputs=additional_inputs) | |
iface = gr.Interface(fn=classify, inputs="text", outputs="text") | |
iface.launch() | |