hysts HF staff commited on
Commit
fed7f36
Β·
1 Parent(s): 009d114
.pre-commit-config.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ exclude: ^stylegan3
2
+ repos:
3
+ - repo: https://github.com/pre-commit/pre-commit-hooks
4
+ rev: v4.2.0
5
+ hooks:
6
+ - id: check-executables-have-shebangs
7
+ - id: check-json
8
+ - id: check-merge-conflict
9
+ - id: check-shebang-scripts-are-executable
10
+ - id: check-toml
11
+ - id: check-yaml
12
+ - id: double-quote-string-fixer
13
+ - id: end-of-file-fixer
14
+ - id: mixed-line-ending
15
+ args: ['--fix=lf']
16
+ - id: requirements-txt-fixer
17
+ - id: trailing-whitespace
18
+ - repo: https://github.com/myint/docformatter
19
+ rev: v1.4
20
+ hooks:
21
+ - id: docformatter
22
+ args: ['--in-place']
23
+ - repo: https://github.com/pycqa/isort
24
+ rev: 5.10.1
25
+ hooks:
26
+ - id: isort
27
+ - repo: https://github.com/pre-commit/mirrors-mypy
28
+ rev: v0.812
29
+ hooks:
30
+ - id: mypy
31
+ args: ['--ignore-missing-imports']
32
+ - repo: https://github.com/google/yapf
33
+ rev: v0.32.0
34
+ hooks:
35
+ - id: yapf
36
+ args: ['--parallel', '--in-place']
37
+ - repo: https://github.com/kynan/nbstripout
38
+ rev: 0.5.0
39
+ hooks:
40
+ - id: nbstripout
41
+ args: ['--extra-keys', 'metadata.interpreter metadata.kernelspec cell.metadata.pycharm']
42
+ - repo: https://github.com/nbQA-dev/nbQA
43
+ rev: 1.3.1
44
+ hooks:
45
+ - id: nbqa-isort
46
+ - id: nbqa-yapf
.style.yapf ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [style]
2
+ based_on_style = pep8
3
+ blank_line_before_nested_class_or_def = false
4
+ spaces_before_comment = 2
5
+ split_before_logical_operator = true
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🐨
4
  colorFrom: indigo
5
  colorTo: red
6
  sdk: gradio
7
- sdk_version: 3.0.5
8
  app_file: app.py
9
  pinned: false
10
  ---
 
4
  colorFrom: indigo
5
  colorTo: red
6
  sdk: gradio
7
+ sdk_version: 3.0.11
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py CHANGED
@@ -3,149 +3,125 @@
3
  from __future__ import annotations
4
 
5
  import argparse
6
- import functools
7
- import os
8
- import pickle
9
- import sys
10
 
11
  import gradio as gr
12
  import numpy as np
13
- import torch
14
- import torch.nn as nn
15
- from huggingface_hub import hf_hub_download
16
-
17
- sys.path.insert(0, 'stylegan3')
18
-
19
- TITLE = 'Self-Distilled StyleGAN'
20
- DESCRIPTION = '''This is an unofficial demo for models provided in https://github.com/self-distilled-stylegan/self-distilled-internet-photos.
21
-
22
- Expected execution time on Hugging Face Spaces: 2s
23
- '''
24
- SAMPLE_IMAGE_DIR = 'https://huggingface.co/spaces/hysts/Self-Distilled-StyleGAN/resolve/main/samples'
25
- ARTICLE = f'''## Generated images
26
- - truncation: 0.7
27
- ### Dogs
28
- - size: 1024x1024
29
- - seed: 0-99
30
- ![Dogs]({SAMPLE_IMAGE_DIR}/dogs.jpg)
31
- ### Elephants
32
- - size: 512x512
33
- - seed: 0-99
34
- ![Elephants]({SAMPLE_IMAGE_DIR}/elephants.jpg)
35
- ### Horses
36
- - size: 256x256
37
- - seed: 0-99
38
- ![Horses]({SAMPLE_IMAGE_DIR}/horses.jpg)
39
- ### Bicycles
40
- - size: 256x256
41
- - seed: 0-99
42
- ![Bicycles]({SAMPLE_IMAGE_DIR}/bicycles.jpg)
43
- ### Lions
44
- - size: 512x512
45
- - seed: 0-99
46
- ![Lions]({SAMPLE_IMAGE_DIR}/lions.jpg)
47
- ### Giraffes
48
- - size: 512x512
49
- - seed: 0-99
50
- ![Giraffes]({SAMPLE_IMAGE_DIR}/giraffes.jpg)
51
- ### Parrots
52
- - size: 512x512
53
- - seed: 0-99
54
- ![Parrots]({SAMPLE_IMAGE_DIR}/parrots.jpg)
55
-
56
- <center><img src="https://visitor-badge.glitch.me/badge?page_id=hysts.self-distilled-stylegan" alt="visitor badge"/></center>
57
- '''
58
-
59
- TOKEN = os.environ['TOKEN']
60
 
61
 
62
  def parse_args() -> argparse.Namespace:
63
  parser = argparse.ArgumentParser()
64
  parser.add_argument('--device', type=str, default='cpu')
65
  parser.add_argument('--theme', type=str)
66
- parser.add_argument('--live', action='store_true')
67
  parser.add_argument('--share', action='store_true')
68
  parser.add_argument('--port', type=int)
69
  parser.add_argument('--disable-queue',
70
  dest='enable_queue',
71
  action='store_false')
72
- parser.add_argument('--allow-flagging', type=str, default='never')
73
  return parser.parse_args()
74
 
75
 
76
- def generate_z(z_dim: int, seed: int, device: torch.device) -> torch.Tensor:
77
- return torch.from_numpy(np.random.RandomState(seed).randn(
78
- 1, z_dim)).to(device).float()
79
 
80
 
81
- @torch.inference_mode()
82
- def generate_image(model_name: str, seed: int, truncation_psi: float,
83
- model_dict: dict[str, nn.Module],
84
- device: torch.device) -> np.ndarray:
85
- model = model_dict[model_name]
86
- seed = int(np.clip(seed, 0, np.iinfo(np.uint32).max))
 
 
87
 
88
- z = generate_z(model.z_dim, seed, device)
89
- label = torch.zeros([1, model.c_dim], device=device)
90
 
91
- out = model(z, label, truncation_psi=truncation_psi)
92
- out = (out.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
93
- return out[0].cpu().numpy()
94
 
95
 
96
- def load_model(model_name: str, device: torch.device) -> nn.Module:
97
- path = hf_hub_download('hysts/Self-Distilled-StyleGAN',
98
- f'models/{model_name}_pytorch.pkl',
99
- use_auth_token=TOKEN)
100
- with open(path, 'rb') as f:
101
- model = pickle.load(f)['G_ema']
102
- model.eval()
103
- model.to(device)
104
- with torch.inference_mode():
105
- z = torch.zeros((1, model.z_dim)).to(device)
106
- label = torch.zeros([1, model.c_dim], device=device)
107
- model(z, label)
108
- return model
109
 
110
 
111
  def main():
112
  args = parse_args()
113
- device = torch.device(args.device)
114
-
115
- model_names = [
116
- 'dogs_1024',
117
- 'elephants_512',
118
- 'horses_256',
119
- 'bicycles_256',
120
- 'lions_512',
121
- 'giraffes_512',
122
- 'parrots_512',
123
- ]
124
-
125
- model_dict = {name: load_model(name, device) for name in model_names}
126
-
127
- func = functools.partial(generate_image,
128
- model_dict=model_dict,
129
- device=device)
130
- func = functools.update_wrapper(func, generate_image)
131
-
132
- gr.Interface(
133
- func,
134
- [
135
- gr.inputs.Radio(
136
- model_names, type='value', default='dogs_1024', label='Model'),
137
- gr.inputs.Number(default=0, label='Seed'),
138
- gr.inputs.Slider(
139
- 0, 2, step=0.05, default=0.7, label='Truncation psi'),
140
- ],
141
- gr.outputs.Image(type='numpy', label='Output'),
142
- title=TITLE,
143
- description=DESCRIPTION,
144
- article=ARTICLE,
145
- theme=args.theme,
146
- allow_flagging=args.allow_flagging,
147
- live=args.live,
148
- ).launch(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  enable_queue=args.enable_queue,
150
  server_port=args.port,
151
  share=args.share,
 
3
  from __future__ import annotations
4
 
5
  import argparse
 
 
 
 
6
 
7
  import gradio as gr
8
  import numpy as np
9
+
10
+ from model import Model
11
+
12
+ TITLE = '# Self-Distilled StyleGAN'
13
+ DESCRIPTION = '''This is an unofficial demo for [https://github.com/self-distilled-stylegan/self-distilled-internet-photos](https://github.com/self-distilled-stylegan/self-distilled-internet-photos).
14
+
15
+ Expected execution time on Hugging Face Spaces: 2s'''
16
+ FOOTER = '<img id="visitor-badge" src="https://visitor-badge.glitch.me/badge?page_id=hysts.self-distilled-stylegan" alt="visitor badge" />'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
 
19
  def parse_args() -> argparse.Namespace:
20
  parser = argparse.ArgumentParser()
21
  parser.add_argument('--device', type=str, default='cpu')
22
  parser.add_argument('--theme', type=str)
 
23
  parser.add_argument('--share', action='store_true')
24
  parser.add_argument('--port', type=int)
25
  parser.add_argument('--disable-queue',
26
  dest='enable_queue',
27
  action='store_false')
 
28
  return parser.parse_args()
29
 
30
 
31
+ def get_sample_image_url(model_name: str) -> str:
32
+ sample_image_dir = 'https://huggingface.co/spaces/hysts/Self-Distilled-StyleGAN/resolve/main/samples'
33
+ return f'{sample_image_dir}/{model_name}.jpg'
34
 
35
 
36
+ def get_sample_image_markdown(model_name: str) -> str:
37
+ url = get_sample_image_url(model_name)
38
+ size = model_name.split('_')[-1]
39
+ return f'''
40
+ - size: {size}x{size}
41
+ - seed: 0-99
42
+ - truncation: 0.7
43
+ ![sample images]({url})'''
44
 
 
 
45
 
46
+ def get_cluster_center_image_url(model_name: str) -> str:
47
+ cluster_center_image_dir = 'https://huggingface.co/spaces/hysts/Self-Distilled-StyleGAN/resolve/main/cluster_center_images'
48
+ return f'{cluster_center_image_dir}/{model_name}.jpg'
49
 
50
 
51
+ def get_cluster_center_image_markdown(model_name: str) -> str:
52
+ url = get_cluster_center_image_url(model_name)
53
+ return f'![cluster center images]({url})'
 
 
 
 
 
 
 
 
 
 
54
 
55
 
56
  def main():
57
  args = parse_args()
58
+
59
+ model = Model(args.device)
60
+
61
+ with gr.Blocks(theme=args.theme, css='style.css') as demo:
62
+ gr.Markdown(TITLE)
63
+ gr.Markdown(DESCRIPTION)
64
+
65
+ with gr.Tabs():
66
+ with gr.TabItem('App'):
67
+ with gr.Row():
68
+ with gr.Column():
69
+ with gr.Group():
70
+ model_name = gr.Dropdown(
71
+ model.MODEL_NAMES,
72
+ value=model.MODEL_NAMES[0],
73
+ label='Model')
74
+ seed = gr.Slider(0,
75
+ np.iinfo(np.uint32).max,
76
+ value=0,
77
+ step=1,
78
+ label='Seed')
79
+ psi = gr.Slider(0,
80
+ 2,
81
+ step=0.05,
82
+ value=0.7,
83
+ label='Truncation psi')
84
+ multimodal_truncation = gr.Checkbox(
85
+ label='Multi-modal Truncation', value=True)
86
+ run_button = gr.Button('Run')
87
+ with gr.Column():
88
+ result = gr.Image(label='Result', elem_id='result')
89
+ with gr.TabItem('Sample Images'):
90
+ with gr.Row():
91
+ model_name2 = gr.Dropdown(model.MODEL_NAMES,
92
+ value=model.MODEL_NAMES[0],
93
+ label='Model')
94
+ with gr.Row():
95
+ text = get_sample_image_markdown(model_name2.value)
96
+ sample_images = gr.Markdown(text)
97
+ with gr.TabItem('Cluster Center Images'):
98
+ with gr.Row():
99
+ model_name3 = gr.Dropdown(model.MODEL_NAMES,
100
+ value=model.MODEL_NAMES[0],
101
+ label='Model')
102
+ with gr.Row():
103
+ text = get_cluster_center_image_markdown(model_name3.value)
104
+ cluster_center_images = gr.Markdown(value=text)
105
+
106
+ gr.Markdown(FOOTER)
107
+
108
+ model_name.change(fn=model.set_model, inputs=model_name, outputs=None)
109
+ run_button.click(fn=model.set_model_and_generate_image,
110
+ inputs=[
111
+ model_name,
112
+ seed,
113
+ psi,
114
+ multimodal_truncation,
115
+ ],
116
+ outputs=result)
117
+ model_name2.change(fn=get_sample_image_markdown,
118
+ inputs=model_name2,
119
+ outputs=sample_images)
120
+ model_name3.change(fn=get_cluster_center_image_markdown,
121
+ inputs=model_name3,
122
+ outputs=cluster_center_images)
123
+
124
+ demo.launch(
125
  enable_queue=args.enable_queue,
126
  server_port=args.port,
127
  share=args.share,
samples/bicycles.jpg β†’ cluster_center_images/bicycles_256.jpg RENAMED
File without changes
samples/dogs.jpg β†’ cluster_center_images/dogs_1024.jpg RENAMED
File without changes
samples/horses.jpg β†’ cluster_center_images/elephants_512.jpg RENAMED
File without changes
samples/parrots.jpg β†’ cluster_center_images/giraffes_512.jpg RENAMED
File without changes
samples/giraffes.jpg β†’ cluster_center_images/horses_256.jpg RENAMED
File without changes
cluster_center_images/lions_512.jpg ADDED

Git LFS Details

  • SHA256: 1775e96848d8f2b8150a1a2f7ce09fe051566385c20a1aca317e5bd1cd1fbf0a
  • Pointer size: 132 Bytes
  • Size of remote file: 7.75 MB
cluster_center_images/parrots_512.jpg ADDED

Git LFS Details

  • SHA256: ee1f0758a580117cbeccda384441edb896e43e7d8295f56cc137962f4b7df2f9
  • Pointer size: 132 Bytes
  • Size of remote file: 4.12 MB
model.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import pathlib
5
+ import pickle
6
+ import sys
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ from huggingface_hub import hf_hub_download
12
+
13
+ current_dir = pathlib.Path(__file__).parent
14
+ submodule_dir = current_dir / 'stylegan3'
15
+ sys.path.insert(0, submodule_dir.as_posix())
16
+
17
+ HF_TOKEN = os.environ['HF_TOKEN']
18
+
19
+
20
+ class Model:
21
+
22
+ MODEL_NAMES = [
23
+ 'dogs_1024',
24
+ 'elephants_512',
25
+ 'horses_256',
26
+ 'bicycles_256',
27
+ 'lions_512',
28
+ 'giraffes_512',
29
+ 'parrots_512',
30
+ ]
31
+
32
+ def __init__(self, device: str | torch.device):
33
+ self.device = torch.device(device)
34
+ self._download_all_models()
35
+ self._download_all_cluster_centers()
36
+
37
+ self.model_name = self.MODEL_NAMES[0]
38
+ self.model = self._load_model(self.model_name)
39
+ self.cluster_centers = self._load_cluster_centers(self.model_name)
40
+
41
+ def _load_model(self, model_name: str) -> nn.Module:
42
+ path = hf_hub_download('hysts/Self-Distilled-StyleGAN',
43
+ f'models/{model_name}_pytorch.pkl',
44
+ use_auth_token=HF_TOKEN)
45
+ with open(path, 'rb') as f:
46
+ model = pickle.load(f)['G_ema']
47
+ model.eval()
48
+ model.to(self.device)
49
+ return model
50
+
51
+ def _load_cluster_centers(self, model_name: str) -> torch.Tensor:
52
+ path = hf_hub_download('hysts/Self-Distilled-StyleGAN',
53
+ f'cluster_centers/{model_name}.npy',
54
+ use_auth_token=HF_TOKEN)
55
+ centers = np.load(path)
56
+ centers = torch.from_numpy(centers).float().to(self.device)
57
+ return centers
58
+
59
+ def set_model(self, model_name: str) -> None:
60
+ if model_name == self.model_name:
61
+ return
62
+ self.model_name = model_name
63
+ self.model = self._load_model(model_name)
64
+ self.cluster_centers = self._load_cluster_centers(model_name)
65
+
66
+ def _download_all_models(self):
67
+ for name in self.MODEL_NAMES:
68
+ self._load_model(name)
69
+
70
+ def _download_all_cluster_centers(self):
71
+ for name in self.MODEL_NAMES:
72
+ self._load_cluster_centers(name)
73
+
74
+ def generate_z(self, seed: int) -> torch.Tensor:
75
+ seed = int(np.clip(seed, 0, np.iinfo(np.uint32).max))
76
+ return torch.from_numpy(
77
+ np.random.RandomState(seed).randn(1, self.model.z_dim)).float().to(
78
+ self.device)
79
+
80
+ def compute_w(self, z: torch.Tensor) -> torch.Tensor:
81
+ label = torch.zeros((1, self.model.c_dim), device=self.device)
82
+ w = self.model.mapping(z, label)
83
+ return w
84
+
85
+ def find_nearest_cluster_center(self, w: torch.Tensor) -> int:
86
+ # Here, Euclidean distance is used instead of LPIPS distance
87
+ dist2 = ((self.cluster_centers - w)**2).sum(dim=1)
88
+ return torch.argmin(dist2).item()
89
+
90
+ @staticmethod
91
+ def truncate_w(w_center: torch.Tensor, w: torch.Tensor,
92
+ psi: float) -> torch.Tensor:
93
+ if psi == 1:
94
+ return w
95
+ return w_center.lerp(w, psi)
96
+
97
+ @torch.inference_mode()
98
+ def synthesize(self, w: torch.Tensor) -> torch.Tensor:
99
+ return self.model.synthesis(w)
100
+
101
+ def postprocess(self, tensor: torch.Tensor) -> np.ndarray:
102
+ tensor = (tensor.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(
103
+ torch.uint8)
104
+ return tensor.cpu().numpy()
105
+
106
+ def generate_image(self, seed: int, truncation_psi: float,
107
+ multimodal_truncation: bool) -> np.ndarray:
108
+ z = self.generate_z(seed)
109
+ w = self.compute_w(z)
110
+ if multimodal_truncation:
111
+ cluster_index = self.find_nearest_cluster_center(w[:, 0])
112
+ w0 = self.cluster_centers[cluster_index]
113
+ else:
114
+ w0 = self.model.mapping.w_avg
115
+ new_w = self.truncate_w(w0, w, truncation_psi)
116
+ out = self.synthesize(new_w)
117
+ out = self.postprocess(out)
118
+ return out[0]
119
+
120
+ def set_model_and_generate_image(
121
+ self, model_name: str, seed: int, truncation_psi: float,
122
+ multimodal_truncation: bool) -> np.ndarray:
123
+ self.set_model(model_name)
124
+ return self.generate_image(seed, truncation_psi, multimodal_truncation)
samples/bicycles_256.jpg ADDED

Git LFS Details

  • SHA256: cfbf00881ce820e05ac54411a882ff36ae3218eebf9c1b28d8c4520d60b92c61
  • Pointer size: 132 Bytes
  • Size of remote file: 3.11 MB
samples/{elephants.jpg β†’ dogs_1024.jpg} RENAMED
File without changes
samples/elephants_512.jpg ADDED

Git LFS Details

  • SHA256: c91d4a00e0e85f0e47f48f0bfa043c2ca0d98b8446650305a2f1b7cc44e73069
  • Pointer size: 133 Bytes
  • Size of remote file: 12 MB
samples/giraffes_512.jpg ADDED

Git LFS Details

  • SHA256: 638b0614f8af2cc4e1eeb5aee790169b9512d053b851720eaaba2288118b0fb2
  • Pointer size: 133 Bytes
  • Size of remote file: 10.4 MB
samples/horses_256.jpg ADDED

Git LFS Details

  • SHA256: 5a91c99a45303c1cbbfe6f2e4e28ad5bdb985987ca90a5ff43d6eb04d8b11c33
  • Pointer size: 132 Bytes
  • Size of remote file: 3.13 MB
samples/lions.jpg DELETED

Git LFS Details

  • SHA256: 4216e153da49fbff81ef41484f48f1c68c6c1d455cba0a1eed8458aa64dacccc
  • Pointer size: 133 Bytes
  • Size of remote file: 11.3 MB
samples/lions_512.jpg ADDED

Git LFS Details

  • SHA256: 6ea7b36ca52b76476be1736292ab05e0a140151cd8353deeb31d131dcbb7f60c
  • Pointer size: 133 Bytes
  • Size of remote file: 11.2 MB
samples/parrots_512.jpg ADDED

Git LFS Details

  • SHA256: 65460c307f57b1e9740c4515f3e56d39b54be9a1438fa7cf4339328edfa9455a
  • Pointer size: 132 Bytes
  • Size of remote file: 6.93 MB
style.css ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }
4
+ div#result {
5
+ max-width: 600px;
6
+ max-height: 600px;
7
+ }
8
+ img#visitor-badge {
9
+ display: block;
10
+ margin: auto;
11
+ }