testing / app.py
gofilipa's picture
updating with bedtime story code
beda32e
raw
history blame
1.1 kB
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextDataset, DataCollatorForLanguageModeling, Trainer, TrainingArguments, pipeline
from accelerate import Accelerator
accelerator = Accelerator(cpu=True)
# def greet(name):
# return "Hello " + name + "!!"
tokenizer = accelerator.prepare(AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125m"))
model = accelerator.prepare(AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-125m"))
def plex(input_text):
mnputs = tokenizer(input_text, return_tensors='pt')
prediction = model.generate(mnputs['input_ids'], min_length=20, max_length=150, num_return_sequences=1)
lines = tokenizer.decode(prediction[0]).splitlines()
return lines[0]
iface=gr.Interface(
fn=plex,
inputs=gr.Textbox(label="Prompt", value="Once upon a"),
outputs=gr.Textbox(label="Generated_Text"),
title="GPT-Neo-125M",
description="Prompt"
)
iface.queue(max_size=1,api_open=False)
iface.launch(max_threads=1)
# iface = gr.Interface(fn=greet, inputs="text", outputs="text")
# iface.launch()