sayakpaul HF staff hysts HF staff commited on
Commit
59febcc
·
0 Parent(s):

Duplicate from AttendAndExcite/Attend-and-Excite

Browse files

Co-authored-by: hysts <[email protected]>

Files changed (11) hide show
  1. .gitattributes +34 -0
  2. .gitignore +163 -0
  3. .gitmodules +0 -0
  4. .pre-commit-config.yaml +37 -0
  5. .style.yapf +5 -0
  6. .vscode/settings.json +18 -0
  7. README.md +15 -0
  8. app.py +215 -0
  9. model.py +67 -0
  10. requirements.txt +5 -0
  11. style.css +3 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio_cached_examples/
2
+
3
+
4
+ # Byte-compiled / optimized / DLL files
5
+ __pycache__/
6
+ *.py[cod]
7
+ *$py.class
8
+
9
+ # C extensions
10
+ *.so
11
+
12
+ # Distribution / packaging
13
+ .Python
14
+ build/
15
+ develop-eggs/
16
+ dist/
17
+ downloads/
18
+ eggs/
19
+ .eggs/
20
+ lib/
21
+ lib64/
22
+ parts/
23
+ sdist/
24
+ var/
25
+ wheels/
26
+ share/python-wheels/
27
+ *.egg-info/
28
+ .installed.cfg
29
+ *.egg
30
+ MANIFEST
31
+
32
+ # PyInstaller
33
+ # Usually these files are written by a python script from a template
34
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
35
+ *.manifest
36
+ *.spec
37
+
38
+ # Installer logs
39
+ pip-log.txt
40
+ pip-delete-this-directory.txt
41
+
42
+ # Unit test / coverage reports
43
+ htmlcov/
44
+ .tox/
45
+ .nox/
46
+ .coverage
47
+ .coverage.*
48
+ .cache
49
+ nosetests.xml
50
+ coverage.xml
51
+ *.cover
52
+ *.py,cover
53
+ .hypothesis/
54
+ .pytest_cache/
55
+ cover/
56
+
57
+ # Translations
58
+ *.mo
59
+ *.pot
60
+
61
+ # Django stuff:
62
+ *.log
63
+ local_settings.py
64
+ db.sqlite3
65
+ db.sqlite3-journal
66
+
67
+ # Flask stuff:
68
+ instance/
69
+ .webassets-cache
70
+
71
+ # Scrapy stuff:
72
+ .scrapy
73
+
74
+ # Sphinx documentation
75
+ docs/_build/
76
+
77
+ # PyBuilder
78
+ .pybuilder/
79
+ target/
80
+
81
+ # Jupyter Notebook
82
+ .ipynb_checkpoints
83
+
84
+ # IPython
85
+ profile_default/
86
+ ipython_config.py
87
+
88
+ # pyenv
89
+ # For a library or package, you might want to ignore these files since the code is
90
+ # intended to run in multiple environments; otherwise, check them in:
91
+ # .python-version
92
+
93
+ # pipenv
94
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
96
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
97
+ # install all needed dependencies.
98
+ #Pipfile.lock
99
+
100
+ # poetry
101
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
102
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
103
+ # commonly ignored for libraries.
104
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
105
+ #poetry.lock
106
+
107
+ # pdm
108
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
109
+ #pdm.lock
110
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
111
+ # in version control.
112
+ # https://pdm.fming.dev/#use-with-ide
113
+ .pdm.toml
114
+
115
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
116
+ __pypackages__/
117
+
118
+ # Celery stuff
119
+ celerybeat-schedule
120
+ celerybeat.pid
121
+
122
+ # SageMath parsed files
123
+ *.sage.py
124
+
125
+ # Environments
126
+ .env
127
+ .venv
128
+ env/
129
+ venv/
130
+ ENV/
131
+ env.bak/
132
+ venv.bak/
133
+
134
+ # Spyder project settings
135
+ .spyderproject
136
+ .spyproject
137
+
138
+ # Rope project settings
139
+ .ropeproject
140
+
141
+ # mkdocs documentation
142
+ /site
143
+
144
+ # mypy
145
+ .mypy_cache/
146
+ .dmypy.json
147
+ dmypy.json
148
+
149
+ # Pyre type checker
150
+ .pyre/
151
+
152
+ # pytype static type analyzer
153
+ .pytype/
154
+
155
+ # Cython debug symbols
156
+ cython_debug/
157
+
158
+ # PyCharm
159
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
160
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
161
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
162
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
163
+ #.idea/
.gitmodules ADDED
File without changes
.pre-commit-config.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ exclude: patch
2
+ repos:
3
+ - repo: https://github.com/pre-commit/pre-commit-hooks
4
+ rev: v4.2.0
5
+ hooks:
6
+ - id: check-executables-have-shebangs
7
+ - id: check-json
8
+ - id: check-merge-conflict
9
+ - id: check-shebang-scripts-are-executable
10
+ - id: check-toml
11
+ - id: check-yaml
12
+ - id: double-quote-string-fixer
13
+ - id: end-of-file-fixer
14
+ - id: mixed-line-ending
15
+ args: ['--fix=lf']
16
+ - id: requirements-txt-fixer
17
+ - id: trailing-whitespace
18
+ - repo: https://github.com/myint/docformatter
19
+ rev: v1.4
20
+ hooks:
21
+ - id: docformatter
22
+ args: ['--in-place']
23
+ - repo: https://github.com/pycqa/isort
24
+ rev: 5.12.0
25
+ hooks:
26
+ - id: isort
27
+ - repo: https://github.com/pre-commit/mirrors-mypy
28
+ rev: v0.991
29
+ hooks:
30
+ - id: mypy
31
+ args: ['--ignore-missing-imports']
32
+ additional_dependencies: ['types-python-slugify']
33
+ - repo: https://github.com/google/yapf
34
+ rev: v0.32.0
35
+ hooks:
36
+ - id: yapf
37
+ args: ['--parallel', '--in-place']
.style.yapf ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [style]
2
+ based_on_style = pep8
3
+ blank_line_before_nested_class_or_def = false
4
+ spaces_before_comment = 2
5
+ split_before_logical_operator = true
.vscode/settings.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "python.linting.enabled": true,
3
+ "python.linting.flake8Enabled": true,
4
+ "python.linting.pylintEnabled": false,
5
+ "python.linting.lintOnSave": true,
6
+ "python.formatting.provider": "yapf",
7
+ "python.formatting.yapfArgs": [
8
+ "--style={based_on_style: pep8, indent_width: 4, blank_line_before_nested_class_or_def: false, spaces_before_comment: 2, split_before_logical_operator: true}"
9
+ ],
10
+ "[python]": {
11
+ "editor.formatOnType": true,
12
+ "editor.codeActionsOnSave": {
13
+ "source.organizeImports": true
14
+ }
15
+ },
16
+ "editor.formatOnSave": true,
17
+ "files.insertFinalNewline": true
18
+ }
README.md ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Attend And Excite
3
+ emoji: 💻
4
+ colorFrom: gray
5
+ colorTo: pink
6
+ sdk: gradio
7
+ sdk_version: 3.34.0
8
+ python_version: 3.10.11
9
+ app_file: app.py
10
+ pinned: false
11
+ license: mit
12
+ duplicated_from: AttendAndExcite/Attend-and-Excite
13
+ ---
14
+
15
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+
7
+ import gradio as gr
8
+ import PIL.Image
9
+
10
+ from model import Model
11
+
12
+ DESCRIPTION = '''# Attend-and-Excite
13
+ This is a demo for [Attend-and-Excite](https://arxiv.org/abs/2301.13826).
14
+ Attend-and-Excite performs attention-based generative semantic guidance to mitigate subject neglect in Stable Diffusion.
15
+ Select a prompt and a set of indices matching the subjects you wish to strengthen (the `Check token indices` cell can help map between a word and its index).
16
+ '''
17
+
18
+ model = Model()
19
+
20
+
21
+ def process_example(
22
+ prompt: str,
23
+ indices_to_alter_str: str,
24
+ seed: int,
25
+ apply_attend_and_excite: bool,
26
+ ) -> tuple[list[tuple[int, str]], PIL.Image.Image]:
27
+ num_steps = 50
28
+ guidance_scale = 7.5
29
+
30
+ token_table = model.get_token_table(prompt)
31
+ result = model.run(prompt, indices_to_alter_str, seed,
32
+ apply_attend_and_excite, num_steps, guidance_scale)
33
+ return token_table, result
34
+
35
+
36
+ with gr.Blocks(css='style.css') as demo:
37
+ gr.Markdown(DESCRIPTION)
38
+
39
+ with gr.Row():
40
+ with gr.Column():
41
+ prompt = gr.Text(
42
+ label='Prompt',
43
+ max_lines=1,
44
+ placeholder=
45
+ 'A pod of dolphins leaping out of the water in an ocean with a ship on the background'
46
+ )
47
+ with gr.Accordion(label='Check token indices', open=False):
48
+ show_token_indices_button = gr.Button('Show token indices')
49
+ token_indices_table = gr.Dataframe(label='Token indices',
50
+ headers=['Index', 'Token'],
51
+ col_count=2)
52
+ token_indices_str = gr.Text(
53
+ label=
54
+ 'Token indices (a comma-separated list indices of the tokens you wish to alter)',
55
+ max_lines=1,
56
+ placeholder='4,16')
57
+ seed = gr.Slider(label='Seed',
58
+ minimum=0,
59
+ maximum=100000,
60
+ value=0,
61
+ step=1)
62
+ apply_attend_and_excite = gr.Checkbox(
63
+ label='Apply Attend-and-Excite', value=True)
64
+ num_steps = gr.Slider(label='Number of steps',
65
+ minimum=0,
66
+ maximum=100,
67
+ step=1,
68
+ value=50)
69
+ guidance_scale = gr.Slider(label='CFG scale',
70
+ minimum=0,
71
+ maximum=50,
72
+ step=0.1,
73
+ value=7.5)
74
+ run_button = gr.Button('Generate')
75
+ with gr.Column():
76
+ result = gr.Image(label='Result')
77
+
78
+ with gr.Row():
79
+ examples = [
80
+ [
81
+ 'A mouse and a red car',
82
+ '2,6',
83
+ 2098,
84
+ True,
85
+ ],
86
+ [
87
+ 'A mouse and a red car',
88
+ '2,6',
89
+ 2098,
90
+ False,
91
+ ],
92
+ [
93
+ 'A horse and a dog',
94
+ '2,5',
95
+ 123,
96
+ True,
97
+ ],
98
+ [
99
+ 'A horse and a dog',
100
+ '2,5',
101
+ 123,
102
+ False,
103
+ ],
104
+ [
105
+ 'A painting of an elephant with glasses',
106
+ '5,7',
107
+ 123,
108
+ True,
109
+ ],
110
+ [
111
+ 'A painting of an elephant with glasses',
112
+ '5,7',
113
+ 123,
114
+ False,
115
+ ],
116
+ [
117
+ 'A playful kitten chasing a butterfly in a wildflower meadow',
118
+ '3,6,10',
119
+ 123,
120
+ True,
121
+ ],
122
+ [
123
+ 'A playful kitten chasing a butterfly in a wildflower meadow',
124
+ '3,6,10',
125
+ 123,
126
+ False,
127
+ ],
128
+ [
129
+ 'A grizzly bear catching a salmon in a crystal clear river surrounded by a forest',
130
+ '2,6,15',
131
+ 123,
132
+ True,
133
+ ],
134
+ [
135
+ 'A grizzly bear catching a salmon in a crystal clear river surrounded by a forest',
136
+ '2,6,15',
137
+ 123,
138
+ False,
139
+ ],
140
+ [
141
+ 'A pod of dolphins leaping out of the water in an ocean with a ship on the background',
142
+ '4,16',
143
+ 123,
144
+ True,
145
+ ],
146
+ [
147
+ 'A pod of dolphins leaping out of the water in an ocean with a ship on the background',
148
+ '4,16',
149
+ 123,
150
+ False,
151
+ ],
152
+ ]
153
+ gr.Examples(examples=examples,
154
+ inputs=[
155
+ prompt,
156
+ token_indices_str,
157
+ seed,
158
+ apply_attend_and_excite,
159
+ ],
160
+ outputs=[
161
+ token_indices_table,
162
+ result,
163
+ ],
164
+ fn=process_example,
165
+ cache_examples=os.getenv('CACHE_EXAMPLES') == '1',
166
+ examples_per_page=20)
167
+
168
+ show_token_indices_button.click(
169
+ fn=model.get_token_table,
170
+ inputs=prompt,
171
+ outputs=token_indices_table,
172
+ queue=False,
173
+ )
174
+
175
+ inputs = [
176
+ prompt,
177
+ token_indices_str,
178
+ seed,
179
+ apply_attend_and_excite,
180
+ num_steps,
181
+ guidance_scale,
182
+ ]
183
+ prompt.submit(
184
+ fn=model.get_token_table,
185
+ inputs=prompt,
186
+ outputs=token_indices_table,
187
+ queue=False,
188
+ ).then(
189
+ fn=model.run,
190
+ inputs=inputs,
191
+ outputs=result,
192
+ )
193
+ token_indices_str.submit(
194
+ fn=model.get_token_table,
195
+ inputs=prompt,
196
+ outputs=token_indices_table,
197
+ queue=False,
198
+ ).then(
199
+ fn=model.run,
200
+ inputs=inputs,
201
+ outputs=result,
202
+ )
203
+ run_button.click(
204
+ fn=model.get_token_table,
205
+ inputs=prompt,
206
+ outputs=token_indices_table,
207
+ queue=False,
208
+ ).then(
209
+ fn=model.run,
210
+ inputs=inputs,
211
+ outputs=result,
212
+ api_name='run',
213
+ )
214
+
215
+ demo.queue(max_size=10).launch()
model.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import PIL.Image
4
+ import torch
5
+ from diffusers import (StableDiffusionAttendAndExcitePipeline,
6
+ StableDiffusionPipeline)
7
+
8
+
9
+ class Model:
10
+ def __init__(self):
11
+ self.device = torch.device(
12
+ 'cuda:0' if torch.cuda.is_available() else 'cpu')
13
+ model_id = 'CompVis/stable-diffusion-v1-4'
14
+ self.ax_pipe = StableDiffusionAttendAndExcitePipeline.from_pretrained(
15
+ model_id)
16
+ self.ax_pipe.to(self.device)
17
+ self.sd_pipe = StableDiffusionPipeline.from_pretrained(model_id)
18
+ self.sd_pipe.to(self.device)
19
+
20
+ def get_token_table(self, prompt: str):
21
+ tokens = [
22
+ self.ax_pipe.tokenizer.decode(t)
23
+ for t in self.ax_pipe.tokenizer(prompt)['input_ids']
24
+ ]
25
+ tokens = tokens[1:-1]
26
+ return list(enumerate(tokens, start=1))
27
+
28
+ def run(
29
+ self,
30
+ prompt: str,
31
+ indices_to_alter_str: str,
32
+ seed: int = 0,
33
+ apply_attend_and_excite: bool = True,
34
+ num_steps: int = 50,
35
+ guidance_scale: float = 7.5,
36
+ scale_factor: int = 20,
37
+ thresholds: dict[int, float] = {
38
+ 10: 0.5,
39
+ 20: 0.8,
40
+ },
41
+ max_iter_to_alter: int = 25,
42
+ ) -> PIL.Image.Image:
43
+ generator = torch.Generator(device=self.device).manual_seed(seed)
44
+
45
+ if apply_attend_and_excite:
46
+ try:
47
+ token_indices = list(map(int, indices_to_alter_str.split(',')))
48
+ except Exception:
49
+ raise ValueError('Invalid token indices.')
50
+ out = self.ax_pipe(
51
+ prompt=prompt,
52
+ token_indices=token_indices,
53
+ guidance_scale=guidance_scale,
54
+ generator=generator,
55
+ num_inference_steps=num_steps,
56
+ max_iter_to_alter=max_iter_to_alter,
57
+ thresholds=thresholds,
58
+ scale_factor=scale_factor,
59
+ )
60
+ else:
61
+ out = self.sd_pipe(
62
+ prompt=prompt,
63
+ guidance_scale=guidance_scale,
64
+ generator=generator,
65
+ num_inference_steps=num_steps,
66
+ )
67
+ return out.images[0]
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ accelerate==0.20.3
2
+ diffusers==0.17.0
3
+ Pillow==9.5.0
4
+ torch==2.0.1
5
+ transformers==4.30.1
style.css ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }