|
import gradio as gr |
|
from transformers import pipeline |
|
|
|
title = "Silly Ted-Talk snippet generator" |
|
description = "Tap on the \"Submit\" button to generate a random text snippet." |
|
article = "<p>Fine tuned <a href=\"https://huggingface.co/EleutherAI/gpt-neo-125M\">EleutherAI/gpt-neo-125M</a> upon a formatted <a href=\"https://www.kaggle.com/datasets/miguelcorraljr/ted-ultimate-dataset\"> TED β Ultimate Dataset</a> (English)</p>" |
|
|
|
model_id = "./model" |
|
text_generator = pipeline('text-generation', model=model_id, tokenizer=model_id) |
|
max_length = 128 |
|
top_k = 40 |
|
top_p = 0.92 |
|
temperature = 1.0 |
|
|
|
def text_generation(input_text = None): |
|
if input_text == None or len(input_text) == 0: |
|
input_text = "\t\"" |
|
else: |
|
input_text.replace("\"", "") |
|
if input_text.startswith("<|startoftext|>") == False: |
|
input_text ="\t\"" + input_text |
|
generated_text = text_generator(input_text, |
|
max_length=max_length, |
|
top_k=top_k, |
|
top_p=top_p, |
|
temperature=temperature, |
|
do_sample=True, |
|
repetition_penalty=2.0, |
|
num_return_sequences=1) |
|
parsed_text = generated_text[0]["generated_text"].replace("<|startoftext|>", "").replace("\r","").replace("\n\n", "\n").replace("\t", " ").replace("<|pad|>", " * ").replace("\"\"", "\"") |
|
return parsed_text |
|
|
|
gr.Interface( |
|
text_generation, |
|
[gr.inputs.Textbox(lines=1, label="Enter input text or leave blank")], |
|
outputs=[gr.outputs.Textbox(type="text", label="Generated Ted-Talk snippet")], |
|
title=title, |
|
description=description, |
|
article=article, |
|
theme="default", |
|
allow_flagging=False, |
|
).launch() |