|
|
|
|
|
""" |
|
__author__ = 'Ahmad Abdulnasir Shuaib <[email protected]>' |
|
__homepage__ = https://ahmadabdulnasir.com.ng |
|
__copyright__ = 'Copyright (c) 2024, salafi' |
|
__version__ = "0.01t" |
|
""" |
|
from fastapi import FastAPI, File, UploadFile, Form |
|
from fastapi.responses import FileResponse |
|
|
|
from pydantic_settings import BaseSettings |
|
|
|
|
|
from PIL import Image |
|
import io |
|
import os |
|
from datetime import datetime |
|
import torch |
|
import numpy as np |
|
from diffusers.image_processor import VaeImageProcessor |
|
from huggingface_hub import snapshot_download |
|
from model.cloth_masker import AutoMasker |
|
from model.pipeline import CatVTONPipeline |
|
from utils import init_weight_dtype, resize_and_crop, resize_and_padding |
|
|
|
class Settings(BaseSettings): |
|
base_model_path: str = "Abhilashvj/stable-diffusion-inpainting-copy" |
|
resume_path: str = "abubakar123456/CatVTON" |
|
output_dir: str = "resource/demo/output" |
|
width: int = 768 |
|
height: int = 1024 |
|
allow_tf32: bool = True |
|
mixed_precision: str = "bf16" |
|
|
|
class Config: |
|
env_file = ".env" |
|
|
|
settings = Settings() |
|
|
|
app = FastAPI() |
|
|
|
|
|
repo_path = snapshot_download(repo_id=settings.resume_path) |
|
|
|
pipeline = CatVTONPipeline( |
|
base_ckpt=settings.base_model_path, |
|
attn_ckpt=repo_path, |
|
attn_ckpt_version="mix", |
|
weight_dtype=init_weight_dtype(settings.mixed_precision), |
|
use_tf32=settings.allow_tf32, |
|
device='cpu' |
|
) |
|
|
|
mask_processor = VaeImageProcessor(vae_scale_factor=8, do_normalize=False, do_binarize=True, do_convert_grayscale=True) |
|
automasker = AutoMasker( |
|
densepose_ckpt=os.path.join(repo_path, "DensePose"), |
|
schp_ckpt=os.path.join(repo_path, "SCHP"), |
|
device='cpu' |
|
) |
|
|
|
@app.post("/process_images/") |
|
async def process_images( |
|
person_image: UploadFile = File(...), |
|
cloth_image: UploadFile = File(...), |
|
cloth_type: str = Form(...), |
|
num_inference_steps: int = Form(50), |
|
guidance_scale: float = Form(2.5), |
|
seed: int = Form(42) |
|
): |
|
|
|
person_img = Image.open(io.BytesIO(await person_image.read())).convert("RGB") |
|
cloth_img = Image.open(io.BytesIO(await cloth_image.read())).convert("RGB") |
|
|
|
person_img = resize_and_crop(person_img, (settings.width, settings.height)) |
|
cloth_img = resize_and_padding(cloth_img, (settings.width, settings.height)) |
|
|
|
|
|
mask = automasker(person_img, cloth_type)['mask'] |
|
mask = mask_processor.blur(mask, blur_factor=9) |
|
|
|
|
|
generator = torch.Generator(device='cpu').manual_seed(seed) if seed != -1 else None |
|
|
|
|
|
result_image = pipeline( |
|
image=person_img, |
|
condition_image=cloth_img, |
|
mask=mask, |
|
num_inference_steps=num_inference_steps, |
|
guidance_scale=guidance_scale, |
|
generator=generator |
|
)[0] |
|
|
|
|
|
tmp_folder = settings.output_dir |
|
date_str = datetime.now().strftime("%Y%m%d%H%M%S") |
|
result_save_path = os.path.join(tmp_folder, date_str[:8], date_str[8:] + ".png") |
|
os.makedirs(os.path.dirname(result_save_path), exist_ok=True) |
|
result_image.save(result_save_path) |
|
|
|
|
|
return FileResponse(result_save_path, media_type="image/png") |
|
|
|
|
|
def boot(): |
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|
|
if __name__ == "__main__": |
|
boot() |
|
|