|
import streamlit as st |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
import torch |
|
import gc |
|
|
|
|
|
class SpacesHaikuGenerator: |
|
def __init__(self): |
|
self.models = { |
|
"TinyLlama": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", |
|
"Flan-T5": "google/flan-t5-large", |
|
"GPT2-Medium": "gpt2-medium", |
|
"BART": "facebook/bart-large" |
|
} |
|
|
|
self.loaded_model = None |
|
self.loaded_tokenizer = None |
|
self.current_model = None |
|
|
|
self.style_prompts = { |
|
"Nature": "Write a nature-inspired haiku about {name}, who is {traits}", |
|
"Urban": "Create a modern city haiku about {name}, characterized by {traits}", |
|
"Emotional": "Compose an emotional haiku capturing {name}'s essence: {traits}", |
|
"Reflective": "Write a contemplative haiku about {name}, focusing on {traits}" |
|
} |
|
|
|
def load_model(self, model_name): |
|
"""Load model with proper caching.""" |
|
|
|
@st.cache_resource |
|
def _load_model_cached(_model_name): |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
self.models[_model_name], |
|
trust_remote_code=True |
|
) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
self.models[_model_name], |
|
trust_remote_code=True, |
|
torch_dtype=torch.float16, |
|
low_cpu_mem_usage=True |
|
) |
|
if torch.cuda.is_available(): |
|
model = model.to("cuda") |
|
return model, tokenizer |
|
|
|
if self.current_model != model_name: |
|
|
|
if self.loaded_model is not None: |
|
del self.loaded_model |
|
del self.loaded_tokenizer |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
|
|
self.loaded_model, self.loaded_tokenizer = _load_model_cached(model_name) |
|
self.current_model = model_name |
|
|
|
def generate_haiku(self, name, traits, model_name, style): |
|
"""Generate a free-form haiku using the selected model.""" |
|
self.load_model(model_name) |
|
|
|
|
|
traits_text = ", ".join(traits) |
|
|
|
|
|
base_prompt = self.style_prompts[style].format(name=name, traits=traits_text) |
|
prompt = f"""{base_prompt} |
|
|
|
Create a free-form haiku that: |
|
- Uses imagery and metaphor |
|
- Captures a single moment |
|
- Reflects the character's essence |
|
|
|
Haiku:""" |
|
|
|
|
|
max_length = 100 |
|
if model_name == "Flan-T5": |
|
max_length = 50 |
|
|
|
|
|
inputs = self.loaded_tokenizer(prompt, return_tensors="pt") |
|
if torch.cuda.is_available(): |
|
inputs = inputs.to("cuda") |
|
|
|
with torch.no_grad(): |
|
outputs = self.loaded_model.generate( |
|
**inputs, |
|
max_length=max_length, |
|
num_return_sequences=1, |
|
temperature=0.9, |
|
top_p=0.9, |
|
do_sample=True, |
|
pad_token_id=self.loaded_tokenizer.eos_token_id |
|
) |
|
|
|
generated_text = self.loaded_tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
haiku_text = generated_text.split("Haiku:")[-1].strip() |
|
|
|
|
|
lines = [line.strip() for line in haiku_text.split('\n') if line.strip()] |
|
return lines[:3] |
|
|
|
|
|
def main(): |
|
st.title("π Free-Form Haiku Generator") |
|
st.write("Create unique AI-generated haikus about characters") |
|
|
|
|
|
@st.cache_resource |
|
def get_generator(): |
|
return SpacesHaikuGenerator() |
|
|
|
generator = get_generator() |
|
|
|
|
|
col1, col2 = st.columns([1, 2]) |
|
with col1: |
|
name = st.text_input("Character Name") |
|
|
|
|
|
traits = [] |
|
cols = st.columns(4) |
|
for i, col in enumerate(cols): |
|
label = "Trait" if i < 2 else "Hobby" if i == 2 else "Physical" |
|
trait = col.text_input(f"{label} {i + 1}") |
|
if trait: |
|
traits.append(trait) |
|
|
|
|
|
col1, col2 = st.columns(2) |
|
with col1: |
|
model = st.selectbox("Choose Model", list(generator.models.keys())) |
|
with col2: |
|
style = st.selectbox("Choose Style", list(generator.style_prompts.keys())) |
|
|
|
if name and len(traits) == 4: |
|
if st.button("Generate Haiku"): |
|
with st.spinner(f"Creating your haiku using {model}..."): |
|
try: |
|
haiku_lines = generator.generate_haiku(name, traits, model, style) |
|
|
|
|
|
st.markdown("---") |
|
for line in haiku_lines: |
|
st.markdown(f"*{line}*") |
|
st.markdown("---") |
|
|
|
|
|
st.caption(f"Style: {style} | Model: {model}") |
|
|
|
|
|
if st.button("Create Another"): |
|
st.experimental_rerun() |
|
|
|
except Exception as e: |
|
st.error(f"Generation error: {str(e)}") |
|
st.info("Try a different model or simplify your input.") |
|
|
|
|
|
st.sidebar.markdown(""" |
|
### Tips for Better Results: |
|
- Use vivid, descriptive traits |
|
- Mix concrete and abstract details |
|
- Try different models for variety |
|
- Experiment with styles |
|
""") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |