aiqcamp commited on
Commit
dd5d6cc
Β·
verified Β·
1 Parent(s): 0d770e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +263 -209
app.py CHANGED
@@ -1,229 +1,283 @@
1
- import random
2
  import os
3
- import uuid
4
- from datetime import datetime
5
- import gradio as gr
6
- import numpy as np
7
- import spaces
8
- import torch
9
- from diffusers import DiffusionPipeline
10
- from PIL import Image
11
-
12
- # Create permanent storage directory
13
- SAVE_DIR = "saved_images" # Gradio will handle the persistence
14
- if not os.path.exists(SAVE_DIR):
15
- os.makedirs(SAVE_DIR, exist_ok=True)
16
-
17
- device = "cuda" if torch.cuda.is_available() else "cpu"
18
- repo_id = "black-forest-labs/FLUX.1-dev"
19
- adapter_id = "openfree/korea-president-yoon"
20
-
21
- pipeline = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.bfloat16)
22
- pipeline.load_lora_weights(adapter_id)
23
- pipeline = pipeline.to(device)
24
-
25
- MAX_SEED = np.iinfo(np.int32).max
26
- MAX_IMAGE_SIZE = 1024
27
-
28
- def save_generated_image(image, prompt):
29
- # Generate unique filename with timestamp
30
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
31
- unique_id = str(uuid.uuid4())[:8]
32
- filename = f"{timestamp}_{unique_id}.png"
33
- filepath = os.path.join(SAVE_DIR, filename)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- # Save the image
36
- image.save(filepath)
 
 
37
 
38
- # Save metadata
39
- metadata_file = os.path.join(SAVE_DIR, "metadata.txt")
40
- with open(metadata_file, "a", encoding="utf-8") as f:
41
- f.write(f"{filename}|{prompt}|{timestamp}\n")
42
 
43
- return filepath
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
- def load_generated_images():
46
- if not os.path.exists(SAVE_DIR):
47
- return []
 
 
 
 
 
48
 
49
- # Load all images from the directory
50
- image_files = [os.path.join(SAVE_DIR, f) for f in os.listdir(SAVE_DIR)
51
- if f.endswith(('.png', '.jpg', '.jpeg', '.webp'))]
52
- # Sort by creation time (newest first)
53
- image_files.sort(key=lambda x: os.path.getctime(x), reverse=True)
54
- return image_files
55
-
56
- def load_predefined_images():
57
- # Return empty list since we're not using predefined images
58
- return []
59
-
60
- @spaces.GPU(duration=120)
61
- def inference(
62
- prompt: str,
63
- seed: int,
64
- randomize_seed: bool,
65
- width: int,
66
- height: int,
67
- guidance_scale: float,
68
- num_inference_steps: int,
69
- lora_scale: float,
70
- progress: gr.Progress = gr.Progress(track_tqdm=True),
71
- ):
72
- if randomize_seed:
73
- seed = random.randint(0, MAX_SEED)
74
- generator = torch.Generator(device=device).manual_seed(seed)
75
 
76
- image = pipeline(
77
- prompt=prompt,
78
- guidance_scale=guidance_scale,
79
- num_inference_steps=num_inference_steps,
80
- width=width,
81
- height=height,
82
- generator=generator,
83
- joint_attention_kwargs={"scale": lora_scale},
84
- ).images[0]
 
85
 
86
- # Save the generated image
87
- filepath = save_generated_image(image, prompt)
88
 
89
- # Return the image, seed, and updated gallery
90
- return image, seed, load_generated_images()
91
 
92
- examples = [
93
- "A man playing fetch with a golden retriever in a sunny park. He wears casual weekend clothes and throws a red frisbee with joy. The dog leaps gracefully through the air, tail wagging with excitement. Warm afternoon sunlight filters through the trees, creating a peaceful scene of companionship. [president yoon]",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
- "A soldier standing at attention in full military gear, holding a standard-issue rifle. His uniform is crisp and properly adorned with medals. Behind him, other soldiers march in formation during a military parade. The scene conveys discipline and duty. [president yoon]",
96
 
97
- "A medieval knight in gleaming armor, holding an ornate sword and shield. He stands proudly in front of a majestic castle, his cape flowing in the wind. The shield bears intricate heraldic designs, and sunlight glints off his polished armor. [president yoon]",
 
 
 
 
98
 
99
- "A charismatic political leader addressing a crowd from a podium. He wears a well-fitted suit and gestures confidently while speaking. The audience fills a large plaza, holding supportive banners and signs. News cameras capture the moment as he delivers his speech. [president yoon]",
 
 
 
 
 
 
100
 
101
- "A man enjoying a peaceful morning at home, reading a newspaper at his breakfast table. He wears comfortable home clothes and sips coffee from a favorite mug. Sunlight streams through the kitchen window, and a house plant adds a touch of nature to the cozy domestic scene. [president yoon]",
 
102
 
103
- "A businessman walking confidently through a modern office building. He carries a leather briefcase and wears a tailored navy suit. Floor-to-ceiling windows reveal a cityscape behind him, and his expression shows determination and purpose. [president yoon]"
104
- ]
 
 
 
 
 
105
 
106
- css = """
107
- footer {
108
- visibility: hidden;
109
- }
110
- """
111
 
112
- with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css, analytics_enabled=False) as demo:
113
- gr.HTML('<div class="title"> President Yoon in KOREA </div>')
 
 
 
114
 
115
- gr.HTML("""<a href="https://visitorbadge.io/status?path=https%3A%2F%2Fopenfree-korea-president-yoon.hf.space">
116
- <img src="https://api.visitorbadge.io/api/visitors?path=https%3A%2F%2Fopenfree-korea-president-yoon.hf.space&countColor=%23263759" />
117
- </a>""")
118
-
119
- with gr.Tabs() as tabs:
120
- with gr.Tab("Generation"):
121
- with gr.Column(elem_id="col-container"):
122
- with gr.Row():
123
- prompt = gr.Text(
124
- label="Prompt",
125
- show_label=False,
126
- max_lines=1,
127
- placeholder="Enter your prompt",
128
- container=False,
129
- )
130
- run_button = gr.Button("Run", scale=0)
131
-
132
- result = gr.Image(label="Result", show_label=False)
133
-
134
- with gr.Accordion("Advanced Settings", open=False):
135
- seed = gr.Slider(
136
- label="Seed",
137
- minimum=0,
138
- maximum=MAX_SEED,
139
- step=1,
140
- value=42,
141
- )
142
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
143
-
144
- with gr.Row():
145
- width = gr.Slider(
146
- label="Width",
147
- minimum=256,
148
- maximum=MAX_IMAGE_SIZE,
149
- step=32,
150
- value=1024,
151
- )
152
- height = gr.Slider(
153
- label="Height",
154
- minimum=256,
155
- maximum=MAX_IMAGE_SIZE,
156
- step=32,
157
- value=768,
158
- )
159
-
160
- with gr.Row():
161
- guidance_scale = gr.Slider(
162
- label="Guidance scale",
163
- minimum=0.0,
164
- maximum=10.0,
165
- step=0.1,
166
- value=3.5,
167
- )
168
- num_inference_steps = gr.Slider(
169
- label="Number of inference steps",
170
- minimum=1,
171
- maximum=50,
172
- step=1,
173
- value=30,
174
- )
175
- lora_scale = gr.Slider(
176
- label="LoRA scale",
177
- minimum=0.0,
178
- maximum=1.0,
179
- step=0.1,
180
- value=1.0,
181
- )
182
-
183
- gr.Examples(
184
- examples=examples,
185
- inputs=[prompt],
186
- outputs=[result, seed],
187
- )
188
-
189
- with gr.Tab("Gallery"):
190
- gallery_header = gr.Markdown("### Generated Images Gallery")
191
- generated_gallery = gr.Gallery(
192
- label="Generated Images",
193
- columns=6,
194
- show_label=False,
195
- value=load_generated_images(),
196
- elem_id="generated_gallery",
197
- height="auto"
198
- )
199
- refresh_btn = gr.Button("πŸ”„ Refresh Gallery")
200
-
201
-
202
- # Event handlers
203
- def refresh_gallery():
204
- return load_generated_images()
205
-
206
- refresh_btn.click(
207
- fn=refresh_gallery,
208
- inputs=None,
209
- outputs=generated_gallery,
210
- )
211
 
212
- gr.on(
213
- triggers=[run_button.click, prompt.submit],
214
- fn=inference,
215
- inputs=[
216
- prompt,
217
- seed,
218
- randomize_seed,
219
- width,
220
- height,
221
- guidance_scale,
222
- num_inference_steps,
223
- lora_scale,
224
- ],
225
- outputs=[result, seed, generated_gallery],
226
- )
227
 
228
- demo.queue()
229
- demo.launch()
 
1
+ # app.py
2
  import os
3
+ import base64
4
+ import streamlit as st
5
+ from gradio_client import Client
6
+ from dotenv import load_dotenv
7
+ from pathlib import Path
8
+ import json
9
+ import hashlib
10
+ import time
11
+ from typing import Dict, Any
12
+
13
+ # Load environment variables
14
+ load_dotenv()
15
+ HF_TOKEN = os.getenv("HF_TOKEN")
16
+
17
+ # Cache directory setup
18
+ CACHE_DIR = Path("./cache")
19
+ CACHE_DIR.mkdir(exist_ok=True)
20
+
21
+ # Cached example diagrams
22
+ CACHED_EXAMPLES = {
23
+ "literacy_mental": {
24
+ "title": "Literacy Mental Map",
25
+ "prompt": """A handrawn colorful mind map diagram, rugosity drawn lines, clear shapes, brain silhouette, text areas. must include the texts
26
+ LITERACY/MENTAL
27
+ β”œβ”€β”€ PEACE [Dove Icon]
28
+ β”œβ”€β”€ HEALTH [Vitruvian Man ~60px]
29
+ β”œβ”€β”€ CONNECT [Brain-Mind Connection Icon]
30
+ β”œβ”€β”€ INTELLIGENCE
31
+ β”‚ └── EVERYTHING [Globe Icon ~50px]
32
+ └── MEMORY
33
+ β”œβ”€β”€ READING [Book Icon ~40px]
34
+ β”œβ”€β”€ SPEED [Speedometer Icon]
35
+ └── CREATIVITY
36
+ └── INTELLIGENCE [Lightbulb + Infinity ~30px]""",
37
+ "width": 1024,
38
+ "height": 1024,
39
+ "seed": 1872187377,
40
+ "cache_path": "literacy_mental.png"
41
+ }
42
+ }
43
+
44
+ # Example diagrams for various use cases
45
+ DIAGRAM_EXAMPLES = [
46
+ {
47
+ "title": "Project Management Flow",
48
+ "prompt": """A handrawn colorful mind map diagram, rugosity drawn lines, clear shapes, project management flow.
49
+ PROJECT MANAGEMENT
50
+ β”œβ”€β”€ INITIATION [Rocket Icon]
51
+ β”œβ”€β”€ PLANNING [Calendar Icon]
52
+ β”œβ”€β”€ EXECUTION [Gear Icon]
53
+ β”œβ”€β”€ MONITORING
54
+ β”‚ └── CONTROL [Dashboard Icon]
55
+ └── CLOSURE [Checkmark Icon]""",
56
+ "width": 1024,
57
+ "height": 1024
58
+ },
59
+ {
60
+ "title": "Digital Marketing Strategy",
61
+ "prompt": """A handrawn colorful mind map diagram, rugosity drawn lines, modern style, marketing concept.
62
+ DIGITAL MARKETING
63
+ β”œβ”€β”€ SEO [Magnifying Glass]
64
+ β”œβ”€β”€ SOCIAL MEDIA [Network Icon]
65
+ β”œβ”€β”€ CONTENT
66
+ β”‚ β”œβ”€β”€ BLOG [Document Icon]
67
+ β”‚ └── VIDEO [Play Button]
68
+ └── ANALYTICS [Graph Icon]""",
69
+ "width": 1024,
70
+ "height": 1024
71
+ }
72
+ ]
73
+
74
+ # Add 15 more examples
75
+ ADDITIONAL_EXAMPLES = [
76
+ {
77
+ "title": "Health & Wellness",
78
+ "prompt": """A handrawn colorful mind map diagram, wellness-focused style, health aspects.
79
+ WELLNESS
80
+ β”œβ”€β”€ PHYSICAL [Dumbbell Icon]
81
+ β”œβ”€β”€ MENTAL [Brain Icon]
82
+ β”œβ”€β”€ NUTRITION [Apple Icon]
83
+ └── SLEEP
84
+ β”œβ”€β”€ QUALITY [Star Icon]
85
+ └── DURATION [Clock Icon]""",
86
+ "width": 1024,
87
+ "height": 1024
88
+ }
89
+ # ... (λ‚˜λ¨Έμ§€ μ˜ˆμ œλ“€)
90
+ ]
91
+
92
+
93
+
94
+
95
+
96
+
97
+
98
+
99
+
100
+
101
+
102
+ class DiagramCache:
103
+ def __init__(self, cache_dir: Path):
104
+ self.cache_dir = cache_dir
105
+ self.cache_dir.mkdir(exist_ok=True)
106
+ self._load_cache()
107
+
108
+ def _load_cache(self):
109
+ """Load existing cache entries"""
110
+ self.cache_index = {}
111
+ if (self.cache_dir / "cache_index.json").exists():
112
+ with open(self.cache_dir / "cache_index.json", "r") as f:
113
+ self.cache_index = json.load(f)
114
 
115
+ def _save_cache_index(self):
116
+ """Save cache index to disk"""
117
+ with open(self.cache_dir / "cache_index.json", "w") as f:
118
+ json.dump(self.cache_index, f)
119
 
120
+ def _get_cache_key(self, params: Dict[str, Any]) -> str:
121
+ """Generate cache key from parameters"""
122
+ param_str = json.dumps(params, sort_keys=True)
123
+ return hashlib.md5(param_str.encode()).hexdigest()
124
 
125
+ def get(self, params: Dict[str, Any]) -> Path:
126
+ """Get cached result if exists"""
127
+ cache_key = self._get_cache_key(params)
128
+ cache_info = self.cache_index.get(cache_key)
129
+ if cache_info:
130
+ cache_path = self.cache_dir / cache_info["filename"]
131
+ if cache_path.exists():
132
+ return cache_path
133
+ return None
134
+
135
+ def put(self, params: Dict[str, Any], result_path: Path):
136
+ """Store result in cache"""
137
+ cache_key = self._get_cache_key(params)
138
+ filename = f"{cache_key}{result_path.suffix}"
139
+ cache_path = self.cache_dir / filename
140
+
141
+ # Copy result to cache
142
+ with open(result_path, "rb") as src, open(cache_path, "wb") as dst:
143
+ dst.write(src.read())
144
+
145
+ # Update index
146
+ self.cache_index[cache_key] = {
147
+ "filename": filename,
148
+ "timestamp": time.time(),
149
+ "params": params
150
+ }
151
+ self._save_cache_index()
152
+
153
 
154
+ # Initialize cache
155
+ diagram_cache = DiagramCache(CACHE_DIR)
156
+
157
+ @st.cache_data
158
+ def generate_cached_example(example_id: str) -> str:
159
+ """Generate and cache example diagram"""
160
+ example = CACHED_EXAMPLES[example_id]
161
+ client = Client("black-forest-labs/FLUX.1-schnell")
162
 
163
+ # Check cache first
164
+ cache_path = diagram_cache.get(example)
165
+ if cache_path:
166
+ with open(cache_path, "rb") as f:
167
+ return base64.b64encode(f.read()).decode()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
+ # Generate new image
170
+ result = client.predict(
171
+ prompt=example["prompt"],
172
+ seed=example["seed"],
173
+ randomize_seed=False,
174
+ width=example["width"],
175
+ height=example["height"],
176
+ num_inference_steps=4,
177
+ api_name="/infer"
178
+ )
179
 
180
+ # Cache the result
181
+ diagram_cache.put(example, Path(result))
182
 
183
+ with open(result, "rb") as f:
184
+ return base64.b64encode(f.read()).decode()
185
 
186
+ def generate_diagram(prompt: str, width: int, height: int, seed: int = None) -> str:
187
+ """Generate a new diagram"""
188
+ client = Client("black-forest-labs/FLUX.1-schnell")
189
+ params = {
190
+ "prompt": prompt,
191
+ "seed": seed if seed else 1872187377,
192
+ "width": width,
193
+ "height": height
194
+ }
195
+
196
+ # Check cache first
197
+ cache_path = diagram_cache.get(params)
198
+ if cache_path:
199
+ with open(cache_path, "rb") as f:
200
+ return base64.b64encode(f.read()).decode()
201
+
202
+ # Generate new image
203
+ try:
204
+ result = client.predict(
205
+ prompt=prompt,
206
+ seed=params["seed"],
207
+ randomize_seed=False,
208
+ width=width,
209
+ height=height,
210
+ num_inference_steps=4,
211
+ api_name="/infer"
212
+ )
213
+
214
+ # Cache the result
215
+ diagram_cache.put(params, Path(result))
216
+
217
+ with open(result, "rb") as f:
218
+ return base64.b64encode(f.read()).decode()
219
+ except Exception as e:
220
+ st.error(f"Error generating diagram: {str(e)}")
221
+ return None
222
 
 
223
 
224
+ def main():
225
+ st.set_page_config(page_title="FLUX Diagram Generator", layout="wide")
226
+
227
+ st.title("🎨 FLUX Diagram Generator")
228
+ st.markdown("Generate beautiful hand-drawn style diagrams using FLUX AI")
229
 
230
+ # Sidebar for examples
231
+ st.sidebar.title("πŸ“š Example Templates")
232
+ selected_example = st.sidebar.selectbox(
233
+ "Choose a template",
234
+ options=range(len(DIAGRAM_EXAMPLES)),
235
+ format_func=lambda x: DIAGRAM_EXAMPLES[x]["title"]
236
+ )
237
 
238
+ # Main content area
239
+ col1, col2 = st.columns([2, 1])
240
 
241
+ with col1:
242
+ # Input area
243
+ prompt = st.text_area(
244
+ "Diagram Prompt",
245
+ value=DIAGRAM_EXAMPLES[selected_example]["prompt"],
246
+ height=200
247
+ )
248
 
249
+ # Configuration
250
+ with st.expander("Advanced Configuration"):
251
+ width = st.number_input("Width", min_value=512, max_value=2048, value=1024, step=128)
252
+ height = st.number_input("Height", min_value=512, max_value=2048, value=1024, step=128)
253
+ seed = st.number_input("Seed (optional)", value=None, step=1)
254
 
255
+ if st.button("🎨 Generate Diagram"):
256
+ with st.spinner("Generating your diagram..."):
257
+ result = generate_diagram(prompt, width, height, seed)
258
+ if result:
259
+ st.image(result, caption="Generated Diagram", use_column_width=True)
260
 
261
+ with col2:
262
+ st.subheader("Tips for Better Results")
263
+ st.markdown("""
264
+ - Use clear hierarchical structures
265
+ - Include icon descriptions in brackets
266
+ - Keep text concise and meaningful
267
+ - Use consistent formatting
268
+ """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
 
270
+ st.subheader("Template Structure")
271
+ st.code("""
272
+ MAIN TOPIC
273
+ β”œβ”€β”€ SUBTOPIC 1 [Icon]
274
+ β”œβ”€β”€ SUBTOPIC 2 [Icon]
275
+ └── SUBTOPIC 3
276
+ β”œβ”€β”€ DETAIL 1 [Icon]
277
+ └── DETAIL 2 [Icon]
278
+ """)
279
+
280
+ if __name__ == "__main__":
281
+ main()
 
 
 
282
 
283
+