hysts HF staff commited on
Commit
c4bc1d1
1 Parent(s): df24b2d
.gitattributes CHANGED
@@ -1,3 +1,4 @@
 
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
 
1
+ *.jpg filter=lfs diff=lfs merge=lfs -text
2
  *.7z filter=lfs diff=lfs merge=lfs -text
3
  *.arrow filter=lfs diff=lfs merge=lfs -text
4
  *.bin filter=lfs diff=lfs merge=lfs -text
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "CelebAMask-HQ"]
2
+ path = CelebAMask-HQ
3
+ url = https://github.com/switchablenorms/CelebAMask-HQ
CelebAMask-HQ ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 3f6c45df5e67130568f9459f24f8d6a5ff836d30
app.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import functools
7
+ import os
8
+ import pathlib
9
+ import sys
10
+ from typing import Callable
11
+
12
+ import gradio as gr
13
+ import numpy as np
14
+ import PIL.Image
15
+ import torch
16
+ import torch.nn as nn
17
+ import torchvision.transforms as T
18
+ from huggingface_hub import hf_hub_download
19
+
20
+ sys.path.insert(0, 'CelebAMask-HQ/face_parsing')
21
+
22
+ from unet import unet
23
+ from utils import generate_label
24
+
25
+ ORIGINAL_REPO_URL = 'https://github.com/switchablenorms/CelebAMask-HQ'
26
+ TITLE = 'CelebAMask-HQ Face Parsing'
27
+ DESCRIPTION = f'This is a demo for the model provided in {ORIGINAL_REPO_URL}.'
28
+ ARTICLE = None
29
+
30
+ TOKEN = os.environ['TOKEN']
31
+
32
+
33
+ def parse_args() -> argparse.Namespace:
34
+ parser = argparse.ArgumentParser()
35
+ parser.add_argument('--device', type=str, default='cpu')
36
+ parser.add_argument('--theme', type=str)
37
+ parser.add_argument('--live', action='store_true')
38
+ parser.add_argument('--share', action='store_true')
39
+ parser.add_argument('--port', type=int)
40
+ parser.add_argument('--disable-queue',
41
+ dest='enable_queue',
42
+ action='store_false')
43
+ parser.add_argument('--allow-flagging', type=str, default='never')
44
+ parser.add_argument('--allow-screenshot', action='store_true')
45
+ return parser.parse_args()
46
+
47
+
48
+ @torch.inference_mode()
49
+ def predict(image: PIL.Image.Image, model: nn.Module, transform: Callable,
50
+ device: torch.device) -> np.ndarray:
51
+ data = transform(image)
52
+ data = data.unsqueeze(0).to(device)
53
+ out = model(data)
54
+ out = generate_label(out, 512)
55
+ out = out[0].cpu().numpy().transpose(1, 2, 0)
56
+ out = np.clip(np.round(out * 255), 0, 255).astype(np.uint8)
57
+
58
+ res = np.asarray(image.resize(
59
+ (512, 512))).astype(float) * 0.5 + out.astype(float) * 0.5
60
+ res = np.clip(np.round(res), 0, 255).astype(np.uint8)
61
+ return out, res
62
+
63
+
64
+ def load_model(device: torch.device) -> nn.Module:
65
+ path = hf_hub_download('hysts/CelebAMask-HQ-Face-Parsing',
66
+ 'models/model.pth',
67
+ use_auth_token=TOKEN)
68
+ state_dict = torch.load(path, map_location='cpu')
69
+ model = unet()
70
+ model.load_state_dict(state_dict)
71
+ model.eval()
72
+ model.to(device)
73
+ return model
74
+
75
+
76
+ def main():
77
+ gr.close_all()
78
+
79
+ args = parse_args()
80
+ device = torch.device(args.device)
81
+
82
+ model = load_model(device)
83
+ transform = T.Compose([
84
+ T.Resize((512, 512), interpolation=PIL.Image.NEAREST),
85
+ T.ToTensor(),
86
+ T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
87
+ ])
88
+
89
+ func = functools.partial(predict,
90
+ model=model,
91
+ transform=transform,
92
+ device=device)
93
+ func = functools.update_wrapper(func, predict)
94
+
95
+ image_dir = pathlib.Path('images')
96
+ examples = [[path.as_posix()] for path in sorted(image_dir.glob('*.jpg'))]
97
+
98
+ gr.Interface(
99
+ func,
100
+ gr.inputs.Image(type='pil', label='Input'),
101
+ [
102
+ gr.outputs.Image(type='numpy', label='Predicted Labels'),
103
+ gr.outputs.Image(type='numpy', label='Masked'),
104
+ ],
105
+ examples=examples,
106
+ title=TITLE,
107
+ description=DESCRIPTION,
108
+ article=ARTICLE,
109
+ theme=args.theme,
110
+ allow_screenshot=args.allow_screenshot,
111
+ allow_flagging=args.allow_flagging,
112
+ live=args.live,
113
+ ).launch(
114
+ enable_queue=args.enable_queue,
115
+ server_port=args.port,
116
+ share=args.share,
117
+ )
118
+
119
+
120
+ if __name__ == '__main__':
121
+ main()
images/95UF6LXe-Lo.jpg ADDED

Git LFS Details

  • SHA256: 9ba751a6519822fa683e062ee3a383e748f15b41d4ca87d14c4fa73f9beed845
  • Pointer size: 131 Bytes
  • Size of remote file: 503 kB
images/ILip77SbmOE.jpg ADDED

Git LFS Details

  • SHA256: 3eed82923bc76a90f067415f148d56239fdfa4a1aca9eef1d459bc6050c9dde8
  • Pointer size: 131 Bytes
  • Size of remote file: 939 kB
images/README.md ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ These images are freely-usable ones from [Unsplash](https://unsplash.com/).
2
+
3
+ - https://unsplash.com/photos/rDEOVtE7vOs
4
+ - https://unsplash.com/photos/et_78QkMMQs
5
+ - https://unsplash.com/photos/ILip77SbmOE
6
+ - https://unsplash.com/photos/95UF6LXe-Lo
7
+
images/et_78QkMMQs.jpg ADDED

Git LFS Details

  • SHA256: c63a2e9de5eda3cb28012cfc8e4ba9384daeda8cca7a8989ad90b21a1293cc6f
  • Pointer size: 131 Bytes
  • Size of remote file: 371 kB
images/rDEOVtE7vOs.jpg ADDED

Git LFS Details

  • SHA256: b136bf195fef5599f277a563f0eef79af5301d9352d4ebf82bd7a0a061b7bdc0
  • Pointer size: 131 Bytes
  • Size of remote file: 155 kB
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ numpy==1.22.3
2
+ opencv-python-headless==4.5.5.62
3
+ Pillow==9.0.1
4
+ torch==1.11.0
5
+ torchvision==0.12.0