|
from typing import List, Union |
|
|
|
import torch |
|
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionXLInpaintPipeline |
|
|
|
from internals.pipelines.commons import AbstractPipeline |
|
from internals.util.commons import disable_safety_checker, download_image |
|
from internals.util.config import ( |
|
get_hf_cache_dir, |
|
get_hf_token, |
|
get_is_sdxl, |
|
get_inpaint_model_path, |
|
get_model_dir, |
|
) |
|
|
|
|
|
class InPainter(AbstractPipeline): |
|
__loaded = False |
|
|
|
def init(self, pipeline: AbstractPipeline): |
|
self.__base = pipeline |
|
|
|
def load(self): |
|
if self.__loaded: |
|
return |
|
|
|
if hasattr(self, "__base") and get_inpaint_model_path() == get_model_dir(): |
|
self.create(self.__base) |
|
self.__loaded = True |
|
return |
|
|
|
if get_is_sdxl(): |
|
self.pipe = StableDiffusionXLInpaintPipeline.from_pretrained( |
|
get_inpaint_model_path(), |
|
torch_dtype=torch.float16, |
|
cache_dir=get_hf_cache_dir(), |
|
use_auth_token=get_hf_token(), |
|
).to("cuda") |
|
else: |
|
self.pipe = StableDiffusionInpaintPipeline.from_pretrained( |
|
get_inpaint_model_path(), |
|
torch_dtype=torch.float16, |
|
cache_dir=get_hf_cache_dir(), |
|
use_auth_token=get_hf_token(), |
|
).to("cuda") |
|
|
|
disable_safety_checker(self.pipe) |
|
|
|
self.__patch() |
|
|
|
self.__loaded = True |
|
|
|
def create(self, pipeline: AbstractPipeline): |
|
if get_is_sdxl(): |
|
self.pipe = StableDiffusionXLInpaintPipeline(**pipeline.pipe.components).to( |
|
"cuda" |
|
) |
|
else: |
|
self.pipe = StableDiffusionInpaintPipeline(**pipeline.pipe.components).to( |
|
"cuda" |
|
) |
|
disable_safety_checker(self.pipe) |
|
|
|
self.__patch() |
|
|
|
def __patch(self): |
|
if get_is_sdxl(): |
|
self.pipe.enable_vae_tiling() |
|
self.pipe.enable_vae_slicing() |
|
self.pipe.enable_xformers_memory_efficient_attention() |
|
|
|
@torch.inference_mode() |
|
def process( |
|
self, |
|
image_url: str, |
|
mask_image_url: str, |
|
width: int, |
|
height: int, |
|
seed: int, |
|
prompt: Union[str, List[str]], |
|
negative_prompt: Union[str, List[str]], |
|
num_inference_steps: int, |
|
**kwargs, |
|
): |
|
torch.manual_seed(seed) |
|
|
|
input_img = download_image(image_url).resize((width, height)) |
|
mask_img = download_image(mask_image_url).resize((width, height)) |
|
|
|
kwargs = { |
|
"prompt": prompt, |
|
"image": input_img, |
|
"mask_image": mask_img, |
|
"height": height, |
|
"width": width, |
|
"negative_prompt": negative_prompt, |
|
"num_inference_steps": num_inference_steps, |
|
**kwargs, |
|
} |
|
return self.pipe.__call__(**kwargs).images |
|
|