|
from __future__ import annotations |
|
|
|
from functools import cached_property |
|
|
|
from diffusers import ( |
|
StableDiffusionControlNetInpaintPipeline, |
|
StableDiffusionControlNetPipeline, |
|
StableDiffusionInpaintPipeline, |
|
StableDiffusionPipeline, |
|
) |
|
|
|
from asdff.base import AdPipelineBase |
|
|
|
|
|
class AdPipeline(AdPipelineBase, StableDiffusionPipeline): |
|
@cached_property |
|
def inpaint_pipeline(self): |
|
return StableDiffusionInpaintPipeline( |
|
vae=self.vae, |
|
text_encoder=self.text_encoder, |
|
tokenizer=self.tokenizer, |
|
unet=self.unet, |
|
scheduler=self.scheduler, |
|
safety_checker=self.safety_checker, |
|
feature_extractor=self.feature_extractor, |
|
requires_safety_checker=self.config.requires_safety_checker, |
|
) |
|
|
|
@property |
|
def txt2img_class(self): |
|
return StableDiffusionPipeline |
|
|
|
|
|
class AdCnPipeline(AdPipelineBase, StableDiffusionControlNetPipeline): |
|
@cached_property |
|
def inpaint_pipeline(self): |
|
return StableDiffusionControlNetInpaintPipeline( |
|
vae=self.vae, |
|
text_encoder=self.text_encoder, |
|
tokenizer=self.tokenizer, |
|
unet=self.unet, |
|
controlnet=self.controlnet, |
|
scheduler=self.scheduler, |
|
safety_checker=self.safety_checker, |
|
feature_extractor=self.feature_extractor, |
|
requires_safety_checker=self.config.requires_safety_checker, |
|
) |
|
|
|
@property |
|
def txt2img_class(self): |
|
return StableDiffusionControlNetPipeline |
|
|