azeus commited on
Commit
a92e324
·
1 Parent(s): dab360e

better models

Browse files
Files changed (2) hide show
  1. app.py +129 -97
  2. requirements.txt +5 -4
app.py CHANGED
@@ -1,130 +1,162 @@
1
  import streamlit as st
2
- from transformers import pipeline
3
- import random
 
4
 
5
 
6
- class FreeFormHaikuGenerator:
7
  def __init__(self):
8
  self.models = {
9
- "GPT2": pipeline('text-generation', model='gpt2'),
10
- "BERT": pipeline('fill-mask', model='bert-base-uncased'),
11
- "RoBERTa": pipeline('fill-mask', model='roberta-base')
 
12
  }
13
 
 
 
 
 
14
  self.style_prompts = {
15
- "Minimalist": "Write a short profound observation about",
16
- "Emotional": "Express deep feelings about",
17
- "Nature": "Describe the natural essence of",
18
- "Urban": "Capture a city moment with",
19
- "Surreal": "Create a dreamlike scene with"
20
  }
21
 
22
- def generate_associations(self, word, model_name):
23
- """Generate associated words and phrases using different models."""
24
- if model_name == "GPT2":
25
- prompt = f"{word} makes me think of"
26
- response = self.models[model_name](prompt,
27
- max_length=20,
28
- num_return_sequences=3,
29
- temperature=0.9)
30
- return [r['generated_text'].split()[-1] for r in response]
31
-
32
- else: # BERT or RoBERTa
33
- mask_token = "[MASK]" if model_name == "BERT" else "<mask>"
34
- prompts = [
35
- f"The {word} is like {mask_token}",
36
- f"{word} reminds me of {mask_token}",
37
- f"{word} feels {mask_token}"
38
- ]
39
- associations = []
40
- for prompt in prompts:
41
- result = self.models[model_name](prompt)
42
- associations.extend([r['token_str'] for r in result[:2]])
43
- return associations
44
-
45
- def create_freeform_haiku(self, name, traits, model_name, style):
46
- """Create a free-form haiku based on character traits and selected style."""
47
- # Generate word associations for each trait
48
- word_pool = []
49
- for trait in [name] + traits:
50
- word_pool.extend(self.generate_associations(trait, model_name))
51
-
52
- # Create base prompt for the style
53
- base_prompt = f"{self.style_prompts[style]} {name} who is {', '.join(traits)}"
54
-
55
- if model_name == "GPT2":
56
- # Generate three short phrases for our haiku
57
- lines = []
58
- for i in range(3):
59
- prompt = f"{base_prompt}. Line {i + 1}:"
60
- response = self.models[model_name](prompt,
61
- max_length=30,
62
- temperature=0.9,
63
- num_return_sequences=1)
64
- # Extract a meaningful phrase from the response
65
- generated_text = response[0]['generated_text'].split('\n')[0]
66
- clean_line = ' '.join(generated_text.split()[-4:]) # Take last 4 words
67
- lines.append(clean_line)
68
- else:
69
- # For BERT/RoBERTa, construct lines using masked predictions
70
- lines = []
71
- random.shuffle(word_pool)
72
- for i in range(3):
73
- if i == 0:
74
- line = f"{name} {random.choice(word_pool)}"
75
- else:
76
- line = f"{random.choice(word_pool)} {random.choice(word_pool)}"
77
- lines.append(line.strip())
78
-
79
- return lines
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
 
82
  def main():
83
- st.title("🎋 Free-Form Character Haiku Generator")
84
- st.write("Breaking traditional rules to create unique, AI-generated haiku!")
85
 
86
  # Initialize generator
87
- generator = FreeFormHaikuGenerator()
88
 
89
  # Input fields
90
  col1, col2 = st.columns([1, 2])
91
  with col1:
92
  name = st.text_input("Character Name")
93
 
 
94
  traits = []
95
  cols = st.columns(4)
96
  for i, col in enumerate(cols):
97
- trait = col.text_input(f"{'Trait' if i < 2 else 'Hobby' if i == 2 else 'Physical'} {i + 1}")
 
98
  if trait:
99
  traits.append(trait)
100
 
101
- # Style and model selection
102
  col1, col2 = st.columns(2)
103
  with col1:
104
- style = st.selectbox("Choose Style",
105
- list(generator.style_prompts.keys()))
106
  with col2:
107
- model = st.selectbox("Choose AI Model",
108
- ["GPT2", "BERT", "RoBERTa"])
109
-
110
- if name and len(traits) == 4 and st.button("Generate Free-Form Haiku"):
111
- st.subheader("Your Generated Haiku:")
112
-
113
- with st.spinner(f"Crafting a {style} haiku using {model}..."):
114
- haiku_lines = generator.create_freeform_haiku(name, traits, model, style)
115
-
116
- # Display haiku with styling
117
- st.markdown("---")
118
- for line in haiku_lines:
119
- st.markdown(f"*{line}*")
120
- st.markdown("---")
121
-
122
- # Add some context
123
- st.caption(f"Style: {style} | Model: {model}")
124
-
125
- # Option to regenerate
126
- if st.button("Generate Another Version"):
127
- st.experimental_rerun()
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
 
130
  if __name__ == "__main__":
 
1
  import streamlit as st
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
3
+ import torch
4
+ import gc
5
 
6
 
7
+ class SpacesHaikuGenerator:
8
  def __init__(self):
9
  self.models = {
10
+ "TinyLlama": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
11
+ "Flan-T5": "google/flan-t5-large",
12
+ "GPT2-Medium": "gpt2-medium",
13
+ "BART": "facebook/bart-large"
14
  }
15
 
16
+ self.loaded_model = None
17
+ self.loaded_tokenizer = None
18
+ self.current_model = None
19
+
20
  self.style_prompts = {
21
+ "Nature": "Write a nature-inspired haiku about {name}, who is {traits}",
22
+ "Urban": "Create a modern city haiku about {name}, characterized by {traits}",
23
+ "Emotional": "Compose an emotional haiku capturing {name}'s essence: {traits}",
24
+ "Reflective": "Write a contemplative haiku about {name}, focusing on {traits}"
 
25
  }
26
 
27
+ @st.cache_resource
28
+ def load_model(self, model_name):
29
+ """Load model with caching for Streamlit."""
30
+ if self.current_model != model_name:
31
+ # Clear previous model
32
+ if self.loaded_model is not None:
33
+ del self.loaded_model
34
+ del self.loaded_tokenizer
35
+ torch.cuda.empty_cache()
36
+ gc.collect()
37
+
38
+ # Load new model
39
+ self.loaded_tokenizer = AutoTokenizer.from_pretrained(
40
+ self.models[model_name],
41
+ trust_remote_code=True
42
+ )
43
+ self.loaded_model = AutoModelForCausalLM.from_pretrained(
44
+ self.models[model_name],
45
+ trust_remote_code=True,
46
+ torch_dtype=torch.float16,
47
+ low_cpu_mem_usage=True
48
+ )
49
+ self.current_model = model_name
50
+
51
+ if torch.cuda.is_available():
52
+ self.loaded_model = self.loaded_model.to("cuda")
53
+
54
+ def generate_haiku(self, name, traits, model_name, style):
55
+ """Generate a free-form haiku using the selected model."""
56
+ self.load_model(model_name)
57
+
58
+ # Format traits for prompt
59
+ traits_text = ", ".join(traits)
60
+
61
+ # Construct prompt based on model
62
+ base_prompt = self.style_prompts[style].format(name=name, traits=traits_text)
63
+ prompt = f"""{base_prompt}
64
+
65
+ Create a free-form haiku that:
66
+ - Uses imagery and metaphor
67
+ - Captures a single moment
68
+ - Reflects the character's essence
69
+
70
+ Haiku:"""
71
+
72
+ # Configure generation parameters based on model
73
+ max_length = 100
74
+ if model_name == "Flan-T5":
75
+ max_length = 50 # T5 tends to be more concise
76
+
77
+ # Generate text
78
+ inputs = self.loaded_tokenizer(prompt, return_tensors="pt")
79
+ if torch.cuda.is_available():
80
+ inputs = inputs.to("cuda")
81
+
82
+ with torch.no_grad():
83
+ outputs = self.loaded_model.generate(
84
+ **inputs,
85
+ max_length=max_length,
86
+ num_return_sequences=1,
87
+ temperature=0.9,
88
+ top_p=0.9,
89
+ do_sample=True,
90
+ pad_token_id=self.loaded_tokenizer.eos_token_id
91
+ )
92
+
93
+ generated_text = self.loaded_tokenizer.decode(outputs[0], skip_special_tokens=True)
94
+ haiku_text = generated_text.split("Haiku:")[-1].strip()
95
+
96
+ # Format into three lines
97
+ lines = [line.strip() for line in haiku_text.split('\n') if line.strip()]
98
+ return lines[:3] # Ensure exactly 3 lines
99
 
100
 
101
  def main():
102
+ st.title("🎋 Free-Form Haiku Generator")
103
+ st.write("Create unique AI-generated haikus about characters")
104
 
105
  # Initialize generator
106
+ generator = SpacesHaikuGenerator()
107
 
108
  # Input fields
109
  col1, col2 = st.columns([1, 2])
110
  with col1:
111
  name = st.text_input("Character Name")
112
 
113
+ # Four traits in a grid
114
  traits = []
115
  cols = st.columns(4)
116
  for i, col in enumerate(cols):
117
+ label = "Trait" if i < 2 else "Hobby" if i == 2 else "Physical"
118
+ trait = col.text_input(f"{label} {i + 1}")
119
  if trait:
120
  traits.append(trait)
121
 
122
+ # Model and style selection
123
  col1, col2 = st.columns(2)
124
  with col1:
125
+ model = st.selectbox("Choose Model", list(generator.models.keys()))
 
126
  with col2:
127
+ style = st.selectbox("Choose Style", list(generator.style_prompts.keys()))
128
+
129
+ if name and len(traits) == 4:
130
+ if st.button("Generate Haiku"):
131
+ with st.spinner(f"Creating your haiku using {model}..."):
132
+ try:
133
+ haiku_lines = generator.generate_haiku(name, traits, model, style)
134
+
135
+ # Display haiku
136
+ st.markdown("---")
137
+ for line in haiku_lines:
138
+ st.markdown(f"*{line}*")
139
+ st.markdown("---")
140
+
141
+ # Metadata
142
+ st.caption(f"Style: {style} | Model: {model}")
143
+
144
+ # Regenerate option
145
+ if st.button("Create Another"):
146
+ st.experimental_rerun()
147
+
148
+ except Exception as e:
149
+ st.error(f"Generation error: {str(e)}")
150
+ st.info("Try a different model or simplify your input.")
151
+
152
+ # Tips sidebar
153
+ st.sidebar.markdown("""
154
+ ### Tips for Better Results:
155
+ - Use vivid, descriptive traits
156
+ - Mix concrete and abstract details
157
+ - Try different models for variety
158
+ - Experiment with styles
159
+ """)
160
 
161
 
162
  if __name__ == "__main__":
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
- streamlit==1.31.0
2
- transformers==4.36.0
3
- torch==2.1.0
4
- nltk==3.8.1
 
 
1
+ streamlit
2
+ transformers
3
+ torch
4
+ sentencepiece
5
+ accelerate