Spaces:
Sleeping
Sleeping
File size: 2,383 Bytes
15f0e98 6db102d 15f0e98 6db102d 15f0e98 6db102d 15f0e98 6db102d 15f0e98 6db102d 15f0e98 6db102d 15f0e98 6db102d 15f0e98 6db102d 15f0e98 6db102d 15f0e98 6db102d 15f0e98 6db102d 15f0e98 6db102d 15f0e98 6db102d 15f0e98 6db102d 15f0e98 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import streamlit as st
from datasets import load_dataset
from transformers import GPT2Tokenizer, GPT2LMHeadModel, DataCollatorForLanguageModeling, Trainer, TrainingArguments
@st.cache_resource
def load_and_fine_tune_model():
# Load the dataset
dataset = load_dataset("blog_authorship_corpus")
# Load the tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
def tokenize_function(examples):
return tokenizer(examples["text"], truncation=True)
tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=["text"])
# Data collator
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
# Load the model
model = GPT2LMHeadModel.from_pretrained("gpt2")
# Training arguments
training_args = TrainingArguments(
output_dir="./results",
overwrite_output_dir=True,
num_train_epochs=1,
per_device_train_batch_size=2,
save_steps=10_000,
save_total_limit=2,
)
# Initialize the Trainer
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=tokenized_datasets['train'],
eval_dataset=tokenized_datasets['test']
)
# Fine-tune the model
trainer.train()
# Save the fine-tuned model
model.save_pretrained("./fine-tuned-gpt2")
tokenizer.save_pretrained("./fine-tuned-gpt2")
return model, tokenizer
def generate_blog_post(prompt, model, tokenizer, max_length=500, temperature=0.7, top_k=50):
input_ids = tokenizer.encode(prompt, return_tensors="pt")
output = model.generate(
input_ids,
max_length=max_length,
temperature=temperature,
top_k=top_k,
no_repeat_ngram_size=2,
num_return_sequences=1
)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
return generated_text
# Streamlit UI
st.title("Blog Post Generator")
prompt = st.text_input("Enter a prompt for the blog post:", "The future of artificial intelligence in daily life")
if st.button("Generate Blog Post"):
with st.spinner("Fine-tuning the model. This might take a few minutes..."):
model, tokenizer = load_and_fine_tune_model()
blog_post = generate_blog_post(prompt, model, tokenizer)
st.subheader("Generated Blog Post")
st.write(blog_post)
|