NadaAljohani's picture
Update app.py
41b4109 verified
import gradio as gr
from transformers import pipeline
import re
# Function to clean the output by truncating at the last full sentence and formatting paragraphs
def clean_output(text):
# Remove unwanted symbols and replace them with appropriate punctuation or space
text = re.sub(r'\.{2,}', '.', text) # Replace sequences of more than one period with a single period
text = re.sub(r'[:\-]+', '', text) # Remove colons and dashes
text = re.sub(r'[()]+', '', text) # Remove parentheses
text = re.sub(r'\s+', ' ', text) # Replace excessive spaces
text = re.sub(r'[^\S\n]+', ' ', text) # Remove non-visible spaces like tabs
# Ensure the text ends with a full sentence
if '.' in text:
text = text[:text.rfind('.')+1] # Truncate at the last full sentence
# Add paragraph breaks by splitting sentences into paragraphs
sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', text) # Split by sentence-ending punctuation
# Create paragraphs by grouping sentences
paragraph_size = len(sentences) // 4 # Split into approximately 4 paragraphs
paragraphs = [' '.join(sentences[i:i + paragraph_size]) for i in range(0, len(sentences), paragraph_size)]
# Limit the number of paragraphs to 3-4
paragraphs = paragraphs[:4]
# Join paragraphs with double line breaks
formatted_text = '\n\n'.join(paragraphs)
return formatted_text.strip() # Return trimmed and formatted text
# Function to generate the story
def generate_story(title, model_name="gpt2"):
# Use text-generation pipeline from Hugging Face
generator = pipeline('text-generation', model=model_name)
# Generate the story based on the input title
story = generator(title,
max_length=500, # Set the maximum length for the generated text (story)
no_repeat_ngram_size=3, # Avoid repeating any sequence of 3 words (to prevent repetitive text)
temperature=0.8, # Introduce some randomness
top_p=0.95 # Use nucleus sampling for coherent output
)[0]['generated_text']
# Clean the generated story to ensure it ends with a full sentence and trim it into paragraphs
cleaned_story = clean_output(story)
# Return the cleaned and formatted story
return cleaned_story
# Create the Gradio interface using gr.Interface
demo = gr.Interface(
fn=generate_story,
inputs=[
gr.Textbox(label="Enter Story Title", placeholder="Type a title here..."), # Title input
gr.Dropdown(choices=[
'gpt2',
'gpt2-large',
'EleutherAI/gpt-neo-2.7B',
'EleutherAI/gpt-j-6B',
'maldv/badger-writer-llama-3-8b',
'gpt-neo-2.7B'
], value='gpt2', label="Choose Model") # Model selection
],
outputs="text",
title="AI Story Generator",
description="Generate a creative story using different AI models.",
examples=[
["Sara burst into her friend's house, only to find it plunged into darkness. A strange, pulsing glow flickered from the corner, casting eerie shadows on the walls. Her heart raced as she called out, but there was no answer. Something wasn’t right. On the table sat an unfamiliar, glowing device—humming with energy. With a deep breath, Sara stepped closer, knowing that once she touched it, there would be no turning back."]
]
)
# Launch the interface with sharing enabled
demo.launch(share=True)