Haiku / app.py
azeus
fix loading issue
94c8c98
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import gc
class SpacesHaikuGenerator:
def __init__(self):
self.models = {
"TinyLlama": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"Flan-T5": "google/flan-t5-large",
"GPT2-Medium": "gpt2-medium",
"BART": "facebook/bart-large"
}
self.loaded_model = None
self.loaded_tokenizer = None
self.current_model = None
self.style_prompts = {
"Nature": "Write a nature-inspired haiku about {name}, who is {traits}",
"Urban": "Create a modern city haiku about {name}, characterized by {traits}",
"Emotional": "Compose an emotional haiku capturing {name}'s essence: {traits}",
"Reflective": "Write a contemplative haiku about {name}, focusing on {traits}"
}
def load_model(self, model_name):
"""Load model with proper caching."""
@st.cache_resource
def _load_model_cached(_model_name):
tokenizer = AutoTokenizer.from_pretrained(
self.models[_model_name],
trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
self.models[_model_name],
trust_remote_code=True,
torch_dtype=torch.float16,
low_cpu_mem_usage=True
)
if torch.cuda.is_available():
model = model.to("cuda")
return model, tokenizer
if self.current_model != model_name:
# Clear previous model
if self.loaded_model is not None:
del self.loaded_model
del self.loaded_tokenizer
torch.cuda.empty_cache()
gc.collect()
# Load new model using cached function
self.loaded_model, self.loaded_tokenizer = _load_model_cached(model_name)
self.current_model = model_name
def generate_haiku(self, name, traits, model_name, style):
"""Generate a free-form haiku using the selected model."""
self.load_model(model_name)
# Format traits for prompt
traits_text = ", ".join(traits)
# Construct prompt based on model
base_prompt = self.style_prompts[style].format(name=name, traits=traits_text)
prompt = f"""{base_prompt}
Create a free-form haiku that:
- Uses imagery and metaphor
- Captures a single moment
- Reflects the character's essence
Haiku:"""
# Configure generation parameters based on model
max_length = 100
if model_name == "Flan-T5":
max_length = 50 # T5 tends to be more concise
# Generate text
inputs = self.loaded_tokenizer(prompt, return_tensors="pt")
if torch.cuda.is_available():
inputs = inputs.to("cuda")
with torch.no_grad():
outputs = self.loaded_model.generate(
**inputs,
max_length=max_length,
num_return_sequences=1,
temperature=0.9,
top_p=0.9,
do_sample=True,
pad_token_id=self.loaded_tokenizer.eos_token_id
)
generated_text = self.loaded_tokenizer.decode(outputs[0], skip_special_tokens=True)
haiku_text = generated_text.split("Haiku:")[-1].strip()
# Format into three lines
lines = [line.strip() for line in haiku_text.split('\n') if line.strip()]
return lines[:3] # Ensure exactly 3 lines
def main():
st.title("πŸŽ‹ Free-Form Haiku Generator")
st.write("Create unique AI-generated haikus about characters")
# Initialize generator
@st.cache_resource
def get_generator():
return SpacesHaikuGenerator()
generator = get_generator()
# Input fields
col1, col2 = st.columns([1, 2])
with col1:
name = st.text_input("Character Name")
# Four traits in a grid
traits = []
cols = st.columns(4)
for i, col in enumerate(cols):
label = "Trait" if i < 2 else "Hobby" if i == 2 else "Physical"
trait = col.text_input(f"{label} {i + 1}")
if trait:
traits.append(trait)
# Model and style selection
col1, col2 = st.columns(2)
with col1:
model = st.selectbox("Choose Model", list(generator.models.keys()))
with col2:
style = st.selectbox("Choose Style", list(generator.style_prompts.keys()))
if name and len(traits) == 4:
if st.button("Generate Haiku"):
with st.spinner(f"Creating your haiku using {model}..."):
try:
haiku_lines = generator.generate_haiku(name, traits, model, style)
# Display haiku
st.markdown("---")
for line in haiku_lines:
st.markdown(f"*{line}*")
st.markdown("---")
# Metadata
st.caption(f"Style: {style} | Model: {model}")
# Regenerate option
if st.button("Create Another"):
st.experimental_rerun()
except Exception as e:
st.error(f"Generation error: {str(e)}")
st.info("Try a different model or simplify your input.")
# Tips sidebar
st.sidebar.markdown("""
### Tips for Better Results:
- Use vivid, descriptive traits
- Mix concrete and abstract details
- Try different models for variety
- Experiment with styles
""")
if __name__ == "__main__":
main()