File size: 3,531 Bytes
5edd223
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
__author__ = 'Ahmad Abdulnasir Shuaib <[email protected]>'
__homepage__ = https://ahmadabdulnasir.com.ng
__copyright__ = 'Copyright (c) 2024, salafi'
__version__ = "0.01t"
"""
from fastapi import FastAPI, File, UploadFile, Form
from fastapi.responses import FileResponse
# from pydantic import BaseSettings
from pydantic_settings import BaseSettings


from PIL import Image
import io
import os
from datetime import datetime
import torch
import numpy as np
from diffusers.image_processor import VaeImageProcessor
from huggingface_hub import snapshot_download
from model.cloth_masker import AutoMasker
from model.pipeline import CatVTONPipeline
from utils import init_weight_dtype, resize_and_crop, resize_and_padding

class Settings(BaseSettings):
    base_model_path: str = "Abhilashvj/stable-diffusion-inpainting-copy" #"runwayml/stable-diffusion-inpainting"  #
    resume_path: str = "abubakar123456/CatVTON" #"zhengchong/CatVTON"
    output_dir: str = "resource/demo/output" 
    width: int = 768
    height: int = 1024
    allow_tf32: bool = True
    mixed_precision: str = "bf16"

    class Config:
        env_file = ".env"

settings = Settings()

app = FastAPI()

# Initialize your models and processors here
repo_path = snapshot_download(repo_id=settings.resume_path)

pipeline = CatVTONPipeline(
    base_ckpt=settings.base_model_path,
    attn_ckpt=repo_path,
    attn_ckpt_version="mix",
    weight_dtype=init_weight_dtype(settings.mixed_precision),
    use_tf32=settings.allow_tf32,
    device='cpu'
)

mask_processor = VaeImageProcessor(vae_scale_factor=8, do_normalize=False, do_binarize=True, do_convert_grayscale=True)
automasker = AutoMasker(
    densepose_ckpt=os.path.join(repo_path, "DensePose"),
    schp_ckpt=os.path.join(repo_path, "SCHP"),
    device='cpu'
)

@app.post("/process_images/")
async def process_images(
    person_image: UploadFile = File(...),
    cloth_image: UploadFile = File(...),
    cloth_type: str = Form(...),
    num_inference_steps: int = Form(50),
    guidance_scale: float = Form(2.5),
    seed: int = Form(42)
):
    # Read and process the uploaded images
    person_img = Image.open(io.BytesIO(await person_image.read())).convert("RGB")
    cloth_img = Image.open(io.BytesIO(await cloth_image.read())).convert("RGB")

    person_img = resize_and_crop(person_img, (settings.width, settings.height))
    cloth_img = resize_and_padding(cloth_img, (settings.width, settings.height))

    # Generate mask
    mask = automasker(person_img, cloth_type)['mask']
    mask = mask_processor.blur(mask, blur_factor=9)

    # Set up generator for reproducibility
    generator = torch.Generator(device='cpu').manual_seed(seed) if seed != -1 else None

    # Run inference
    result_image = pipeline(
        image=person_img,
        condition_image=cloth_img,
        mask=mask,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        generator=generator
    )[0]

    # Save the result
    tmp_folder = settings.output_dir
    date_str = datetime.now().strftime("%Y%m%d%H%M%S")
    result_save_path = os.path.join(tmp_folder, date_str[:8], date_str[8:] + ".png")
    os.makedirs(os.path.dirname(result_save_path), exist_ok=True)
    result_image.save(result_save_path)

    # Return the result image
    return FileResponse(result_save_path, media_type="image/png")

    
def boot():
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)

if __name__ == "__main__":
    boot()