Masrkai commited on
Commit
768bc7d
·
verified ·
1 Parent(s): 1a0596e

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +33 -14
model.py CHANGED
@@ -1,20 +1,39 @@
1
- # model.py
2
  import torch
3
- from diffusers import ShapEPipeline
4
- from diffusers.utils import export_to_gif
 
 
 
5
 
6
- # Load pipeline once to avoid reloading with each request
7
  def load_pipeline():
8
- ckpt_id = "openai/shap-e"
9
- pipe = ShapEPipeline.from_pretrained(ckpt_id, torch_dtype=torch.float32, trust_remote_code=True).to("cpu")
 
 
 
10
  return pipe
11
 
12
- # Generate images and export to GIF
13
- def generate_3d_gif(pipe, prompt, guidance_scale=10.0, num_inference_steps=32, size=256):
14
- images = pipe(
15
- prompt=prompt,
16
- guidance_scale=guidance_scale,
17
- num_inference_steps=num_inference_steps,
18
- ).images
19
- gif_path = export_to_gif(images, "generated_3d.gif")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  return gif_path
 
 
1
  import torch
2
+ from diffusers import DiffusionPipeline
3
+ import trimesh
4
+ import numpy as np
5
+ from PIL import Image
6
+ from io import BytesIO
7
 
 
8
  def load_pipeline():
9
+ """
10
+ Load the stable-zero123 model pipeline from Hugging Face.
11
+ """
12
+ ckpt_id = "stabilityai/stable-zero123"
13
+ pipe = DiffusionPipeline.from_pretrained(ckpt_id, torch_dtype=torch.float32).to("cpu")
14
  return pipe
15
 
16
+ def generate_3d_model(pipe, prompt, output_path="output.obj", guidance_scale=7.5, num_inference_steps=32):
17
+ """
18
+ Generate a 3D model from the prompt and save it in a Blender-compatible format (.obj).
19
+ """
20
+ # Generate the model output
21
+ outputs = pipe(prompt=prompt, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps)
22
+
23
+ # Extract mesh data if the output structure allows
24
+ vertices = outputs["vertices"][0].detach().cpu().numpy()
25
+ faces = outputs["faces"][0].detach().cpu().numpy()
26
+
27
+ # Create and save the mesh using trimesh
28
+ mesh = trimesh.Trimesh(vertices=vertices, faces=faces, process=True)
29
+ mesh.export(output_path)
30
+ return output_path
31
+
32
+ def convert_to_gif(images, gif_path="output.gif"):
33
+ """
34
+ Convert a list of images into a GIF.
35
+ """
36
+ images[0].save(
37
+ gif_path, save_all=True, append_images=images[1:], loop=0, duration=100
38
+ )
39
  return gif_path