CM2000112 / carvekit /web /utils /init_utils.py
jayparmr's picture
Upload folder using huggingface_hub
a3d6c18
from os import getenv
from typing import Union
from loguru import logger
from carvekit.web.schemas.config import WebAPIConfig, MLConfig, AuthConfig
from carvekit.api.interface import Interface
from carvekit.ml.wrap.fba_matting import FBAMatting
from carvekit.ml.wrap.u2net import U2NET
from carvekit.ml.wrap.deeplab_v3 import DeepLabV3
from carvekit.ml.wrap.basnet import BASNET
from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7
from carvekit.pipelines.postprocessing import MattingMethod
from carvekit.pipelines.preprocessing import PreprocessingStub
from carvekit.trimap.generator import TrimapGenerator
def init_config() -> WebAPIConfig:
default_config = WebAPIConfig()
config = WebAPIConfig(
**dict(
port=int(getenv("CARVEKIT_PORT", default_config.port)),
host=getenv("CARVEKIT_HOST", default_config.host),
ml=MLConfig(
segmentation_network=getenv(
"CARVEKIT_SEGMENTATION_NETWORK",
default_config.ml.segmentation_network,
),
preprocessing_method=getenv(
"CARVEKIT_PREPROCESSING_METHOD",
default_config.ml.preprocessing_method,
),
postprocessing_method=getenv(
"CARVEKIT_POSTPROCESSING_METHOD",
default_config.ml.postprocessing_method,
),
device=getenv("CARVEKIT_DEVICE", default_config.ml.device),
batch_size_seg=int(
getenv("CARVEKIT_BATCH_SIZE_SEG", default_config.ml.batch_size_seg)
),
batch_size_matting=int(
getenv(
"CARVEKIT_BATCH_SIZE_MATTING",
default_config.ml.batch_size_matting,
)
),
seg_mask_size=int(
getenv("CARVEKIT_SEG_MASK_SIZE", default_config.ml.seg_mask_size)
),
matting_mask_size=int(
getenv(
"CARVEKIT_MATTING_MASK_SIZE",
default_config.ml.matting_mask_size,
)
),
fp16=bool(int(getenv("CARVEKIT_FP16", default_config.ml.fp16))),
trimap_prob_threshold=int(
getenv(
"CARVEKIT_TRIMAP_PROB_THRESHOLD",
default_config.ml.trimap_prob_threshold,
)
),
trimap_dilation=int(
getenv(
"CARVEKIT_TRIMAP_DILATION", default_config.ml.trimap_dilation
)
),
trimap_erosion=int(
getenv("CARVEKIT_TRIMAP_EROSION", default_config.ml.trimap_erosion)
),
),
auth=AuthConfig(
auth=bool(
int(getenv("CARVEKIT_AUTH_ENABLE", default_config.auth.auth))
),
admin_token=getenv(
"CARVEKIT_ADMIN_TOKEN", default_config.auth.admin_token
),
allowed_tokens=default_config.auth.allowed_tokens
if getenv("CARVEKIT_ALLOWED_TOKENS") is None
else getenv("CARVEKIT_ALLOWED_TOKENS").split(","),
),
)
)
logger.info(f"Admin token for Web API is {config.auth.admin_token}")
logger.debug(f"Running Web API with this config: {config.json()}")
return config
def init_interface(config: Union[WebAPIConfig, MLConfig]) -> Interface:
if isinstance(config, WebAPIConfig):
config = config.ml
if config.segmentation_network == "u2net":
seg_net = U2NET(
device=config.device,
batch_size=config.batch_size_seg,
input_image_size=config.seg_mask_size,
fp16=config.fp16,
)
elif config.segmentation_network == "deeplabv3":
seg_net = DeepLabV3(
device=config.device,
batch_size=config.batch_size_seg,
input_image_size=config.seg_mask_size,
fp16=config.fp16,
)
elif config.segmentation_network == "basnet":
seg_net = BASNET(
device=config.device,
batch_size=config.batch_size_seg,
input_image_size=config.seg_mask_size,
fp16=config.fp16,
)
elif config.segmentation_network == "tracer_b7":
seg_net = TracerUniversalB7(
device=config.device,
batch_size=config.batch_size_seg,
input_image_size=config.seg_mask_size,
fp16=config.fp16,
)
else:
seg_net = TracerUniversalB7(
device=config.device,
batch_size=config.batch_size_seg,
input_image_size=config.seg_mask_size,
fp16=config.fp16,
)
if config.preprocessing_method == "stub":
preprocessing = PreprocessingStub()
elif config.preprocessing_method == "none":
preprocessing = None
else:
preprocessing = None
if config.postprocessing_method == "fba":
fba = FBAMatting(
device=config.device,
batch_size=config.batch_size_matting,
input_tensor_size=config.matting_mask_size,
fp16=config.fp16,
)
trimap_generator = TrimapGenerator(
prob_threshold=config.trimap_prob_threshold,
kernel_size=config.trimap_dilation,
erosion_iters=config.trimap_erosion,
)
postprocessing = MattingMethod(
device=config.device, matting_module=fba, trimap_generator=trimap_generator
)
elif config.postprocessing_method == "none":
postprocessing = None
else:
postprocessing = None
interface = Interface(
pre_pipe=preprocessing,
post_pipe=postprocessing,
seg_pipe=seg_net,
device=config.device,
)
return interface