ginipick commited on
Commit
95a9680
ยท
verified ยท
1 Parent(s): 903bf81

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -622
app.py CHANGED
@@ -1,622 +0,0 @@
1
- import gradio as gr
2
- import spaces
3
- from gradio_litmodel3d import LitModel3D
4
- import os
5
- import time
6
- from os import path
7
- import shutil
8
- from datetime import datetime
9
- from safetensors.torch import load_file
10
- from huggingface_hub import hf_hub_download
11
- import torch
12
- import numpy as np
13
- import imageio
14
- import uuid
15
- from easydict import EasyDict as edict
16
- from PIL import Image
17
- from trellis.pipelines import TrellisImageTo3DPipeline
18
- from trellis.representations import Gaussian, MeshExtractResult
19
- from trellis.utils import render_utils, postprocessing_utils
20
- from diffusers import FluxPipeline
21
- from typing import Tuple, Dict, Any # Tuple import ์ถ”๊ฐ€
22
- # ํŒŒ์ผ ์ƒ๋‹จ์˜ import ๋ฌธ
23
- import transformers
24
- from transformers import pipeline as transformers_pipeline
25
- from transformers import Pipeline
26
- import gc # ํŒŒ์ผ ์ƒ๋‹จ์— ์ถ”๊ฐ€
27
-
28
- # ์ „์—ญ ๋ณ€์ˆ˜ ์ดˆ๊ธฐํ™”
29
- class GlobalVars:
30
- def __init__(self):
31
- self.translator = None
32
- self.trellis_pipeline = None
33
- self.flux_pipe = None
34
-
35
- g = GlobalVars()
36
-
37
- # ํŒŒ์ผ ์ƒ๋‹จ์— ์ถ”๊ฐ€
38
- torch.backends.cudnn.benchmark = False # ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰ ๊ฐ์†Œ
39
- torch.backends.cudnn.deterministic = True
40
- torch.cuda.set_per_process_memory_fraction(0.7) # GPU ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰ ์ œํ•œ
41
-
42
- def initialize_models(device):
43
- try:
44
- print("Initializing models...")
45
- g.translator = transformers_pipeline(
46
- "translation",
47
- model="Helsinki-NLP/opus-mt-ko-en",
48
- device=device
49
- )
50
- print("Model initialization completed successfully")
51
-
52
- # 3D ์ƒ์„ฑ ํŒŒ์ดํ”„๋ผ์ธ
53
- g.trellis_pipeline = TrellisImageTo3DPipeline.from_pretrained(
54
- "JeffreyXiang/TRELLIS-image-large"
55
- )
56
- print("TrellisImageTo3DPipeline loaded successfully")
57
-
58
- # ์ด๋ฏธ์ง€ ์ƒ์„ฑ ํŒŒ์ดํ”„๋ผ์ธ
59
- print("Loading flux_pipe...")
60
- g.flux_pipe = FluxPipeline.from_pretrained(
61
- "black-forest-labs/FLUX.1-dev",
62
- torch_dtype=torch.bfloat16,
63
- device_map="balanced"
64
- )
65
- print("FluxPipeline loaded successfully")
66
-
67
- # Hyper-SD LoRA ๋กœ๋“œ
68
- print("Loading LoRA weights...")
69
- lora_path = hf_hub_download(
70
- "ByteDance/Hyper-SD",
71
- "Hyper-FLUX.1-dev-8steps-lora.safetensors",
72
- use_auth_token=HF_TOKEN
73
- )
74
- g.flux_pipe.load_lora_weights(lora_path)
75
- g.flux_pipe.fuse_lora(lora_scale=0.125)
76
- print("LoRA weights loaded successfully")
77
-
78
- # ๋ฒˆ์—ญ๊ธฐ ์ดˆ๊ธฐํ™”
79
- print("Initializing translator...")
80
- g.translator = transformers_pipeline(
81
- "translation",
82
- model="Helsinki-NLP/opus-mt-ko-en",
83
- device=device
84
- )
85
- print("Model initialization completed successfully")
86
-
87
- except Exception as e:
88
- print(f"Error during model initialization: {str(e)}")
89
- raise
90
-
91
-
92
- # ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ์„ค์ •
93
- # ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ์„ค์ •
94
- os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
95
- os.environ['SPCONV_ALGO'] = 'native'
96
- os.environ['SPARSE_BACKEND'] = 'native'
97
- os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
98
- os.environ['XFORMERS_FORCE_DISABLE_TRITON'] = '1'
99
- os.environ['XFORMERS_ENABLE_FLASH_ATTENTION'] = '1'
100
- os.environ['TORCH_CUDA_MEMORY_ALLOCATOR'] = 'native'
101
- os.environ['PYTORCH_NO_CUDA_MEMORY_CACHING'] = '1'
102
-
103
- # CUDA ์ดˆ๊ธฐํ™” ๋ฐฉ์ง€
104
- torch.set_grad_enabled(False)
105
-
106
- # Hugging Face ํ† ํฐ ์„ค์ •
107
- HF_TOKEN = os.getenv("HF_TOKEN")
108
- if HF_TOKEN is None:
109
- raise ValueError("HF_TOKEN environment variable is not set")
110
-
111
- MAX_SEED = np.iinfo(np.int32).max
112
- TMP_DIR = "/tmp/Trellis-demo"
113
- os.makedirs(TMP_DIR, exist_ok=True)
114
-
115
- # Setup and initialization code
116
- cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
117
- PERSISTENT_DIR = os.environ.get("PERSISTENT_DIR", ".")
118
- gallery_path = path.join(PERSISTENT_DIR, "gallery")
119
-
120
- os.environ["TRANSFORMERS_CACHE"] = cache_path
121
- os.environ["HF_HUB_CACHE"] = cache_path
122
- os.environ["HF_HOME"] = cache_path
123
- os.environ['SPCONV_ALGO'] = 'native'
124
-
125
- torch.backends.cuda.matmul.allow_tf32 = True
126
-
127
-
128
-
129
- class timer:
130
- def __init__(self, method_name="timed process"):
131
- self.method = method_name
132
- def __enter__(self):
133
- self.start = time.time()
134
- print(f"{self.method} starts")
135
- def __exit__(self, exc_type, exc_val, exc_tb):
136
- end = time.time()
137
- print(f"{self.method} took {str(round(end - self.start, 2))}s")
138
-
139
- def preprocess_image(image: Image.Image) -> Tuple[str, Image.Image]:
140
- if image is None:
141
- print("Error: Input image is None")
142
- return "", None
143
-
144
- try:
145
- if g.trellis_pipeline is None:
146
- print("Error: trellis_pipeline is not initialized")
147
- return "", None
148
-
149
- # webp ์ด๋ฏธ์ง€๋ฅผ RGB๋กœ ๋ณ€ํ™˜
150
- if isinstance(image, str) and image.endswith('.webp'):
151
- image = Image.open(image).convert('RGB')
152
- elif isinstance(image, Image.Image):
153
- image = image.convert('RGB')
154
-
155
- trial_id = str(uuid.uuid4())
156
- processed_image = g.trellis_pipeline.preprocess_image(image)
157
- if processed_image is not None:
158
- save_path = f"{TMP_DIR}/{trial_id}.png"
159
- processed_image.save(save_path)
160
- print(f"Saved processed image to: {save_path}")
161
- return trial_id, processed_image
162
- else:
163
- print("Error: Processed image is None")
164
- return "", None
165
- except Exception as e:
166
- print(f"Error in image preprocessing: {str(e)}")
167
- return "", None
168
-
169
- def pack_state(gs: Gaussian, mesh: MeshExtractResult, trial_id: str) -> dict:
170
- return {
171
- 'gaussian': {
172
- **gs.init_params,
173
- '_xyz': gs._xyz.cpu().numpy(),
174
- '_features_dc': gs._features_dc.cpu().numpy(),
175
- '_scaling': gs._scaling.cpu().numpy(),
176
- '_rotation': gs._rotation.cpu().numpy(),
177
- '_opacity': gs._opacity.cpu().numpy(),
178
- },
179
- 'mesh': {
180
- 'vertices': mesh.vertices.cpu().numpy(),
181
- 'faces': mesh.faces.cpu().numpy(),
182
- },
183
- 'trial_id': trial_id,
184
- }
185
-
186
-
187
- def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
188
- gs = Gaussian(
189
- aabb=state['gaussian']['aabb'],
190
- sh_degree=state['gaussian']['sh_degree'],
191
- mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
192
- scaling_bias=state['gaussian']['scaling_bias'],
193
- opacity_bias=state['gaussian']['opacity_bias'],
194
- scaling_activation=state['gaussian']['scaling_activation'],
195
- )
196
- gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
197
- gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
198
- gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
199
- gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
200
- gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
201
-
202
- mesh = edict(
203
- vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
204
- faces=torch.tensor(state['mesh']['faces'], device='cuda'),
205
- )
206
-
207
- return gs, mesh, state['trial_id']
208
-
209
- @spaces.GPU
210
- def image_to_3d(trial_id: str, seed: int, randomize_seed: bool, ss_guidance_strength: float,
211
- ss_sampling_steps: int, slat_guidance_strength: float, slat_sampling_steps: int) -> Tuple[dict, str]:
212
- try:
213
- # ์ดˆ๊ธฐ ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
214
- clear_gpu_memory()
215
-
216
- if not trial_id or trial_id.strip() == "":
217
- return None, None
218
-
219
- image_path = f"{TMP_DIR}/{trial_id}.png"
220
- if not os.path.exists(image_path):
221
- return None, None
222
-
223
- image = Image.open(image_path)
224
-
225
- # ์ด๋ฏธ์ง€ ํฌ๊ธฐ ์ œํ•œ ๊ฐ•ํ™”
226
- max_size = 384 # ๋” ์ž‘์€ ํฌ๊ธฐ๋กœ ์ œํ•œ
227
- if max(image.size) > max_size:
228
- ratio = max_size / max(image.size)
229
- new_size = tuple(int(dim * ratio) for dim in image.size)
230
- image = image.resize(new_size, Image.LANCZOS)
231
-
232
- with torch.inference_mode():
233
- try:
234
- # ํŒŒ์ดํ”„๋ผ์ธ์„ GPU๋กœ ์ด๋™
235
- g.trellis_pipeline.to('cuda')
236
-
237
- outputs = g.trellis_pipeline.run(
238
- image,
239
- seed=seed,
240
- formats=["gaussian", "mesh"],
241
- preprocess_image=False,
242
- sparse_structure_sampler_params={
243
- "steps": min(ss_sampling_steps, 8),
244
- "cfg_strength": ss_guidance_strength
245
- },
246
- slat_sampler_params={
247
- "steps": min(slat_sampling_steps, 8),
248
- "cfg_strength": slat_guidance_strength
249
- }
250
- )
251
-
252
- # ์ค‘๊ฐ„ ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
253
- clear_gpu_memory()
254
-
255
- # ๋น„๋””์˜ค ๋ Œ๋”๋ง ์ตœ์ ํ™”
256
- video = render_utils.render_video(
257
- outputs['gaussian'][0],
258
- num_frames=30,
259
- resolution=384
260
- )['color']
261
-
262
- video_geo = render_utils.render_video(
263
- outputs['mesh'][0],
264
- num_frames=30,
265
- resolution=384
266
- )['normal']
267
-
268
- # tensor๋ฅผ numpy๋กœ ๋ณ€ํ™˜
269
- if torch.is_tensor(video[0]):
270
- video = [v.cpu().numpy() if torch.is_tensor(v) else v for v in video]
271
- if torch.is_tensor(video_geo[0]):
272
- video_geo = [v.cpu().numpy() if torch.is_tensor(v) else v for v in video_geo]
273
-
274
- clear_gpu_memory()
275
-
276
- # ๋น„๋””์˜ค ์ƒ์„ฑ ๋ฐ ์ €์žฅ
277
- video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
278
- new_trial_id = str(uuid.uuid4())
279
- video_path = f"{TMP_DIR}/{new_trial_id}.mp4"
280
- imageio.mimsave(video_path, video, fps=15)
281
-
282
- state = pack_state(outputs['gaussian'][0], outputs['mesh'][0], new_trial_id)
283
- return state, video_path
284
-
285
- finally:
286
- # ์ •๋ฆฌ ์ž‘์—…
287
- g.trellis_pipeline.to('cpu')
288
- clear_gpu_memory()
289
-
290
- except Exception as e:
291
- print(f"Error in image_to_3d: {str(e)}")
292
- if hasattr(g.trellis_pipeline, 'to'):
293
- g.trellis_pipeline.to('cpu')
294
- clear_gpu_memory()
295
- return None, None
296
-
297
- def clear_gpu_memory():
298
- """GPU ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ๋” ์ฒ ์ €ํ•˜๊ฒŒ ์ •๋ฆฌํ•˜๋Š” ํ•จ์ˆ˜"""
299
- try:
300
- if torch.cuda.is_available():
301
- # ๋ชจ๋“  GPU ์บ์‹œ ์ •๋ฆฌ
302
- torch.cuda.empty_cache()
303
- torch.cuda.synchronize()
304
-
305
- # ์‚ฌ์šฉํ•˜์ง€ ์•Š๋Š” ์บ์‹œ๋œ ๋ฉ”๋ชจ๋ฆฌ ํ•ด์ œ
306
- for i in range(torch.cuda.device_count()):
307
- with torch.cuda.device(i):
308
- torch.cuda.empty_cache()
309
- torch.cuda.ipc_collect()
310
-
311
- # Python ๊ฐ€๋น„์ง€ ์ปฌ๋ ‰ํ„ฐ ์‹คํ–‰
312
- gc.collect()
313
- except Exception as e:
314
- print(f"Error in clear_gpu_memory: {e}")
315
-
316
- def move_to_device(model, device):
317
- """๋ชจ๋ธ์„ ์•ˆ์ „ํ•˜๊ฒŒ ๋””๋ฐ”์ด์Šค๋กœ ์ด๋™ํ•˜๋Š” ํ•จ์ˆ˜"""
318
- try:
319
- if hasattr(model, 'to'):
320
- clear_gpu_memory()
321
- model.to(device)
322
- if device == 'cuda':
323
- torch.cuda.synchronize()
324
- clear_gpu_memory()
325
- except Exception as e:
326
- print(f"Error moving model to {device}: {str(e)}")
327
-
328
- @spaces.GPU
329
- def extract_glb(state: dict, mesh_simplify: float, texture_size: int) -> Tuple[str, str]:
330
- """
331
- 3D ๋ชจ๋ธ์—์„œ GLB ํŒŒ์ผ ์ถ”์ถœ
332
- """
333
- try:
334
- gs, mesh, trial_id = unpack_state(state)
335
- glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
336
- glb_path = f"{TMP_DIR}/{trial_id}.glb"
337
- glb.export(glb_path)
338
- return glb_path, glb_path
339
- except Exception as e:
340
- print(f"GLB ์ถ”์ถœ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
341
- return None, None
342
-
343
-
344
-
345
- def activate_button() -> gr.Button:
346
- return gr.Button(interactive=True)
347
-
348
-
349
- def deactivate_button() -> gr.Button:
350
- return gr.Button(interactive=False)
351
-
352
- @spaces.GPU
353
- def text_to_image(prompt: str, height: int, width: int, steps: int, scales: float, seed: int) -> Image.Image:
354
- try:
355
- # CUDA ๋ฉ”๋ชจ๋ฆฌ ์ดˆ๊ธฐํ™”
356
- if torch.cuda.is_available():
357
- torch.cuda.empty_cache()
358
- torch.cuda.synchronize()
359
- gc.collect()
360
-
361
- # ํ•œ๊ธ€ ๊ฐ์ง€ ๋ฐ ๋ฒˆ์—ญ
362
- def contains_korean(text):
363
- return any(ord('๊ฐ€') <= ord(c) <= ord('ํžฃ') for c in text)
364
-
365
- if contains_korean(prompt):
366
- translated = g.translator(prompt)[0]['translation_text']
367
- prompt = translated
368
- print(f"Translated prompt: {prompt}")
369
-
370
- formatted_prompt = f"wbgmsst, 3D, {prompt}, white background"
371
-
372
- # ํฌ๊ธฐ ์ œํ•œ
373
- height = min(height, 512)
374
- width = min(width, 512)
375
- steps = min(steps, 12)
376
-
377
- with torch.inference_mode():
378
- generated_image = g.flux_pipe(
379
- prompt=[formatted_prompt],
380
- generator=torch.Generator('cuda').manual_seed(int(seed)),
381
- num_inference_steps=int(steps),
382
- guidance_scale=float(scales),
383
- height=int(height),
384
- width=int(width),
385
- max_sequence_length=256
386
- ).images[0]
387
-
388
- if generated_image is not None:
389
- trial_id = str(uuid.uuid4())
390
- save_path = f"{TMP_DIR}/{trial_id}.png"
391
- generated_image.save(save_path)
392
- print(f"Saved generated image to: {save_path}")
393
- return generated_image
394
- else:
395
- print("Error: Generated image is None")
396
- return None
397
-
398
- except Exception as e:
399
- print(f"Error in image generation: {str(e)}")
400
- return None
401
- finally:
402
- if torch.cuda.is_available():
403
- torch.cuda.empty_cache()
404
- torch.cuda.synchronize()
405
- gc.collect()
406
-
407
- css = """
408
- footer {
409
- visibility: hidden;
410
- }
411
- """
412
-
413
- def periodic_cleanup():
414
- """์ฃผ๊ธฐ์ ์œผ๋กœ ์‹คํ–‰๋  ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ ํ•จ์ˆ˜"""
415
- clear_gpu_memory()
416
- return None
417
-
418
- with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
419
- gr.Markdown("""## Roblox3D GEN""")
420
-
421
- # Examples ์ด๋ฏธ์ง€ ๋กœ๋“œ
422
- example_dir = "assets/example_image/"
423
- example_images = []
424
- if os.path.exists(example_dir):
425
- for file in os.listdir(example_dir):
426
- if file.endswith('.webp'):
427
- example_images.append(os.path.join(example_dir, file))
428
-
429
- with gr.Row():
430
- with gr.Column():
431
- text_prompt = gr.Textbox(
432
- label="Text Prompt",
433
- placeholder="Describe what you want to create...",
434
- lines=3
435
- )
436
-
437
- # ์ด๋ฏธ์ง€ ํ”„๋กฌํ”„ํŠธ
438
- image_prompt = gr.Image(label="Image Prompt", image_mode="RGBA", type="pil", height=300)
439
-
440
- with gr.Accordion("Image Generation Settings", open=False):
441
- with gr.Row():
442
- height = gr.Slider(
443
- label="Height",
444
- minimum=256,
445
- maximum=1152,
446
- step=64,
447
- value=1024
448
- )
449
- width = gr.Slider(
450
- label="Width",
451
- minimum=256,
452
- maximum=1152,
453
- step=64,
454
- value=1024
455
- )
456
-
457
- with gr.Row():
458
- steps = gr.Slider(
459
- label="Inference Steps",
460
- minimum=6,
461
- maximum=25,
462
- step=1,
463
- value=8
464
- )
465
- scales = gr.Slider(
466
- label="Guidance Scale",
467
- minimum=0.0,
468
- maximum=5.0,
469
- step=0.1,
470
- value=3.5
471
- )
472
-
473
- seed = gr.Number(
474
- label="Seed",
475
- value=lambda: torch.randint(0, MAX_SEED, (1,)).item(),
476
- precision=0
477
- )
478
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
479
-
480
- generate_image_btn = gr.Button("Generate Image")
481
-
482
- with gr.Accordion("3D Generation Settings", open=False):
483
- ss_guidance_strength = gr.Slider(0.0, 10.0, label="Structure Guidance Strength", value=7.5, step=0.1)
484
- ss_sampling_steps = gr.Slider(1, 50, label="Structure Sampling Steps", value=12, step=1)
485
- slat_guidance_strength = gr.Slider(0.0, 10.0, label="Latent Guidance Strength", value=3.0, step=0.1)
486
- slat_sampling_steps = gr.Slider(1, 50, label="Latent Sampling Steps", value=12, step=1)
487
-
488
- generate_3d_btn = gr.Button("Generate 3D")
489
-
490
- with gr.Accordion("GLB Extraction Settings", open=False):
491
- mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
492
- texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
493
-
494
- extract_glb_btn = gr.Button("Extract GLB", interactive=False)
495
-
496
- with gr.Column():
497
- # ์ƒ๋‹จ์— 3d.mp4 ์ž๋™์žฌ์ƒ ๋น„๋””์˜ค ์ถ”๊ฐ€
498
- gr.Video(
499
- "3d.mp4",
500
- label="3D Asset Preview",
501
- autoplay=True,
502
- loop=True,
503
- height=300,
504
- width="100%"
505
- )
506
- video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
507
- model_output = LitModel3D(label="Extracted GLB", exposure=20.0, height=300)
508
- download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
509
-
510
- trial_id = gr.Textbox(visible=False)
511
- output_buf = gr.State()
512
-
513
- # Examples ๊ฐค๋Ÿฌ๋ฆฌ๋ฅผ ๋งจ ์•„๋ž˜๋กœ ์ด๋™
514
- if example_images:
515
- gr.Markdown("""### Example Images""")
516
- with gr.Row():
517
- gallery = gr.Gallery(
518
- value=example_images,
519
- label="Click an image to use it",
520
- show_label=True,
521
- elem_id="gallery",
522
- columns=11, # ํ•œ ์ค„์— 12๊ฐœ
523
- rows=3, # 2์ค„
524
- height=400, # ๋†’์ด ์กฐ์ •
525
- allow_preview=True,
526
- object_fit="contain" # ์ด๋ฏธ์ง€ ๋น„์œจ ์œ ์ง€
527
- )
528
-
529
- def load_example(evt: gr.SelectData):
530
- selected_image = Image.open(example_images[evt.index])
531
- trial_id_val, processed_image = preprocess_image(selected_image)
532
- return selected_image, trial_id_val
533
-
534
- gallery.select(
535
- load_example,
536
- None,
537
- [image_prompt, trial_id],
538
- show_progress=True
539
- )
540
-
541
- # Handlers
542
- generate_image_btn.click(
543
- text_to_image,
544
- inputs=[text_prompt, height, width, steps, scales, seed],
545
- outputs=[image_prompt]
546
- ).then(
547
- preprocess_image,
548
- inputs=[image_prompt],
549
- outputs=[trial_id, image_prompt]
550
- )
551
-
552
- # ๋‚˜๋จธ์ง€ ํ•ธ๋“ค๋Ÿฌ๋“ค
553
- image_prompt.upload(
554
- preprocess_image,
555
- inputs=[image_prompt],
556
- outputs=[trial_id, image_prompt],
557
- )
558
-
559
- image_prompt.clear(
560
- lambda: '',
561
- outputs=[trial_id],
562
- )
563
-
564
- generate_3d_btn.click(
565
- image_to_3d,
566
- inputs=[trial_id, seed, randomize_seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
567
- outputs=[output_buf, video_output],
568
- ).then(
569
- activate_button,
570
- outputs=[extract_glb_btn],
571
- )
572
-
573
- video_output.clear(
574
- deactivate_button,
575
- outputs=[extract_glb_btn],
576
- )
577
-
578
- extract_glb_btn.click(
579
- extract_glb,
580
- inputs=[output_buf, mesh_simplify, texture_size],
581
- outputs=[model_output, download_glb],
582
- ).then(
583
- activate_button,
584
- outputs=[download_glb],
585
- )
586
-
587
- model_output.clear(
588
- deactivate_button,
589
- outputs=[download_glb],
590
- )
591
-
592
- if __name__ == "__main__":
593
- try:
594
- # CPU๋กœ ์ดˆ๊ธฐํ™”
595
- device = "cpu"
596
- print(f"Using device: {device}")
597
-
598
- # ๋ชจ๋ธ ์ดˆ๊ธฐํ™”
599
- initialize_models(device)
600
-
601
- # ์ดˆ๊ธฐ ์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ ํ…Œ์ŠคํŠธ
602
- try:
603
- test_image = Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))
604
- if g.trellis_pipeline is not None:
605
- g.trellis_pipeline.preprocess_image(test_image)
606
- else:
607
- print("Warning: trellis_pipeline is None")
608
- except Exception as e:
609
- print(f"Warning: Initial preprocessing test failed: {e}")
610
-
611
- # Gradio ์ธํ„ฐํŽ˜์ด์Šค ์‹คํ–‰
612
- demo.queue() # ํ ๊ธฐ๋Šฅ ํ™œ์„ฑํ™”
613
- demo.launch(
614
- allowed_paths=[PERSISTENT_DIR, TMP_DIR],
615
- server_name="0.0.0.0",
616
- server_port=7860,
617
- show_error=True
618
- )
619
-
620
- except Exception as e:
621
- print(f"Error during initialization: {e}")
622
- raise