Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import yaml | |
import numpy as np | |
from PIL import Image | |
import rembg | |
import importlib | |
import torch | |
import tempfile | |
import json | |
import spaces | |
from core.models import DiT_models | |
from core.diffusion import create_diffusion | |
from core.utils.dinov2 import Dinov2Model | |
from core.utils.math_utils import unnormalize_params | |
from huggingface_hub import hf_hub_download | |
# Setup PyTorch: | |
torch.set_grad_enabled(False) | |
device = torch.device('cuda') | |
# Define the cache directory for model files | |
#model_cache_dir = './ckpts/' | |
#os.makedirs(model_cache_dir, exist_ok=True) | |
# load generators & models | |
generators_choices = ["chair", "table", "vase", "basket", "flower", "dandelion"] | |
factory_names = ["ChairFactory", "TableDiningFactory", "VaseFactory", "BasketBaseFactory", "FlowerFactory", "DandelionFactory"] | |
generator_path = "./core/assets/" | |
generators, configs, models = [], [], [] | |
for category, factory in zip(generators_choices, factory_names): | |
# load generator | |
module = importlib.import_module(f"core.assets.{category}") | |
gen = getattr(module, factory) | |
generator = gen(0) | |
generators.append(generator) | |
# load configs | |
config_path = f"./configs/demo/{category}_demo.yaml" | |
with open(config_path) as f: | |
cfg = yaml.load(f, Loader=yaml.FullLoader) | |
configs.append(cfg) | |
# load models | |
latent_size = cfg["num_params"] | |
model = DiT_models[cfg["model"]](input_size=latent_size).to(device) | |
# load a custom DiT checkpoint from train.py: | |
# download the checkpoint if not found: | |
if not os.path.exists(cfg["ckpt_path"]): | |
model_dir, model_name = os.path.dirname(cfg["ckpt_path"]), os.path.basename(cfg["ckpt_path"]) | |
os.makedirs(model_dir, exist_ok=True) | |
checkpoint_path = hf_hub_download(repo_id="TencentARC/DI-PCG", | |
local_dir=model_dir, filename=model_name) | |
print("Downloading checkpoint {} from Hugging Face Hub...".format(model_name)) | |
print("Loading model from {}".format(cfg["ckpt_path"])) | |
state_dict = torch.load(cfg["ckpt_path"], map_location=lambda storage, loc: storage) | |
if "ema" in state_dict: # supports checkpoints from train.py | |
state_dict = state_dict["ema"] | |
model.load_state_dict(state_dict) | |
model.eval() | |
models.append(model) | |
diffusion = create_diffusion(str(cfg["num_sampling_steps"])) | |
# feature model | |
feature_model = Dinov2Model() | |
def check_input_image(input_image): | |
if input_image is None: | |
raise gr.Error("No image uploaded!") | |
def preprocess(input_image, do_remove_background): | |
# resize | |
if input_image.size[0] != 256 or input_image.size[1] != 256: | |
input_image = input_image.resize((256, 256)) | |
# remove background | |
if do_remove_background: | |
processed_image = rembg.remove(np.array(input_image)) | |
# white background | |
else: | |
processed_image = input_image | |
return processed_image | |
def sample(image, seed, category): | |
# seed | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
# generator & model | |
idx = generators_choices.index(category) | |
generator, cfg, model = generators[idx], configs[idx], models[idx] | |
# encode condition image feature | |
# convert RGBA images to RGB, white background | |
input_image_np = np.array(image) | |
mask = input_image_np[:, :, -1:] > 0 | |
input_image_np = input_image_np[:, :, :3] * mask + 255 * (1 - mask) | |
image = input_image_np.astype(np.uint8) | |
img_feat = feature_model.encode_batch_imgs([np.array(image)], global_feat=False) | |
# Create sampling noise: | |
latent_size = int(cfg['num_params']) | |
z = torch.randn(1, 1, latent_size, device=device) | |
y = img_feat | |
# No classifier-free guidance: | |
model_kwargs = dict(y=y) | |
# Sample target params: | |
samples = diffusion.p_sample_loop( | |
model.forward, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device | |
) | |
samples = samples[0].squeeze(0).cpu().numpy() | |
# unnormalize params | |
params_dict = generator.params_dict | |
params_original = unnormalize_params(samples, params_dict) | |
mesh_fpath = tempfile.NamedTemporaryFile(suffix=f".glb", delete=False).name | |
params_fpath = tempfile.NamedTemporaryFile(suffix=f".npy", delete=False).name | |
np.save(params_fpath, params_original) | |
print(mesh_fpath) | |
print(params_fpath) | |
# generate 3D using sampled params - TODO: this is a hacky way to go through PCG pipeline, avoiding conflict with gradio | |
command = f"python ./scripts/generate.py --config ./configs/demo/{category}_demo.yaml --output_path {mesh_fpath} --seed {seed} --params_path {params_fpath}" | |
os.system(command) | |
return mesh_fpath, mesh_fpath | |
import gradio as gr | |
_HEADER_ = ''' | |
<h2><b>DI-PCG π€ Gradio Demo</b></h2> | |
This is official demo for our technical report <a href="">DI-PCG: Diffusion-based Efficient Inverse Procedural Content Generation for High-quality 3D Asset Creation </a>. | |
**DI-PCG** is a diffusion model which directly generates a procedural generator's parameters from a single image, resulting in high-quality 3D meshes. | |
Code: <a href='https://github.com/TencentARC/DI-PCG' target='_blank'>GitHub</a>. Techenical report: <a href='http://arxiv.org/abs/2412.15200' target='_blank'>ArXiv</a>. | |
βοΈβοΈβοΈ**Important Notes:** | |
- DI-PCG trains a diffusion model for each procedural generator. Current supported generators are: Chair, Table, Vase, Basket, Flower, Dandelion from <a href="https://github.com/princeton-vl/infinigen">Infinigen</a>. | |
- The diversity of the generated meshes are strictly bounded by the procedural generators. For out-of-domain shapes, DI-PCG may only provide closest approximations. | |
''' | |
_CITE_ = r""" | |
If DI-PCG is helpful, please help to β the <a href='https://github.com/TencentARC/DI-PCG' target='_blank'>Github Repo</a>. Thanks! [![GitHub Stars](https://img.shields.io/github/stars/TencentARC/DI-PCG?style=social)](https://github.com/TencentARC/DI-PCG) | |
--- | |
π **Citation** | |
If you find our work useful for your research or applications, please cite using this bibtex: | |
```bibtex | |
@article{zhao2024dipcg, | |
title={DI-PCG: Diffusion-based Efficient Inverse Procedural Content Generation for High-quality 3D Asset Creation}, | |
author={Zhao, Wang and Cao, Yanpei and Xu, Jiale and Dong, Yuejiang and Shan, Ying}, | |
journal={arXiv preprint arXiv:2412.15200}, | |
year={2024} | |
} | |
``` | |
π **License** | |
Please refer to the [LICENSE file]() for details. | |
π§ **Contact** | |
If you have any questions, feel free to open a discussion or contact us at <b>[email protected]</b>. | |
""" | |
def update_examples(category): | |
samples = [[os.path.join(f"examples/{category}", img_name)] | |
for img_name in sorted(os.listdir(f"examples/{category}"))] | |
print(samples) | |
return gr.Dataset(samples=samples) | |
with gr.Blocks() as demo: | |
gr.Markdown(_HEADER_) | |
with gr.Row(variant="panel"): | |
with gr.Column(): | |
# select the generator category | |
with gr.Row(): | |
with gr.Group(): | |
generator_category = gr.Radio( | |
choices=[ | |
"chair", | |
"table", | |
"vase", | |
"basket", | |
"flower", | |
"dandelion", | |
], | |
value="chair", | |
label="category", | |
) | |
with gr.Row(): | |
input_image = gr.Image( | |
label="Input Image", | |
image_mode="RGB", | |
sources='upload', | |
width=256, | |
height=256, | |
type="pil", | |
elem_id="content_image", | |
) | |
processed_image = gr.Image( | |
label="Processed Image", | |
image_mode="RGBA", | |
width=256, | |
height=256, | |
type="pil", | |
interactive=False | |
) | |
with gr.Row(): | |
with gr.Group(): | |
do_remove_background = gr.Checkbox( | |
label="Remove Background", value=False | |
) | |
sample_seed = gr.Number(value=0, label="Seed Value", precision=0) | |
with gr.Row(): | |
submit = gr.Button("Generate", elem_id="generate", variant="primary") | |
with gr.Row(variant="panel"): | |
examples = gr.Examples( | |
[os.path.join(f"examples/chair", img_name) for img_name in sorted(os.listdir(f"examples/chair"))], | |
inputs=[input_image], | |
label="Examples", | |
examples_per_page=7 | |
) | |
generator_category.change(update_examples, generator_category, outputs=examples.dataset) | |
with gr.Column(): | |
with gr.Row(): | |
with gr.Tab("Geometry"): | |
output_model_wireframe = gr.Model3D( | |
label="Output Model (Wireframe)", | |
#width=768, | |
display_mode="wireframe", | |
interactive=False | |
) | |
output_model_solid = gr.Model3D( | |
label="Output Model (Solid)", | |
#width=768, | |
interactive=False, | |
) | |
gr.Markdown(_CITE_) | |
mv_images = gr.State() | |
submit.click(fn=check_input_image, inputs=[input_image]).success( | |
fn=preprocess, | |
inputs=[input_image, do_remove_background], | |
outputs=[processed_image], | |
).success( | |
fn=sample, | |
inputs=[processed_image, sample_seed, generator_category], | |
outputs=[output_model_wireframe, output_model_solid], | |
) | |
demo.queue() | |
demo.launch() |