aiqtech commited on
Commit
f96a94d
ยท
verified ยท
1 Parent(s): 52f4e8f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +146 -236
app.py CHANGED
@@ -1,11 +1,13 @@
1
  import gradio as gr
2
  import spaces
3
  from gradio_litmodel3d import LitModel3D
4
- from diffusers import StableDiffusionPipeline
5
-
6
  import os
7
- os.environ['SPCONV_ALGO'] = 'native'
8
- from typing import *
 
 
 
 
9
  import torch
10
  import numpy as np
11
  import imageio
@@ -15,209 +17,150 @@ from PIL import Image
15
  from trellis.pipelines import TrellisImageTo3DPipeline
16
  from trellis.representations import Gaussian, MeshExtractResult
17
  from trellis.utils import render_utils, postprocessing_utils
 
 
18
 
 
 
 
 
19
 
20
  MAX_SEED = np.iinfo(np.int32).max
21
  TMP_DIR = "/tmp/Trellis-demo"
22
-
23
  os.makedirs(TMP_DIR, exist_ok=True)
24
 
25
- # ์ƒ๋‹จ์˜ import ๋ถ€๋ถ„์— ์ถ”๊ฐ€
26
- from transformers import pipeline
 
 
 
 
 
 
 
 
 
27
 
28
  # ๋ฒˆ์—ญ๊ธฐ ์ดˆ๊ธฐํ™”
29
  translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
30
 
31
- # ํ•œ๊ธ€ ๊ฐ์ง€ ๋ฐ ๋ฒˆ์—ญ ํ•จ์ˆ˜
32
- def translate_korean_prompt(prompt: str) -> str:
33
- """
34
- ํ•œ๊ธ€์ด ํฌํ•จ๋œ ํ”„๋กฌํ”„ํŠธ๋ฅผ ์˜์–ด๋กœ ๋ฒˆ์—ญ
35
- """
36
- def contains_korean(text):
37
- return any(ord('๊ฐ€') <= ord(c) <= ord('ํžฃ') for c in text)
38
-
39
- if contains_korean(prompt):
40
- translated = translator(prompt)[0]['translation_text']
41
- return translated
42
- return prompt
43
 
44
  def preprocess_image(image: Image.Image) -> Tuple[str, Image.Image]:
45
- """
46
- Preprocess the input image.
47
- Args:
48
- image (Image.Image): The input image.
49
- Returns:
50
- str: uuid of the trial.
51
- Image.Image: The preprocessed image.
52
- """
53
  trial_id = str(uuid.uuid4())
54
  processed_image = pipeline.preprocess_image(image)
55
  processed_image.save(f"{TMP_DIR}/{trial_id}.png")
56
  return trial_id, processed_image
57
 
58
-
59
- def pack_state(gs: Gaussian, mesh: MeshExtractResult, trial_id: str) -> dict:
60
- return {
61
- 'gaussian': {
62
- **gs.init_params,
63
- '_xyz': gs._xyz.cpu().numpy(),
64
- '_features_dc': gs._features_dc.cpu().numpy(),
65
- '_scaling': gs._scaling.cpu().numpy(),
66
- '_rotation': gs._rotation.cpu().numpy(),
67
- '_opacity': gs._opacity.cpu().numpy(),
68
- },
69
- 'mesh': {
70
- 'vertices': mesh.vertices.cpu().numpy(),
71
- 'faces': mesh.faces.cpu().numpy(),
72
- },
73
- 'trial_id': trial_id,
74
- }
75
-
76
-
77
- def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
78
- gs = Gaussian(
79
- aabb=state['gaussian']['aabb'],
80
- sh_degree=state['gaussian']['sh_degree'],
81
- mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
82
- scaling_bias=state['gaussian']['scaling_bias'],
83
- opacity_bias=state['gaussian']['opacity_bias'],
84
- scaling_activation=state['gaussian']['scaling_activation'],
85
- )
86
- gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
87
- gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
88
- gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
89
- gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
90
- gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
91
-
92
- mesh = edict(
93
- vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
94
- faces=torch.tensor(state['mesh']['faces'], device='cuda'),
95
- )
96
-
97
- return gs, mesh, state['trial_id']
98
 
99
  @spaces.GPU
100
- def text_to_image(prompt: str, seed: int, randomize_seed: bool) -> Image.Image:
101
- """
102
- Generate image from text prompt using Stable Diffusion.
103
- """
104
- if randomize_seed:
105
- seed = np.random.randint(0, MAX_SEED)
106
 
107
- # ํ•œ๊ธ€ ํ”„๋กฌํ”„ํŠธ๋ฅผ ์˜์–ด๋กœ ๋ฒˆ์—ญ
108
- english_prompt = translate_korean_prompt(prompt)
 
 
109
 
110
  # ํ”„๋กฌํ”„ํŠธ ํ˜•์‹ ๊ฐ•์ œ
111
- formatted_prompt = f"wbgmsst, 3D, {english_prompt}, white background"
112
 
113
- generator = torch.Generator(device="cuda").manual_seed(seed)
114
- image = text2img_pipeline(formatted_prompt, generator=generator).images[0]
115
- return image
116
-
117
-
118
- @spaces.GPU
119
- def image_to_3d(trial_id: str, seed: int, randomize_seed: bool, ss_guidance_strength: float, ss_sampling_steps: int, slat_guidance_strength: float, slat_sampling_steps: int) -> Tuple[dict, str]:
120
- """
121
- Convert an image to a 3D model.
122
- Args:
123
- trial_id (str): The uuid of the trial.
124
- seed (int): The random seed.
125
- randomize_seed (bool): Whether to randomize the seed.
126
- ss_guidance_strength (float): The guidance strength for sparse structure generation.
127
- ss_sampling_steps (int): The number of sampling steps for sparse structure generation.
128
- slat_guidance_strength (float): The guidance strength for structured latent generation.
129
- slat_sampling_steps (int): The number of sampling steps for structured latent generation.
130
- Returns:
131
- dict: The information of the generated 3D model.
132
- str: The path to the video of the 3D model.
133
- """
134
- if randomize_seed:
135
- seed = np.random.randint(0, MAX_SEED)
136
- outputs = pipeline.run(
137
- Image.open(f"{TMP_DIR}/{trial_id}.png"),
138
- seed=seed,
139
- formats=["gaussian", "mesh"],
140
- preprocess_image=False,
141
- sparse_structure_sampler_params={
142
- "steps": ss_sampling_steps,
143
- "cfg_strength": ss_guidance_strength,
144
- },
145
- slat_sampler_params={
146
- "steps": slat_sampling_steps,
147
- "cfg_strength": slat_guidance_strength,
148
- },
149
- )
150
- video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
151
- video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
152
- video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
153
- trial_id = uuid.uuid4()
154
- video_path = f"{TMP_DIR}/{trial_id}.mp4"
155
- os.makedirs(os.path.dirname(video_path), exist_ok=True)
156
- imageio.mimsave(video_path, video, fps=15)
157
- state = pack_state(outputs['gaussian'][0], outputs['mesh'][0], trial_id)
158
- return state, video_path
159
-
160
-
161
- @spaces.GPU
162
- def extract_glb(state: dict, mesh_simplify: float, texture_size: int) -> Tuple[str, str]:
163
- """
164
- Extract a GLB file from the 3D model.
165
- Args:
166
- state (dict): The state of the generated 3D model.
167
- mesh_simplify (float): The mesh simplification factor.
168
- texture_size (int): The texture resolution.
169
- Returns:
170
- str: The path to the extracted GLB file.
171
- """
172
- gs, mesh, trial_id = unpack_state(state)
173
- glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
174
- glb_path = f"{TMP_DIR}/{trial_id}.glb"
175
- glb.export(glb_path)
176
- return glb_path, glb_path
177
-
178
-
179
- def activate_button() -> gr.Button:
180
- return gr.Button(interactive=True)
181
-
182
-
183
- def deactivate_button() -> gr.Button:
184
- return gr.Button(interactive=False)
185
-
186
-
187
- css = """
188
- footer {
189
- visibility: hidden;
190
- }
191
- """
192
-
193
 
194
- with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
195
- gr.Markdown("""
196
- ## Craft3D""")
197
 
198
  with gr.Row():
199
  with gr.Column():
200
- # Text to Image ๋ถ€๋ถ„ ์ถ”๊ฐ€
201
- text_prompt = gr.Textbox(label="Text Prompt", placeholder="Enter your text prompt here...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  generate_image_btn = gr.Button("Generate Image")
203
 
204
  image_prompt = gr.Image(label="Image Prompt", image_mode="RGBA", type="pil", height=300)
205
 
206
- with gr.Accordion(label="Generation Settings", open=False):
207
- seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
208
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
209
- gr.Markdown("Stage 1: Sparse Structure Generation")
210
- with gr.Row():
211
- ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
212
- ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
213
- gr.Markdown("Stage 2: Structured Latent Generation")
214
- with gr.Row():
215
- slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
216
- slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
217
 
218
- generate_btn = gr.Button("Generate")
219
 
220
- with gr.Accordion(label="GLB Extraction Settings", open=False):
221
  mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
222
  texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
223
 
@@ -231,81 +174,48 @@ with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
231
  trial_id = gr.Textbox(visible=False)
232
  output_buf = gr.State()
233
 
234
- # Example images at the bottom of the page
235
- with gr.Row():
236
- examples = gr.Examples(
237
- examples=[
238
- f'assets/example_image/{image}'
239
- for image in os.listdir("assets/example_image")
240
- ],
241
- inputs=[image_prompt],
242
- fn=preprocess_image,
243
- outputs=[trial_id, image_prompt],
244
- run_on_click=True,
245
- examples_per_page=64,
246
- )
247
-
248
  # Handlers
249
  generate_image_btn.click(
250
  text_to_image,
251
- inputs=[text_prompt, seed, randomize_seed],
252
- outputs=[image_prompt],
253
  ).then(
254
  preprocess_image,
255
  inputs=[image_prompt],
256
- outputs=[trial_id, image_prompt],
257
  )
258
 
259
- image_prompt.upload(
260
- preprocess_image,
261
- inputs=[image_prompt],
262
- outputs=[trial_id, image_prompt],
263
- )
264
- image_prompt.clear(
265
- lambda: '',
266
- outputs=[trial_id],
267
- )
268
-
269
- generate_btn.click(
270
- image_to_3d,
271
- inputs=[trial_id, seed, randomize_seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
272
- outputs=[output_buf, video_output],
273
- ).then(
274
- activate_button,
275
- outputs=[extract_glb_btn],
276
- )
277
-
278
- video_output.clear(
279
- deactivate_button,
280
- outputs=[extract_glb_btn],
281
- )
282
-
283
- extract_glb_btn.click(
284
- extract_glb,
285
- inputs=[output_buf, mesh_simplify, texture_size],
286
- outputs=[model_output, download_glb],
287
- ).then(
288
- activate_button,
289
- outputs=[download_glb],
290
- )
291
 
292
- model_output.clear(
293
- deactivate_button,
294
- outputs=[download_glb],
295
- )
296
-
297
-
298
- # Launch the Gradio app
299
  if __name__ == "__main__":
300
- pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
 
 
 
 
301
  pipeline.cuda()
302
 
303
- # Stable Diffusion pipeline ์ถ”๊ฐ€
304
- text2img_pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
305
- text2img_pipeline.to("cuda")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
 
307
  try:
308
- pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # Preload rembg
309
  except:
310
  pass
311
- demo.launch()
 
 
1
  import gradio as gr
2
  import spaces
3
  from gradio_litmodel3d import LitModel3D
 
 
4
  import os
5
+ import time
6
+ from os import path
7
+ import shutil
8
+ from datetime import datetime
9
+ from safetensors.torch import load_file
10
+ from huggingface_hub import hf_hub_download
11
  import torch
12
  import numpy as np
13
  import imageio
 
17
  from trellis.pipelines import TrellisImageTo3DPipeline
18
  from trellis.representations import Gaussian, MeshExtractResult
19
  from trellis.utils import render_utils, postprocessing_utils
20
+ from diffusers import FluxPipeline
21
+ from transformers import pipeline
22
 
23
+ # Hugging Face ํ† ํฐ ์„ค์ •
24
+ HF_TOKEN = os.getenv("HF_TOKEN")
25
+ if HF_TOKEN is None:
26
+ raise ValueError("HF_TOKEN environment variable is not set")
27
 
28
  MAX_SEED = np.iinfo(np.int32).max
29
  TMP_DIR = "/tmp/Trellis-demo"
 
30
  os.makedirs(TMP_DIR, exist_ok=True)
31
 
32
+ # Setup and initialization code
33
+ cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
34
+ PERSISTENT_DIR = os.environ.get("PERSISTENT_DIR", ".")
35
+ gallery_path = path.join(PERSISTENT_DIR, "gallery")
36
+
37
+ os.environ["TRANSFORMERS_CACHE"] = cache_path
38
+ os.environ["HF_HUB_CACHE"] = cache_path
39
+ os.environ["HF_HOME"] = cache_path
40
+ os.environ['SPCONV_ALGO'] = 'native'
41
+
42
+ torch.backends.cuda.matmul.allow_tf32 = True
43
 
44
  # ๋ฒˆ์—ญ๊ธฐ ์ดˆ๊ธฐํ™”
45
  translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
46
 
47
+ class timer:
48
+ def __init__(self, method_name="timed process"):
49
+ self.method = method_name
50
+ def __enter__(self):
51
+ self.start = time.time()
52
+ print(f"{self.method} starts")
53
+ def __exit__(self, exc_type, exc_val, exc_tb):
54
+ end = time.time()
55
+ print(f"{self.method} took {str(round(end - self.start, 2))}s")
 
 
 
56
 
57
  def preprocess_image(image: Image.Image) -> Tuple[str, Image.Image]:
 
 
 
 
 
 
 
 
58
  trial_id = str(uuid.uuid4())
59
  processed_image = pipeline.preprocess_image(image)
60
  processed_image.save(f"{TMP_DIR}/{trial_id}.png")
61
  return trial_id, processed_image
62
 
63
+ [์ด์ „ ์ฝ”๋“œ์˜ ๋‚˜๋จธ์ง€ ํ•จ์ˆ˜๋“ค: pack_state, unpack_state, image_to_3d, extract_glb, activate_button, deactivate_button์€ ๊ทธ๋Œ€๋กœ ์œ ์ง€]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  @spaces.GPU
66
+ def text_to_image(prompt: str, height: int, width: int, steps: int, scales: float, seed: int) -> Image.Image:
67
+ # ํ•œ๊ธ€ ๊ฐ์ง€ ๋ฐ ๋ฒˆ์—ญ
68
+ def contains_korean(text):
69
+ return any(ord('๊ฐ€') <= ord(c) <= ord('ํžฃ') for c in text)
 
 
70
 
71
+ # ํ”„๋กฌํ”„ํŠธ ์ „์ฒ˜๋ฆฌ
72
+ if contains_korean(prompt):
73
+ translated = translator(prompt)[0]['translation_text']
74
+ prompt = translated
75
 
76
  # ํ”„๋กฌํ”„ํŠธ ํ˜•์‹ ๊ฐ•์ œ
77
+ formatted_prompt = f"wbgmsst, 3D, {prompt}, white background"
78
 
79
+ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
80
+ try:
81
+ generated_image = pipe(
82
+ prompt=[formatted_prompt],
83
+ generator=torch.Generator().manual_seed(int(seed)),
84
+ num_inference_steps=int(steps),
85
+ guidance_scale=float(scales),
86
+ height=int(height),
87
+ width=int(width),
88
+ max_sequence_length=256
89
+ ).images[0]
90
+
91
+ trial_id = str(uuid.uuid4())
92
+ generated_image.save(f"{TMP_DIR}/{trial_id}.png")
93
+ return generated_image
94
+
95
+ except Exception as e:
96
+ print(f"Error in image generation: {str(e)}")
97
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
+ # Gradio Interface
100
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
101
+ gr.Markdown("""## Craft3D""")
102
 
103
  with gr.Row():
104
  with gr.Column():
105
+ text_prompt = gr.Textbox(
106
+ label="Text Prompt",
107
+ placeholder="Describe what you want to create...",
108
+ lines=3
109
+ )
110
+
111
+ with gr.Accordion("Image Generation Settings", open=False):
112
+ with gr.Row():
113
+ height = gr.Slider(
114
+ label="Height",
115
+ minimum=256,
116
+ maximum=1152,
117
+ step=64,
118
+ value=1024
119
+ )
120
+ width = gr.Slider(
121
+ label="Width",
122
+ minimum=256,
123
+ maximum=1152,
124
+ step=64,
125
+ value=1024
126
+ )
127
+
128
+ with gr.Row():
129
+ steps = gr.Slider(
130
+ label="Inference Steps",
131
+ minimum=6,
132
+ maximum=25,
133
+ step=1,
134
+ value=8
135
+ )
136
+ scales = gr.Slider(
137
+ label="Guidance Scale",
138
+ minimum=0.0,
139
+ maximum=5.0,
140
+ step=0.1,
141
+ value=3.5
142
+ )
143
+
144
+ seed = gr.Number(
145
+ label="Seed",
146
+ value=lambda: torch.randint(0, MAX_SEED, (1,)).item(),
147
+ precision=0
148
+ )
149
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
150
+
151
  generate_image_btn = gr.Button("Generate Image")
152
 
153
  image_prompt = gr.Image(label="Image Prompt", image_mode="RGBA", type="pil", height=300)
154
 
155
+ with gr.Accordion("3D Generation Settings", open=False):
156
+ ss_guidance_strength = gr.Slider(0.0, 10.0, label="Structure Guidance Strength", value=7.5, step=0.1)
157
+ ss_sampling_steps = gr.Slider(1, 50, label="Structure Sampling Steps", value=12, step=1)
158
+ slat_guidance_strength = gr.Slider(0.0, 10.0, label="Latent Guidance Strength", value=3.0, step=0.1)
159
+ slat_sampling_steps = gr.Slider(1, 50, label="Latent Sampling Steps", value=12, step=1)
 
 
 
 
 
 
160
 
161
+ generate_3d_btn = gr.Button("Generate 3D")
162
 
163
+ with gr.Accordion("GLB Extraction Settings", open=False):
164
  mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
165
  texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
166
 
 
174
  trial_id = gr.Textbox(visible=False)
175
  output_buf = gr.State()
176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  # Handlers
178
  generate_image_btn.click(
179
  text_to_image,
180
+ inputs=[text_prompt, height, width, steps, scales, seed],
181
+ outputs=[image_prompt]
182
  ).then(
183
  preprocess_image,
184
  inputs=[image_prompt],
185
+ outputs=[trial_id, image_prompt]
186
  )
187
 
188
+ [์ด์ „ ์ฝ”๋“œ์˜ ๋‚˜๋จธ์ง€ ํ•ธ๋“ค๋Ÿฌ๋“ค์€ ๊ทธ๋Œ€๋กœ ์œ ์ง€]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
 
 
 
 
 
 
 
 
190
  if __name__ == "__main__":
191
+ # 3D ์ƒ์„ฑ ํŒŒ์ดํ”„๋ผ์ธ
192
+ pipeline = TrellisImageTo3DPipeline.from_pretrained(
193
+ "JeffreyXiang/TRELLIS-image-large",
194
+ use_auth_token=HF_TOKEN
195
+ )
196
  pipeline.cuda()
197
 
198
+ # ์ด๋ฏธ์ง€ ์ƒ์„ฑ ํŒŒ์ดํ”„๋ผ์ธ
199
+ pipe = FluxPipeline.from_pretrained(
200
+ "black-forest-labs/FLUX.1-dev",
201
+ torch_dtype=torch.bfloat16,
202
+ use_auth_token=HF_TOKEN
203
+ )
204
+
205
+ # Hyper-SD LoRA ๋กœ๋“œ
206
+ pipe.load_lora_weights(
207
+ hf_hub_download(
208
+ "ByteDance/Hyper-SD",
209
+ "Hyper-FLUX.1-dev-8steps-lora.safetensors",
210
+ use_auth_token=HF_TOKEN
211
+ )
212
+ )
213
+ pipe.fuse_lora(lora_scale=0.125)
214
+ pipe.to(device="cuda", dtype=torch.bfloat16)
215
 
216
  try:
217
+ pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8)))
218
  except:
219
  pass
220
+
221
+ demo.launch(allowed_paths=[PERSISTENT_DIR])