File size: 2,338 Bytes
5edd223
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import argparse
import os

from huggingface_hub import snapshot_download
from tqdm import tqdm

from model.cloth_masker import AutoMasker


def parse_args():
    parser = argparse.ArgumentParser(description="Simple example of Preprocess Agnostic Mask")
    parser.add_argument(
        "--data_root_path", 
        type=str, 
        required=True,
        help="Path to the dataset to evaluate."
    )
    parser.add_argument(
        "--repo_path",
        type=str,
        default="zhengchong/CatVTON",
        help=(
            "The Path or repo name of CatVTON. "
        ),
    )
    args = parser.parse_args()
    env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
    if env_local_rank != -1 and env_local_rank != args.local_rank:
        args.local_rank = env_local_rank

    return args

def main(args):
    args.repo_path = snapshot_download(repo_id=args.repo_path)

    automasker = AutoMasker(
        densepose_ckpt=os.path.join(args.repo_path, "DensePose"),
        schp_ckpt=os.path.join(args.repo_path, "SCHP"),
        # device='cuda', 
        device='cpu',
    )
    for sub_folder in ['upper_body', 'lower_body', 'dresses']:
        assert os.path.exists(os.path.join(args.data_root_path, sub_folder)), f"Folder {sub_folder} does not exist."
        pair_txt = os.path.join(args.data_root_path, sub_folder, 'test_pairs_paired.txt')
        assert os.path.exists(pair_txt), f"File {pair_txt} does not exist."
        cloth_type = {'upper_body': 'upper', 'lower_body': 'lower', 'dresses': 'overall'}[sub_folder]
        with open(pair_txt, 'r') as f:
            lines = f.readlines()
        output_dir = os.path.join(args.data_root_path, sub_folder, 'agnostic_masks')
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        for line in tqdm(lines, desc=f"Processing {sub_folder}"):
            person_img, _ = line.strip().split(" ")
            if os.path.exists(os.path.join(output_dir, person_img.replace('.jpg', '.png'))):
                continue
            mask = automasker(
                os.path.join(args.data_root_path, sub_folder, 'images', person_img),
                cloth_type
            )['mask']
            mask.save(os.path.join(output_dir, person_img.replace('.jpg', '.png')))

if __name__ == "__main__":
    args = parse_args()
    main(args)