hysts HF staff commited on
Commit
59b7443
·
1 Parent(s): ca926cd
Files changed (8) hide show
  1. .pre-commit-config.yaml +59 -35
  2. .style.yapf +0 -5
  3. .vscode/settings.json +30 -0
  4. README.md +2 -2
  5. app.py +33 -38
  6. model.py +17 -22
  7. requirements.txt +4 -3
  8. style.css +6 -6
.pre-commit-config.yaml CHANGED
@@ -1,36 +1,60 @@
1
- exclude: ^projected_gan
2
  repos:
3
- - repo: https://github.com/pre-commit/pre-commit-hooks
4
- rev: v4.2.0
5
- hooks:
6
- - id: check-executables-have-shebangs
7
- - id: check-json
8
- - id: check-merge-conflict
9
- - id: check-shebang-scripts-are-executable
10
- - id: check-toml
11
- - id: check-yaml
12
- - id: double-quote-string-fixer
13
- - id: end-of-file-fixer
14
- - id: mixed-line-ending
15
- args: ['--fix=lf']
16
- - id: requirements-txt-fixer
17
- - id: trailing-whitespace
18
- - repo: https://github.com/myint/docformatter
19
- rev: v1.4
20
- hooks:
21
- - id: docformatter
22
- args: ['--in-place']
23
- - repo: https://github.com/pycqa/isort
24
- rev: 5.12.0
25
- hooks:
26
- - id: isort
27
- - repo: https://github.com/pre-commit/mirrors-mypy
28
- rev: v0.991
29
- hooks:
30
- - id: mypy
31
- args: ['--ignore-missing-imports']
32
- - repo: https://github.com/google/yapf
33
- rev: v0.32.0
34
- hooks:
35
- - id: yapf
36
- args: ['--parallel', '--in-place']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v4.5.0
4
+ hooks:
5
+ - id: check-executables-have-shebangs
6
+ - id: check-json
7
+ - id: check-merge-conflict
8
+ - id: check-shebang-scripts-are-executable
9
+ - id: check-toml
10
+ - id: check-yaml
11
+ - id: end-of-file-fixer
12
+ - id: mixed-line-ending
13
+ args: ["--fix=lf"]
14
+ - id: requirements-txt-fixer
15
+ - id: trailing-whitespace
16
+ - repo: https://github.com/myint/docformatter
17
+ rev: v1.7.5
18
+ hooks:
19
+ - id: docformatter
20
+ args: ["--in-place"]
21
+ - repo: https://github.com/pycqa/isort
22
+ rev: 5.13.2
23
+ hooks:
24
+ - id: isort
25
+ args: ["--profile", "black"]
26
+ - repo: https://github.com/pre-commit/mirrors-mypy
27
+ rev: v1.8.0
28
+ hooks:
29
+ - id: mypy
30
+ args: ["--ignore-missing-imports"]
31
+ additional_dependencies:
32
+ [
33
+ "types-python-slugify",
34
+ "types-requests",
35
+ "types-PyYAML",
36
+ "types-pytz",
37
+ ]
38
+ - repo: https://github.com/psf/black
39
+ rev: 24.2.0
40
+ hooks:
41
+ - id: black
42
+ language_version: python3.10
43
+ args: ["--line-length", "119"]
44
+ - repo: https://github.com/kynan/nbstripout
45
+ rev: 0.7.1
46
+ hooks:
47
+ - id: nbstripout
48
+ args:
49
+ [
50
+ "--extra-keys",
51
+ "metadata.interpreter metadata.kernelspec cell.metadata.pycharm",
52
+ ]
53
+ - repo: https://github.com/nbQA-dev/nbQA
54
+ rev: 1.7.1
55
+ hooks:
56
+ - id: nbqa-black
57
+ - id: nbqa-pyupgrade
58
+ args: ["--py37-plus"]
59
+ - id: nbqa-isort
60
+ args: ["--float-to-top"]
.style.yapf DELETED
@@ -1,5 +0,0 @@
1
- [style]
2
- based_on_style = pep8
3
- blank_line_before_nested_class_or_def = false
4
- spaces_before_comment = 2
5
- split_before_logical_operator = true
 
 
 
 
 
 
.vscode/settings.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "editor.formatOnSave": true,
3
+ "files.insertFinalNewline": false,
4
+ "[python]": {
5
+ "editor.defaultFormatter": "ms-python.black-formatter",
6
+ "editor.formatOnType": true,
7
+ "editor.codeActionsOnSave": {
8
+ "source.organizeImports": "explicit"
9
+ }
10
+ },
11
+ "[jupyter]": {
12
+ "files.insertFinalNewline": false
13
+ },
14
+ "black-formatter.args": [
15
+ "--line-length=119"
16
+ ],
17
+ "isort.args": ["--profile", "black"],
18
+ "flake8.args": [
19
+ "--max-line-length=119"
20
+ ],
21
+ "ruff.lint.args": [
22
+ "--line-length=119"
23
+ ],
24
+ "notebook.output.scrolling": true,
25
+ "notebook.formatOnCellExecution": true,
26
+ "notebook.formatOnSave.enabled": true,
27
+ "notebook.codeActionsOnSave": {
28
+ "source.organizeImports": "explicit"
29
+ }
30
+ }
README.md CHANGED
@@ -4,8 +4,8 @@ emoji: 🌍
4
  colorFrom: red
5
  colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 3.36.1
8
- python_version: 3.9.16
9
  app_file: app.py
10
  pinned: false
11
  suggested_hardware: t4-small
 
4
  colorFrom: red
5
  colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 4.20.0
8
+ python_version: 3.9.18
9
  app_file: app.py
10
  pinned: false
11
  suggested_hardware: t4-small
app.py CHANGED
@@ -7,68 +7,63 @@ import numpy as np
7
 
8
  from model import Model
9
 
10
- DESCRIPTION = '# [Projected GAN](https://github.com/autonomousvision/projected_gan)'
11
 
12
 
13
  def get_sample_image_url(name: str) -> str:
14
- sample_image_dir = 'https://huggingface.co/spaces/hysts/projected_gan/resolve/main/samples'
15
- return f'{sample_image_dir}/{name}.jpg'
16
 
17
 
18
  def get_sample_image_markdown(name: str) -> str:
19
  url = get_sample_image_url(name)
20
- return f'''
21
  - size: 256x256
22
  - seed: 0-99
23
  - truncation: 0.7
24
- ![sample images]({url})'''
25
 
26
 
27
  model = Model()
28
 
29
- with gr.Blocks(css='style.css') as demo:
30
  gr.Markdown(DESCRIPTION)
31
 
32
  with gr.Tabs():
33
- with gr.TabItem('App'):
34
  with gr.Row():
35
  with gr.Column():
36
- model_name = gr.Dropdown(label='Model',
37
- choices=model.MODEL_NAMES,
38
- value=model.MODEL_NAMES[8])
39
- seed = gr.Slider(label='Seed',
40
- minimum=0,
41
- maximum=np.iinfo(np.uint32).max,
42
- step=1,
43
- value=0)
44
- psi = gr.Slider(label='Truncation psi',
45
- minimum=0,
46
- maximum=2,
47
- step=0.05,
48
- value=0.7)
49
- run_button = gr.Button('Run')
50
  with gr.Column():
51
- result = gr.Image(label='Result', elem_id='result')
52
 
53
- with gr.TabItem('Sample Images'):
54
  with gr.Row():
55
- model_name2 = gr.Dropdown(label='Model',
56
- choices=model.MODEL_NAMES,
57
- value=model.MODEL_NAMES[0])
58
  with gr.Row():
59
  text = get_sample_image_markdown(model_name2.value)
60
  sample_images = gr.Markdown(text)
61
 
62
- model_name.change(fn=model.set_model, inputs=model_name)
63
- run_button.click(fn=model.set_model_and_generate_image,
64
- inputs=[
65
- model_name,
66
- seed,
67
- psi,
68
- ],
69
- outputs=result)
70
- model_name2.change(fn=get_sample_image_markdown,
71
- inputs=model_name2,
72
- outputs=sample_images)
 
 
 
 
 
 
73
 
74
- demo.queue(max_size=10).launch()
 
 
7
 
8
  from model import Model
9
 
10
+ DESCRIPTION = "# [Projected GAN](https://github.com/autonomousvision/projected_gan)"
11
 
12
 
13
  def get_sample_image_url(name: str) -> str:
14
+ sample_image_dir = "https://huggingface.co/spaces/hysts/projected_gan/resolve/main/samples"
15
+ return f"{sample_image_dir}/{name}.jpg"
16
 
17
 
18
  def get_sample_image_markdown(name: str) -> str:
19
  url = get_sample_image_url(name)
20
+ return f"""
21
  - size: 256x256
22
  - seed: 0-99
23
  - truncation: 0.7
24
+ ![sample images]({url})"""
25
 
26
 
27
  model = Model()
28
 
29
+ with gr.Blocks(css="style.css") as demo:
30
  gr.Markdown(DESCRIPTION)
31
 
32
  with gr.Tabs():
33
+ with gr.TabItem("App"):
34
  with gr.Row():
35
  with gr.Column():
36
+ model_name = gr.Dropdown(label="Model", choices=model.MODEL_NAMES, value=model.MODEL_NAMES[8])
37
+ seed = gr.Slider(label="Seed", minimum=0, maximum=np.iinfo(np.uint32).max, step=1, value=0)
38
+ psi = gr.Slider(label="Truncation psi", minimum=0, maximum=2, step=0.05, value=0.7)
39
+ run_button = gr.Button()
 
 
 
 
 
 
 
 
 
 
40
  with gr.Column():
41
+ result = gr.Image(label="Result")
42
 
43
+ with gr.TabItem("Sample Images"):
44
  with gr.Row():
45
+ model_name2 = gr.Dropdown(label="Model", choices=model.MODEL_NAMES, value=model.MODEL_NAMES[0])
 
 
46
  with gr.Row():
47
  text = get_sample_image_markdown(model_name2.value)
48
  sample_images = gr.Markdown(text)
49
 
50
+ run_button.click(
51
+ fn=model.set_model_and_generate_image,
52
+ inputs=[
53
+ model_name,
54
+ seed,
55
+ psi,
56
+ ],
57
+ outputs=result,
58
+ api_name="run",
59
+ )
60
+ model_name2.change(
61
+ fn=get_sample_image_markdown,
62
+ inputs=model_name2,
63
+ outputs=sample_images,
64
+ queue=False,
65
+ api_name=False,
66
+ )
67
 
68
+ if __name__ == "__main__":
69
+ demo.queue(max_size=10).launch()
model.py CHANGED
@@ -10,36 +10,34 @@ import torch.nn as nn
10
  from huggingface_hub import hf_hub_download
11
 
12
  current_dir = pathlib.Path(__file__).parent
13
- submodule_dir = current_dir / 'projected_gan'
14
  sys.path.insert(0, submodule_dir.as_posix())
15
 
16
 
17
  class Model:
18
 
19
  MODEL_NAMES = [
20
- 'art_painting',
21
- 'church',
22
- 'bedroom',
23
- 'cityscapes',
24
- 'clevr',
25
- 'ffhq',
26
- 'flowers',
27
- 'landscape',
28
- 'pokemon',
29
  ]
30
 
31
  def __init__(self):
32
- self.device = torch.device(
33
- 'cuda:0' if torch.cuda.is_available() else 'cpu')
34
  self._download_all_models()
35
  self.model_name = self.MODEL_NAMES[3]
36
  self.model = self._load_model(self.model_name)
37
 
38
  def _load_model(self, model_name: str) -> nn.Module:
39
- path = hf_hub_download('public-data/projected_gan',
40
- f'models/{model_name}.pkl')
41
- with open(path, 'rb') as f:
42
- model = pickle.load(f)['G_ema']
43
  model.eval()
44
  model.to(self.device)
45
  return model
@@ -60,13 +58,11 @@ class Model:
60
  return torch.from_numpy(z).float().to(self.device)
61
 
62
  def postprocess(self, tensor: torch.Tensor) -> np.ndarray:
63
- tensor = (tensor.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(
64
- torch.uint8)
65
  return tensor.cpu().numpy()
66
 
67
  @torch.inference_mode()
68
- def generate(self, z: torch.Tensor, label: torch.Tensor,
69
- truncation_psi: float) -> torch.Tensor:
70
  return self.model(z, label, truncation_psi=truncation_psi)
71
 
72
  def generate_image(self, seed: int, truncation_psi: float) -> np.ndarray:
@@ -77,7 +73,6 @@ class Model:
77
  out = self.postprocess(out)
78
  return out[0]
79
 
80
- def set_model_and_generate_image(self, model_name: str, seed: int,
81
- truncation_psi: float) -> np.ndarray:
82
  self.set_model(model_name)
83
  return self.generate_image(seed, truncation_psi)
 
10
  from huggingface_hub import hf_hub_download
11
 
12
  current_dir = pathlib.Path(__file__).parent
13
+ submodule_dir = current_dir / "projected_gan"
14
  sys.path.insert(0, submodule_dir.as_posix())
15
 
16
 
17
  class Model:
18
 
19
  MODEL_NAMES = [
20
+ "art_painting",
21
+ "church",
22
+ "bedroom",
23
+ "cityscapes",
24
+ "clevr",
25
+ "ffhq",
26
+ "flowers",
27
+ "landscape",
28
+ "pokemon",
29
  ]
30
 
31
  def __init__(self):
32
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
33
  self._download_all_models()
34
  self.model_name = self.MODEL_NAMES[3]
35
  self.model = self._load_model(self.model_name)
36
 
37
  def _load_model(self, model_name: str) -> nn.Module:
38
+ path = hf_hub_download("public-data/projected_gan", f"models/{model_name}.pkl")
39
+ with open(path, "rb") as f:
40
+ model = pickle.load(f)["G_ema"]
 
41
  model.eval()
42
  model.to(self.device)
43
  return model
 
58
  return torch.from_numpy(z).float().to(self.device)
59
 
60
  def postprocess(self, tensor: torch.Tensor) -> np.ndarray:
61
+ tensor = (tensor.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
 
62
  return tensor.cpu().numpy()
63
 
64
  @torch.inference_mode()
65
+ def generate(self, z: torch.Tensor, label: torch.Tensor, truncation_psi: float) -> torch.Tensor:
 
66
  return self.model(z, label, truncation_psi=truncation_psi)
67
 
68
  def generate_image(self, seed: int, truncation_psi: float) -> np.ndarray:
 
73
  out = self.postprocess(out)
74
  return out[0]
75
 
76
+ def set_model_and_generate_image(self, model_name: str, seed: int, truncation_psi: float) -> np.ndarray:
 
77
  self.set_model(model_name)
78
  return self.generate_image(seed, truncation_psi)
requirements.txt CHANGED
@@ -1,5 +1,6 @@
1
- numpy==1.23.5
2
- Pillow==10.0.0
3
- scipy==1.10.1
 
4
  torch==1.10.2
5
  torchvision==0.11.3
 
1
+ gradio==4.20.0
2
+ numpy==1.26.4
3
+ Pillow==10.2.0
4
+ scipy==1.12.0
5
  torch==1.10.2
6
  torchvision==0.11.3
style.css CHANGED
@@ -1,11 +1,11 @@
1
  h1 {
2
  text-align: center;
3
- }
4
- div#result {
5
- max-width: 600px;
6
- max-height: 600px;
7
- }
8
- img#visitor-badge {
9
  display: block;
 
 
 
10
  margin: auto;
 
 
 
11
  }
 
1
  h1 {
2
  text-align: center;
 
 
 
 
 
 
3
  display: block;
4
+ }
5
+
6
+ #duplicate-button {
7
  margin: auto;
8
+ color: #fff;
9
+ background: #1565c0;
10
+ border-radius: 100vh;
11
  }