|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import random |
|
|
|
import numpy as np |
|
import torch |
|
from diffusers import DiffusionPipeline |
|
from diffusers import EulerAncestralDiscreteScheduler |
|
|
|
|
|
class Multiview_Diffusion_Net(): |
|
def __init__(self, config) -> None: |
|
self.device = config.device |
|
self.view_size = 512 |
|
multiview_ckpt_path = config.multiview_ckpt_path |
|
|
|
current_file_path = os.path.abspath(__file__) |
|
custom_pipeline_path = os.path.join(os.path.dirname(current_file_path), '..', 'hunyuanpaint') |
|
|
|
pipeline = DiffusionPipeline.from_pretrained( |
|
multiview_ckpt_path, |
|
custom_pipeline=custom_pipeline_path, torch_dtype=torch.float16) |
|
|
|
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config, |
|
timestep_spacing='trailing') |
|
|
|
pipeline.set_progress_bar_config(disable=True) |
|
self.pipeline = pipeline.to(self.device) |
|
|
|
def seed_everything(self, seed): |
|
random.seed(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
os.environ["PL_GLOBAL_SEED"] = str(seed) |
|
|
|
def __call__(self, input_image, control_images, camera_info): |
|
|
|
self.seed_everything(0) |
|
|
|
input_image = input_image.resize((self.view_size, self.view_size)) |
|
for i in range(len(control_images)): |
|
control_images[i] = control_images[i].resize((self.view_size, self.view_size)) |
|
if control_images[i].mode == 'L': |
|
control_images[i] = control_images[i].point(lambda x: 255 if x > 1 else 0, mode='1') |
|
|
|
kwargs = dict(generator=torch.Generator(device=self.pipeline.device).manual_seed(0)) |
|
|
|
num_view = len(control_images) // 2 |
|
normal_image = [[control_images[i] for i in range(num_view)]] |
|
position_image = [[control_images[i + num_view] for i in range(num_view)]] |
|
|
|
camera_info_gen = [camera_info] |
|
camera_info_ref = [[0]] |
|
kwargs['width'] = self.view_size |
|
kwargs['height'] = self.view_size |
|
kwargs['num_in_batch'] = num_view |
|
kwargs['camera_info_gen'] = camera_info_gen |
|
kwargs['camera_info_ref'] = camera_info_ref |
|
kwargs["normal_imgs"] = normal_image |
|
kwargs["position_imgs"] = position_image |
|
|
|
mvd_image = self.pipeline(input_image, num_inference_steps=30, **kwargs).images |
|
return mvd_image |
|
|