|
import numpy as np
|
|
import PIL.Image
|
|
from controlnet_aux.util import HWC3
|
|
from transformers import pipeline
|
|
|
|
from cv_utils import resize_image
|
|
|
|
|
|
class DepthEstimator:
|
|
def __init__(self):
|
|
self.model = pipeline("depth-estimation")
|
|
|
|
def __call__(self, image: np.ndarray, **kwargs) -> PIL.Image.Image:
|
|
detect_resolution = kwargs.pop("detect_resolution", 512)
|
|
image_resolution = kwargs.pop("image_resolution", 512)
|
|
image = np.array(image)
|
|
image = HWC3(image)
|
|
image = resize_image(image, resolution=detect_resolution)
|
|
image = PIL.Image.fromarray(image)
|
|
image = self.model(image)
|
|
image = image["depth"]
|
|
image = np.array(image)
|
|
image = HWC3(image)
|
|
image = resize_image(image, resolution=image_resolution)
|
|
return PIL.Image.fromarray(image)
|
|
|