ginipick commited on
Commit
625f4fd
·
verified ·
1 Parent(s): e21788a

Upload 2 files

Browse files
Files changed (2) hide show
  1. app (28).py +295 -0
  2. requirements (9).txt +14 -0
app (28).py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import argparse
3
+ import os
4
+ import time
5
+ from os import path
6
+ import shutil
7
+ from datetime import datetime
8
+ from safetensors.torch import load_file
9
+ from huggingface_hub import hf_hub_download
10
+ import gradio as gr
11
+ import torch
12
+ from diffusers import FluxPipeline
13
+ from PIL import Image
14
+ from transformers import pipeline
15
+
16
+ translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
17
+
18
+ # Hugging Face 토큰 설정
19
+ HF_TOKEN = os.getenv("HF_TOKEN")
20
+ if HF_TOKEN is None:
21
+ raise ValueError("HF_TOKEN environment variable is not set")
22
+
23
+ # Setup and initialization code
24
+ cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
25
+ PERSISTENT_DIR = os.environ.get("PERSISTENT_DIR", ".")
26
+ gallery_path = path.join(PERSISTENT_DIR, "gallery")
27
+
28
+ os.environ["TRANSFORMERS_CACHE"] = cache_path
29
+ os.environ["HF_HUB_CACHE"] = cache_path
30
+ os.environ["HF_HOME"] = cache_path
31
+
32
+ torch.backends.cuda.matmul.allow_tf32 = True
33
+
34
+ # Create gallery directory if it doesn't exist
35
+ if not path.exists(gallery_path):
36
+ os.makedirs(gallery_path, exist_ok=True)
37
+
38
+ class timer:
39
+ def __init__(self, method_name="timed process"):
40
+ self.method = method_name
41
+ def __enter__(self):
42
+ self.start = time.time()
43
+ print(f"{self.method} starts")
44
+ def __exit__(self, exc_type, exc_val, exc_tb):
45
+ end = time.time()
46
+ print(f"{self.method} took {str(round(end - self.start, 2))}s")
47
+
48
+ # Model initialization
49
+ if not path.exists(cache_path):
50
+ os.makedirs(cache_path, exist_ok=True)
51
+
52
+ # 인증된 모델 로드
53
+ pipe = FluxPipeline.from_pretrained(
54
+ "black-forest-labs/FLUX.1-dev",
55
+ torch_dtype=torch.bfloat16,
56
+ use_auth_token=HF_TOKEN
57
+ )
58
+
59
+ # Hyper-SD LoRA 로드
60
+ pipe.load_lora_weights(
61
+ hf_hub_download(
62
+ "ByteDance/Hyper-SD",
63
+ "Hyper-FLUX.1-dev-8steps-lora.safetensors",
64
+ use_auth_token=HF_TOKEN
65
+ )
66
+ )
67
+ pipe.fuse_lora(lora_scale=0.125)
68
+ pipe.to(device="cuda", dtype=torch.bfloat16)
69
+
70
+ def save_image(image):
71
+ """Save the generated image and return the path"""
72
+ try:
73
+ if not os.path.exists(gallery_path):
74
+ try:
75
+ os.makedirs(gallery_path, exist_ok=True)
76
+ except Exception as e:
77
+ print(f"Failed to create gallery directory: {str(e)}")
78
+ return None
79
+
80
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
81
+ random_suffix = os.urandom(4).hex()
82
+ filename = f"generated_{timestamp}_{random_suffix}.png"
83
+ filepath = os.path.join(gallery_path, filename)
84
+
85
+ try:
86
+ if isinstance(image, Image.Image):
87
+ image.save(filepath, "PNG", quality=100)
88
+ else:
89
+ image = Image.fromarray(image)
90
+ image.save(filepath, "PNG", quality=100)
91
+
92
+ return filepath
93
+ except Exception as e:
94
+ print(f"Failed to save image: {str(e)}")
95
+ return None
96
+
97
+ except Exception as e:
98
+ print(f"Error in save_image: {str(e)}")
99
+ return None
100
+
101
+ # 예시 프롬프트 정의
102
+ examples = [
103
+ ["A 3D Star Wars Darth Vader helmet, highly detailed metallic finish"],
104
+ ["A 3D Iron Man mask with glowing eyes and metallic red-gold finish"],
105
+ ["A detailed 3D Pokemon Pikachu figure with glossy surface"],
106
+ ["A 3D geometric abstract cube transforming into a sphere, metallic finish"],
107
+ ["A 3D steampunk mechanical heart with brass and copper details"],
108
+ ["A 3D crystal dragon with transparent iridescent scales"],
109
+ ["A 3D futuristic hovering drone with neon light accents"],
110
+ ["A 3D ancient Greek warrior helmet with ornate details"],
111
+ ["A 3D robotic butterfly with mechanical wings and metallic finish"],
112
+ ["A 3D floating magical crystal orb with internal energy swirls"]
113
+ ]
114
+
115
+ @spaces.GPU
116
+ def process_and_save_image(height=1024, width=1024, steps=8, scales=3.5, prompt="", seed=None):
117
+ global pipe
118
+
119
+ if seed is None:
120
+ seed = torch.randint(0, 1000000, (1,)).item()
121
+
122
+ # 한글 감지 및 번역
123
+ def contains_korean(text):
124
+ return any(ord('가') <= ord(c) <= ord('힣') for c in text)
125
+
126
+ # 프롬프트 전처리
127
+ if contains_korean(prompt):
128
+ translated = translator(prompt)[0]['translation_text']
129
+ prompt = translated
130
+
131
+ formatted_prompt = f"wbgmsst, 3D, {prompt} ,white background"
132
+
133
+ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), timer("inference"):
134
+ try:
135
+ generated_image = pipe(
136
+ prompt=[formatted_prompt],
137
+ generator=torch.Generator().manual_seed(int(seed)),
138
+ num_inference_steps=int(steps),
139
+ guidance_scale=float(scales),
140
+ height=int(height),
141
+ width=int(width),
142
+ max_sequence_length=256
143
+ ).images[0]
144
+
145
+ saved_path = save_image(generated_image)
146
+ if saved_path is None:
147
+ print("Warning: Failed to save generated image")
148
+
149
+ return generated_image
150
+ except Exception as e:
151
+ print(f"Error in image generation: {str(e)}")
152
+ return None
153
+
154
+ def get_random_seed():
155
+ return torch.randint(0, 1000000, (1,)).item()
156
+
157
+
158
+ def process_example(prompt):
159
+ return process_and_save_image(
160
+ height=1024,
161
+ width=1024,
162
+ steps=8,
163
+ scales=3.5,
164
+ prompt=prompt,
165
+ seed=get_random_seed()
166
+ )
167
+
168
+
169
+ # Gradio 인터페이스
170
+ with gr.Blocks(
171
+ theme=gr.themes.Soft(),
172
+ css="""
173
+ .container {
174
+ background: linear-gradient(to bottom right, #1a1a1a, #4a4a4a);
175
+ border-radius: 20px;
176
+ padding: 20px;
177
+ }
178
+ .generate-btn {
179
+ background: linear-gradient(45deg, #2196F3, #00BCD4);
180
+ border: none;
181
+ color: white;
182
+ font-weight: bold;
183
+ border-radius: 10px;
184
+ }
185
+ .output-image {
186
+ border-radius: 15px;
187
+ box-shadow: 0 8px 16px rgba(0,0,0,0.2);
188
+ }
189
+ .fixed-width {
190
+ max-width: 1024px;
191
+ margin: auto;
192
+ }
193
+ """
194
+ ) as demo:
195
+ gr.HTML(
196
+ """
197
+ <div style="text-align: center; max-width: 800px; margin: 0 auto; padding: 20px;">
198
+ <h1 style="font-size: 2.5rem; color: #2196F3;">3D Style Image Generator</h1>
199
+ <p style="font-size: 1.2rem; color: #666;">Create amazing 3D-style images with AI</p>
200
+ </div>
201
+ """
202
+ )
203
+
204
+ with gr.Row(elem_classes="container"):
205
+ with gr.Column(scale=3):
206
+ prompt = gr.Textbox(
207
+ label="Image Description",
208
+ placeholder="Describe the 3D image you want to create...",
209
+ lines=3
210
+ )
211
+
212
+ with gr.Accordion("Advanced Settings", open=False):
213
+ with gr.Row():
214
+ height = gr.Slider(
215
+ label="Height",
216
+ minimum=256,
217
+ maximum=1152,
218
+ step=64,
219
+ value=1024
220
+ )
221
+ width = gr.Slider(
222
+ label="Width",
223
+ minimum=256,
224
+ maximum=1152,
225
+ step=64,
226
+ value=1024
227
+ )
228
+
229
+ with gr.Row():
230
+ steps = gr.Slider(
231
+ label="Inference Steps",
232
+ minimum=6,
233
+ maximum=25,
234
+ step=1,
235
+ value=8
236
+ )
237
+ scales = gr.Slider(
238
+ label="Guidance Scale",
239
+ minimum=0.0,
240
+ maximum=5.0,
241
+ step=0.1,
242
+ value=3.5
243
+ )
244
+
245
+ seed = gr.Number(
246
+ label="Seed (random by default, set for reproducibility)",
247
+ value=get_random_seed(),
248
+ precision=0
249
+ )
250
+
251
+ randomize_seed = gr.Button("🎲 Randomize Seed", elem_classes=["generate-btn"])
252
+
253
+ generate_btn = gr.Button(
254
+ "✨ Generate Image",
255
+ elem_classes=["generate-btn"]
256
+ )
257
+
258
+ with gr.Column(scale=4, elem_classes=["fixed-width"]):
259
+ output = gr.Image(
260
+ label="Generated Image",
261
+ elem_id="output-image",
262
+ elem_classes=["output-image", "fixed-width"],
263
+ value="3d.webp"
264
+ )
265
+
266
+ # Examples 섹션
267
+ gr.Examples(
268
+ examples=examples,
269
+ inputs=prompt,
270
+ outputs=output,
271
+ fn=process_example, # 수정된 함수 사용
272
+ cache_examples=False,
273
+ examples_per_page=5
274
+ )
275
+
276
+ def update_seed():
277
+ return get_random_seed()
278
+
279
+ # 이벤트 핸들러
280
+ generate_btn.click(
281
+ process_and_save_image,
282
+ inputs=[height, width, steps, scales, prompt, seed],
283
+ outputs=output
284
+ ).then(
285
+ update_seed,
286
+ outputs=[seed]
287
+ )
288
+
289
+ randomize_seed.click(
290
+ update_seed,
291
+ outputs=[seed]
292
+ )
293
+
294
+ if __name__ == "__main__":
295
+ demo.launch(allowed_paths=[PERSISTENT_DIR])
requirements (9).txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate
2
+ diffusers==0.30.0
3
+ invisible_watermark
4
+ torch
5
+ transformers==4.43.3
6
+ xformers
7
+ sentencepiece
8
+ peft
9
+
10
+ safetensors
11
+ gradio>=4.4.0
12
+ Pillow>=9.0.0
13
+ huggingface-hub>=0.19.0
14
+ sacremoses