import streamlit as st from transformers import AutoModelForCausalLM, AutoTokenizer import torch import gc class SpacesHaikuGenerator: def __init__(self): self.models = { "TinyLlama": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "Flan-T5": "google/flan-t5-large", "GPT2-Medium": "gpt2-medium", "BART": "facebook/bart-large" } self.loaded_model = None self.loaded_tokenizer = None self.current_model = None self.style_prompts = { "Nature": "Write a nature-inspired haiku about {name}, who is {traits}", "Urban": "Create a modern city haiku about {name}, characterized by {traits}", "Emotional": "Compose an emotional haiku capturing {name}'s essence: {traits}", "Reflective": "Write a contemplative haiku about {name}, focusing on {traits}" } def load_model(self, model_name): """Load model with proper caching.""" @st.cache_resource def _load_model_cached(_model_name): tokenizer = AutoTokenizer.from_pretrained( self.models[_model_name], trust_remote_code=True ) model = AutoModelForCausalLM.from_pretrained( self.models[_model_name], trust_remote_code=True, torch_dtype=torch.float16, low_cpu_mem_usage=True ) if torch.cuda.is_available(): model = model.to("cuda") return model, tokenizer if self.current_model != model_name: # Clear previous model if self.loaded_model is not None: del self.loaded_model del self.loaded_tokenizer torch.cuda.empty_cache() gc.collect() # Load new model using cached function self.loaded_model, self.loaded_tokenizer = _load_model_cached(model_name) self.current_model = model_name def generate_haiku(self, name, traits, model_name, style): """Generate a free-form haiku using the selected model.""" self.load_model(model_name) # Format traits for prompt traits_text = ", ".join(traits) # Construct prompt based on model base_prompt = self.style_prompts[style].format(name=name, traits=traits_text) prompt = f"""{base_prompt} Create a free-form haiku that: - Uses imagery and metaphor - Captures a single moment - Reflects the character's essence Haiku:""" # Configure generation parameters based on model max_length = 100 if model_name == "Flan-T5": max_length = 50 # T5 tends to be more concise # Generate text inputs = self.loaded_tokenizer(prompt, return_tensors="pt") if torch.cuda.is_available(): inputs = inputs.to("cuda") with torch.no_grad(): outputs = self.loaded_model.generate( **inputs, max_length=max_length, num_return_sequences=1, temperature=0.9, top_p=0.9, do_sample=True, pad_token_id=self.loaded_tokenizer.eos_token_id ) generated_text = self.loaded_tokenizer.decode(outputs[0], skip_special_tokens=True) haiku_text = generated_text.split("Haiku:")[-1].strip() # Format into three lines lines = [line.strip() for line in haiku_text.split('\n') if line.strip()] return lines[:3] # Ensure exactly 3 lines def main(): st.title("🎋 Free-Form Haiku Generator") st.write("Create unique AI-generated haikus about characters") # Initialize generator @st.cache_resource def get_generator(): return SpacesHaikuGenerator() generator = get_generator() # Input fields col1, col2 = st.columns([1, 2]) with col1: name = st.text_input("Character Name") # Four traits in a grid traits = [] cols = st.columns(4) for i, col in enumerate(cols): label = "Trait" if i < 2 else "Hobby" if i == 2 else "Physical" trait = col.text_input(f"{label} {i + 1}") if trait: traits.append(trait) # Model and style selection col1, col2 = st.columns(2) with col1: model = st.selectbox("Choose Model", list(generator.models.keys())) with col2: style = st.selectbox("Choose Style", list(generator.style_prompts.keys())) if name and len(traits) == 4: if st.button("Generate Haiku"): with st.spinner(f"Creating your haiku using {model}..."): try: haiku_lines = generator.generate_haiku(name, traits, model, style) # Display haiku st.markdown("---") for line in haiku_lines: st.markdown(f"*{line}*") st.markdown("---") # Metadata st.caption(f"Style: {style} | Model: {model}") # Regenerate option if st.button("Create Another"): st.experimental_rerun() except Exception as e: st.error(f"Generation error: {str(e)}") st.info("Try a different model or simplify your input.") # Tips sidebar st.sidebar.markdown(""" ### Tips for Better Results: - Use vivid, descriptive traits - Mix concrete and abstract details - Try different models for variety - Experiment with styles """) if __name__ == "__main__": main()