azeus commited on
Commit
3599259
Β·
1 Parent(s): 556c085

initial commit

Browse files
.idea/Haiku.iml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$" />
5
+ <orderEntry type="inheritedJdk" />
6
+ <orderEntry type="sourceFolder" forTests="false" />
7
+ </component>
8
+ </module>
.idea/inspectionProfiles/Project_Default.xml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <profile version="1.0">
3
+ <option name="myName" value="Project Default" />
4
+ <inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
5
+ <option name="ignoredPackages">
6
+ <value>
7
+ <list size="2">
8
+ <item index="0" class="java.lang.String" itemvalue="streamlit" />
9
+ <item index="1" class="java.lang.String" itemvalue="plotly" />
10
+ </list>
11
+ </value>
12
+ </option>
13
+ </inspection_tool>
14
+ </profile>
15
+ </component>
.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="USE_PROJECT_PROFILE" value="false" />
4
+ <version value="1.0" />
5
+ </settings>
6
+ </component>
.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/Haiku.iml" filepath="$PROJECT_DIR$/.idea/Haiku.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
app.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import pipeline
3
+ import nltk
4
+ from nltk.corpus import cmudict
5
+ import random
6
+ import torch
7
+
8
+ # Download required NLTK data
9
+ nltk.download('cmudict')
10
+ d = cmudict.dict()
11
+
12
+
13
+ class HaikuGenerator:
14
+ def __init__(self):
15
+ # Initialize different language models
16
+ self.models = {
17
+ "BERT": pipeline('fill-mask', model='bert-base-uncased'),
18
+ "RoBERTa": pipeline('fill-mask', model='roberta-base'),
19
+ "GPT2": pipeline('text-generation', model='gpt2'),
20
+ "DistilBERT": pipeline('fill-mask', model='distilbert-base-uncased')
21
+ }
22
+
23
+ def count_syllables(self, word):
24
+ """Count syllables in a word using CMU dictionary."""
25
+ try:
26
+ return len([x for x in d[word.lower()][0] if x[-1].isdigit()])
27
+ except KeyError:
28
+ return len(''.join(c for c in word if c in 'aeiouAEIOU'))
29
+
30
+ def get_related_words(self, word, model_name):
31
+ """Get related words using different models."""
32
+ if model_name in ["BERT", "RoBERTa", "DistilBERT"]:
33
+ if model_name == "RoBERTa":
34
+ masked_text = f"The {word} is <mask>."
35
+ else:
36
+ masked_text = f"The {word} is [MASK]."
37
+ predictions = self.models[model_name](masked_text)
38
+ return [pred['token_str'].strip() for pred in predictions]
39
+ else: # GPT2
40
+ prompt = f"The word {word} reminds me of"
41
+ predictions = self.models[model_name](prompt, max_length=20, num_return_sequences=5)
42
+ words = []
43
+ for pred in predictions:
44
+ text = pred['generated_text'].split()
45
+ if len(text) > 6: # Get the first new word after the prompt
46
+ words.append(text[6])
47
+ return words
48
+
49
+ def create_themed_haiku(self, character_name, traits, model_name):
50
+ """Generate a themed haiku about a character using specified model."""
51
+ syllable_targets = [5, 7, 5]
52
+ lines = []
53
+
54
+ # Create word pool from character traits and related words
55
+ word_pool = set(traits)
56
+ word_pool.add(character_name)
57
+
58
+ for word in traits:
59
+ related = self.get_related_words(word, model_name)
60
+ word_pool.update(related)
61
+
62
+ # Add character-specific words based on traits
63
+ for trait in traits:
64
+ themed_words = self.get_related_words(trait, model_name)
65
+ word_pool.update(themed_words)
66
+
67
+ # Generate each line ensuring character name appears
68
+ for i, target in enumerate(syllable_targets):
69
+ current_line = []
70
+ current_syllables = 0
71
+
72
+ # Ensure character name appears in one of the lines (preferably first)
73
+ if i == 0 and self.count_syllables(character_name) <= target:
74
+ current_line.append(character_name)
75
+ current_syllables += self.count_syllables(character_name)
76
+
77
+ while current_syllables < target:
78
+ available_words = [w for w in word_pool
79
+ if current_syllables + self.count_syllables(w) <= target]
80
+ if not available_words:
81
+ break
82
+
83
+ word = random.choice(available_words)
84
+ current_line.append(word)
85
+ current_syllables += self.count_syllables(word)
86
+
87
+ lines.append(' '.join(current_line))
88
+
89
+ return lines
90
+
91
+
92
+ def main():
93
+ st.title("πŸŽ‹ Character Haiku Generator")
94
+ st.write("Generate unique haikus about a character using different AI models!")
95
+
96
+ # Initialize generator
97
+ generator = HaikuGenerator()
98
+
99
+ # Input fields
100
+ character_name = st.text_input("Character Name:")
101
+
102
+ # Four traits/characteristics
103
+ cols = st.columns(4)
104
+ traits = []
105
+ for i, col in enumerate(cols):
106
+ trait = col.text_input(f"Trait/Characteristic {i + 1}")
107
+ if trait:
108
+ traits.append(trait)
109
+
110
+ # Model selection
111
+ model_options = ["BERT", "RoBERTa", "DistilBERT", "GPT2"]
112
+ selected_models = st.multiselect(
113
+ "Select AI Models to Generate Haikus",
114
+ model_options,
115
+ default=["BERT"]
116
+ )
117
+
118
+ if character_name and len(traits) == 4 and st.button("Generate Haikus"):
119
+ st.subheader("Your Generated Haikus:")
120
+
121
+ for model in selected_models:
122
+ with st.expander(f"🎯 {model} Generated Haiku"):
123
+ with st.spinner(f"Crafting haiku using {model}..."):
124
+ haiku_lines = generator.create_themed_haiku(character_name, traits, model)
125
+
126
+ # Display haiku with styling
127
+ for line in haiku_lines:
128
+ st.write(line)
129
+
130
+ # Display syllable count
131
+ st.caption("Syllable count verification:")
132
+ for i, line in enumerate(haiku_lines):
133
+ syllables = sum(generator.count_syllables(word)
134
+ for word in line.split())
135
+ st.caption(f"Line {i + 1}: {syllables} syllables")
136
+
137
+
138
+ if __name__ == "__main__":
139
+ main()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ streamlit==1.31.0
2
+ transformers==4.36.0
3
+ torch==2.1.0
4
+ nltk==3.8.1