import streamlit as st from datasets import load_dataset from transformers import GPT2Tokenizer, GPT2LMHeadModel, DataCollatorForLanguageModeling, Trainer, TrainingArguments @st.cache_resource def load_and_fine_tune_model(): # Load the dataset dataset = load_dataset("blog_authorship_corpus") # Load the tokenizer tokenizer = GPT2Tokenizer.from_pretrained("gpt2") def tokenize_function(examples): return tokenizer(examples["text"], truncation=True) tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=["text"]) # Data collator data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) # Load the model model = GPT2LMHeadModel.from_pretrained("gpt2") # Training arguments training_args = TrainingArguments( output_dir="./results", overwrite_output_dir=True, num_train_epochs=1, per_device_train_batch_size=2, save_steps=10_000, save_total_limit=2, ) # Initialize the Trainer trainer = Trainer( model=model, args=training_args, data_collator=data_collator, train_dataset=tokenized_datasets['train'], eval_dataset=tokenized_datasets['test'] ) # Fine-tune the model trainer.train() # Save the fine-tuned model model.save_pretrained("./fine-tuned-gpt2") tokenizer.save_pretrained("./fine-tuned-gpt2") return model, tokenizer def generate_blog_post(prompt, model, tokenizer, max_length=500, temperature=0.7, top_k=50): input_ids = tokenizer.encode(prompt, return_tensors="pt") output = model.generate( input_ids, max_length=max_length, temperature=temperature, top_k=top_k, no_repeat_ngram_size=2, num_return_sequences=1 ) generated_text = tokenizer.decode(output[0], skip_special_tokens=True) return generated_text # Streamlit UI st.title("Blog Post Generator") prompt = st.text_input("Enter a prompt for the blog post:", "The future of artificial intelligence in daily life") if st.button("Generate Blog Post"): with st.spinner("Fine-tuning the model. This might take a few minutes..."): model, tokenizer = load_and_fine_tune_model() blog_post = generate_blog_post(prompt, model, tokenizer) st.subheader("Generated Blog Post") st.write(blog_post)