Spaces:
Sleeping
Sleeping
import streamlit as st | |
from datasets import load_dataset | |
from transformers import GPT2Tokenizer, GPT2LMHeadModel, DataCollatorForLanguageModeling, Trainer, TrainingArguments | |
def load_and_fine_tune_model(): | |
# Load the dataset | |
dataset = load_dataset("blog_authorship_corpus") | |
# Load the tokenizer | |
tokenizer = GPT2Tokenizer.from_pretrained("gpt2") | |
def tokenize_function(examples): | |
return tokenizer(examples["text"], truncation=True) | |
tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=["text"]) | |
# Data collator | |
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) | |
# Load the model | |
model = GPT2LMHeadModel.from_pretrained("gpt2") | |
# Training arguments | |
training_args = TrainingArguments( | |
output_dir="./results", | |
overwrite_output_dir=True, | |
num_train_epochs=1, | |
per_device_train_batch_size=2, | |
save_steps=10_000, | |
save_total_limit=2, | |
) | |
# Initialize the Trainer | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
data_collator=data_collator, | |
train_dataset=tokenized_datasets['train'], | |
eval_dataset=tokenized_datasets['test'] | |
) | |
# Fine-tune the model | |
trainer.train() | |
# Save the fine-tuned model | |
model.save_pretrained("./fine-tuned-gpt2") | |
tokenizer.save_pretrained("./fine-tuned-gpt2") | |
return model, tokenizer | |
def generate_blog_post(prompt, model, tokenizer, max_length=500, temperature=0.7, top_k=50): | |
input_ids = tokenizer.encode(prompt, return_tensors="pt") | |
output = model.generate( | |
input_ids, | |
max_length=max_length, | |
temperature=temperature, | |
top_k=top_k, | |
no_repeat_ngram_size=2, | |
num_return_sequences=1 | |
) | |
generated_text = tokenizer.decode(output[0], skip_special_tokens=True) | |
return generated_text | |
# Streamlit UI | |
st.title("Blog Post Generator") | |
prompt = st.text_input("Enter a prompt for the blog post:", "The future of artificial intelligence in daily life") | |
if st.button("Generate Blog Post"): | |
with st.spinner("Fine-tuning the model. This might take a few minutes..."): | |
model, tokenizer = load_and_fine_tune_model() | |
blog_post = generate_blog_post(prompt, model, tokenizer) | |
st.subheader("Generated Blog Post") | |
st.write(blog_post) | |