Fashable-Tryon / app_api.py
abubakar123456's picture
Upload 18 files
5edd223 verified
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
__author__ = 'Ahmad Abdulnasir Shuaib <[email protected]>'
__homepage__ = https://ahmadabdulnasir.com.ng
__copyright__ = 'Copyright (c) 2024, salafi'
__version__ = "0.01t"
"""
from fastapi import FastAPI, File, UploadFile, Form
from fastapi.responses import FileResponse
# from pydantic import BaseSettings
from pydantic_settings import BaseSettings
from PIL import Image
import io
import os
from datetime import datetime
import torch
import numpy as np
from diffusers.image_processor import VaeImageProcessor
from huggingface_hub import snapshot_download
from model.cloth_masker import AutoMasker
from model.pipeline import CatVTONPipeline
from utils import init_weight_dtype, resize_and_crop, resize_and_padding
class Settings(BaseSettings):
base_model_path: str = "Abhilashvj/stable-diffusion-inpainting-copy" #"runwayml/stable-diffusion-inpainting" #
resume_path: str = "abubakar123456/CatVTON" #"zhengchong/CatVTON"
output_dir: str = "resource/demo/output"
width: int = 768
height: int = 1024
allow_tf32: bool = True
mixed_precision: str = "bf16"
class Config:
env_file = ".env"
settings = Settings()
app = FastAPI()
# Initialize your models and processors here
repo_path = snapshot_download(repo_id=settings.resume_path)
pipeline = CatVTONPipeline(
base_ckpt=settings.base_model_path,
attn_ckpt=repo_path,
attn_ckpt_version="mix",
weight_dtype=init_weight_dtype(settings.mixed_precision),
use_tf32=settings.allow_tf32,
device='cpu'
)
mask_processor = VaeImageProcessor(vae_scale_factor=8, do_normalize=False, do_binarize=True, do_convert_grayscale=True)
automasker = AutoMasker(
densepose_ckpt=os.path.join(repo_path, "DensePose"),
schp_ckpt=os.path.join(repo_path, "SCHP"),
device='cpu'
)
@app.post("/process_images/")
async def process_images(
person_image: UploadFile = File(...),
cloth_image: UploadFile = File(...),
cloth_type: str = Form(...),
num_inference_steps: int = Form(50),
guidance_scale: float = Form(2.5),
seed: int = Form(42)
):
# Read and process the uploaded images
person_img = Image.open(io.BytesIO(await person_image.read())).convert("RGB")
cloth_img = Image.open(io.BytesIO(await cloth_image.read())).convert("RGB")
person_img = resize_and_crop(person_img, (settings.width, settings.height))
cloth_img = resize_and_padding(cloth_img, (settings.width, settings.height))
# Generate mask
mask = automasker(person_img, cloth_type)['mask']
mask = mask_processor.blur(mask, blur_factor=9)
# Set up generator for reproducibility
generator = torch.Generator(device='cpu').manual_seed(seed) if seed != -1 else None
# Run inference
result_image = pipeline(
image=person_img,
condition_image=cloth_img,
mask=mask,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=generator
)[0]
# Save the result
tmp_folder = settings.output_dir
date_str = datetime.now().strftime("%Y%m%d%H%M%S")
result_save_path = os.path.join(tmp_folder, date_str[:8], date_str[8:] + ".png")
os.makedirs(os.path.dirname(result_save_path), exist_ok=True)
result_image.save(result_save_path)
# Return the result image
return FileResponse(result_save_path, media_type="image/png")
def boot():
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
if __name__ == "__main__":
boot()