Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import os | |
from transformers import AutoTokenizer, AutoModelForCausalLM, TextDataset, DataCollatorForLanguageModeling, Trainer, TrainingArguments, pipeline | |
from accelerate import Accelerator | |
accelerator = Accelerator(cpu=True) | |
cwd = "./models" | |
tokenizer = accelerator.prepare(AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125m")) | |
model = accelerator.prepare(AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-125m")) | |
train_dataset = TextDataset( | |
tokenizer=tokenizer, | |
file_path='./train_text.txt', | |
block_size=128 | |
) | |
data_collator = DataCollatorForLanguageModeling( | |
tokenizer=tokenizer, | |
mlm=False | |
) | |
training_args = TrainingArguments( | |
output_dir=cwd, | |
overwrite_output_dir=True, | |
num_train_epochs=5, | |
per_device_train_batch_size=5, | |
save_steps=500, | |
save_total_limit=5, | |
) | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
data_collator=data_collator, | |
train_dataset=train_dataset, | |
) | |
trainer.train() | |
tokenizer.save_pretrained('./models') | |
trainer.save_model('./models', 'pytorch_model') | |
src = './config.json' | |
des = './models/config.json' | |
os.rename(src, des) | |
tokenizer = accelerator.prepare(AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125m")) | |
model = accelerator.prepare(AutoModelForCausalLM.from_pretrained("./models")) | |
def plex(input_text): | |
mnputs = tokenizer(input_text, return_tensors='pt') | |
prediction = model.generate(mnputs['input_ids'], min_length=20, max_length=150, num_return_sequences=1) | |
lines = tokenizer.decode(prediction[0]).splitlines() | |
return lines[0] | |
iface=gr.Interface( | |
fn=plex, | |
inputs=gr.Textbox(label="Prompt Finetuned Model Exmpl: a cat, Exmpl: 3 little bears, Exmpl: once upon a time", value="Once upon a"), | |
outputs=gr.Textbox(label="Generated_Text"), | |
title="GPT-Neo-125M fine-tuned on a small set of shortstories with Gradio", | |
description="Prompt for a short bedtime story.", | |
##examples=gr.Examples(fn=fine_tune_llm,inputs=['./test.txt',"Once upon a time",2,2000],outputs=[gr.Textbox(),gr.File()],cache_examples=True,) | |
) | |
iface.queue(max_size=1,api_open=False) | |
iface.launch(max_threads=1) |