yucornetto commited on
Commit
dada74e
1 Parent(s): cd8845c

Upload 20 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/random_vis_l32.png filter=lfs diff=lfs merge=lfs -text
37
+ assets/recon_w_model_size_num_token.png filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,13 +1,8 @@
1
- ---
2
  title: TiTok
3
- emoji: 😻
4
  colorFrom: indigo
5
- colorTo: blue
6
  sdk: gradio
7
- sdk_version: 4.36.1
8
  app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
1
  title: TiTok
2
+ emoji: 🏆
3
  colorFrom: indigo
4
+ colorTo: pink
5
  sdk: gradio
6
+ sdk_version: 4.36.0
7
  app_file: app.py
8
+ pinned: false
 
 
 
 
app.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Reference: https://huggingface.co/spaces/FoundationVision/LlamaGen/blob/main/app.py
2
+ from PIL import Image
3
+ import gradio as gr
4
+ from imagenet_classes import imagenet_idx2classname
5
+ from huggingface_hub import hf_hub_download
6
+ import torch
7
+ torch.backends.cuda.matmul.allow_tf32 = True
8
+ torch.backends.cudnn.allow_tf32 = True
9
+ import time
10
+ import argparse
11
+ import demo_util
12
+ import os
13
+
14
+ device = "cuda"
15
+
16
+ model2ckpt = {
17
+ "TiTok-L-32": ("tokenizer_titok_l32.bin", "generator_titok_l32.bin"),
18
+ }
19
+
20
+ if not os.path.exists("tokenizer_titok_l32.bin"):
21
+ os.system("gdown 1I_m2Vm4JgQsa7bZVORj-nVhP8fgQLngd")
22
+ if not os.path.exists("generator_titok_l32.bin"):
23
+ os.system("gdown 1IgqZ_vwGIj2ZWOPuCzilxeQ2UrMVY93l")
24
+
25
+ parser = argparse.ArgumentParser()
26
+ parser.add_argument("--precision", type=str, default='bf16', choices=["none", "fp16", "bf16"])
27
+ parser.add_argument("--guidance_scale", type=float, default=3.5)
28
+ parser.add_argument("--randomize_temperature", type=float, default=1.0)
29
+ parser.add_argument("--num_sample_steps", type=int, default=8)
30
+ parser.add_argument("--seed", type=int, default=42)
31
+ parser.add_argument("--temperature", type=float, default=1.0, help="temperature value to sample with")
32
+ args = parser.parse_args()
33
+
34
+ config = demo_util.get_config("configs/titok_l32.yaml")
35
+ print(config)
36
+ titok_tokenizer = demo_util.get_titok_tokenizer(config)
37
+ print(titok_tokenizer)
38
+ titok_generator = demo_util.get_titok_generator(config)
39
+ print(titok_generator)
40
+
41
+ titok_tokenizer = titok_tokenizer.to(device)
42
+ titok_generator = titok_generator.to(device)
43
+
44
+
45
+ def demo_infer(guidance_scale, randomize_temperature, num_sample_steps,
46
+ class_label, seed):
47
+ n = 4
48
+ class_labels = [class_label for _ in range(n)]
49
+ torch.manual_seed(seed)
50
+ torch.cuda.manual_seed(seed)
51
+ t1 = time.time()
52
+ generated_image = demo_util.sample_fn(
53
+ generator=titok_generator,
54
+ tokenizer=titok_tokenizer,
55
+ labels=class_labels,
56
+ guidance_scale=guidance_scale,
57
+ randomize_temperature=randomize_temperature,
58
+ num_sample_steps=num_sample_steps,
59
+ device=device
60
+ )
61
+ sampling_time = time.time() - t1
62
+ print(f"generation takes about {sampling_time:.2f} seconds.")
63
+ samples = [Image.fromarray(sample) for sample in generated_image]
64
+ return samples
65
+
66
+
67
+ with gr.Blocks() as demo:
68
+ gr.Markdown("<h1 style='text-align: center'>An Image is Worth 32 Tokens for Reconstruction and Generation</h1>")
69
+
70
+ with gr.Tabs():
71
+ with gr.TabItem('Generate'):
72
+ with gr.Row():
73
+ with gr.Column():
74
+ with gr.Row():
75
+ i1k_class = gr.Dropdown(
76
+ list(imagenet_idx2classname.values()),
77
+ value='macaw',
78
+ type="index", label='ImageNet-1K Class'
79
+ )
80
+ guidance_scale = gr.Slider(minimum=1, maximum=25, step=0.1, value=3.5, label='Classifier-free Guidance Scale')
81
+ randomize_temperature = gr.Slider(minimum=0., maximum=10.0, step=0.1, value=1.0, label='randomize_temperature')
82
+ num_sample_steps = gr.Slider(minimum=1, maximum=32, step=1, value=8, label='num_sample_steps')
83
+ seed = gr.Slider(minimum=0, maximum=1000, step=1, value=42, label='Seed')
84
+ button = gr.Button("Generate", variant="primary")
85
+ with gr.Column():
86
+ output = gr.Gallery(label='Generated Images', height=700)
87
+ button.click(demo_util.sample_fn, inputs=[
88
+ guidance_scale, randomize_temperature, num_sample_steps,
89
+ i1k_class, seed],
90
+ outputs=[output])
91
+ demo.queue()
92
+ demo.launch(debug=True)
assets/ILSVRC2012_val_00008636.png ADDED
assets/ILSVRC2012_val_00010240.png ADDED
assets/random_vis_l32.png ADDED

Git LFS Details

  • SHA256: ff40d0274f7d6656791e4fc72afbf0d46b0a3975803d6184a46baac0ab80438e
  • Pointer size: 132 Bytes
  • Size of remote file: 2.27 MB
assets/recon_w_model_size_num_token.png ADDED

Git LFS Details

  • SHA256: 8e5fe53bb8aa64fe918a33de92ac2d965d46871298eeec6fcd2a4a00f1b75386
  • Pointer size: 132 Bytes
  • Size of remote file: 1.49 MB
assets/speed_vs_perf.png ADDED
assets/titok_teaser.png ADDED
configs/titok_l32.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ experiment:
2
+ tokenizer_checkpoint: "tokenizer_titok_l32.bin"
3
+ generator_checkpoint: "generator_titok_l32.bin"
4
+
5
+ model:
6
+ vq_model:
7
+ codebook_size: 4096
8
+ token_size: 12
9
+ use_l2_norm: True
10
+ commitment_cost: 0.25
11
+ # vit arch
12
+ vit_enc_model_size: "large"
13
+ vit_dec_model_size: "large"
14
+ vit_enc_patch_size: 16
15
+ vit_dec_patch_size: 16
16
+ num_latent_tokens: 32
17
+
18
+ generator:
19
+ dropout: 0.1
20
+ attn_drop: 0.1
21
+ num_steps: 8
22
+ mask_schedule_strategy: "arccos"
23
+ class_label_dropout: 0.1
24
+ image_seq_len: ${model.vq_model.num_latent_tokens}
25
+ condition_num_classes: 1000
26
+
27
+ dataset:
28
+ preprocessing:
29
+ crop_size: 256
demo.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
demo_util.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Demo file for sampling images from TiTok.
2
+
3
+ Copyright (2024) Bytedance Ltd. and/or its affiliates
4
+
5
+ Licensed under the Apache License, Version 2.0 (the "License");
6
+ you may not use this file except in compliance with the License.
7
+ You may obtain a copy of the License at
8
+
9
+ http://www.apache.org/licenses/LICENSE-2.0
10
+
11
+ Unless required by applicable law or agreed to in writing, software
12
+ distributed under the License is distributed on an "AS IS" BASIS,
13
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ See the License for the specific language governing permissions and
15
+ limitations under the License.
16
+ """
17
+
18
+
19
+ import torch
20
+
21
+ from omegaconf import OmegaConf
22
+ from modeling.titok import TiTok
23
+ from modeling.maskgit import ImageBert
24
+
25
+
26
+ def get_config_cli():
27
+ cli_conf = OmegaConf.from_cli()
28
+
29
+ yaml_conf = OmegaConf.load(cli_conf.config)
30
+ conf = OmegaConf.merge(yaml_conf, cli_conf)
31
+
32
+ return conf
33
+
34
+ def get_config(config_path):
35
+ conf = OmegaConf.load(config_path)
36
+ return conf
37
+
38
+ def get_titok_tokenizer(config):
39
+ tokenizer = TiTok(config)
40
+ tokenizer.load_state_dict(torch.load(config.experiment.tokenizer_checkpoint))
41
+ tokenizer.eval()
42
+ tokenizer.requires_grad_(False)
43
+ return tokenizer
44
+
45
+ def get_titok_generator(config):
46
+ generator = ImageBert(config)
47
+ generator.load_state_dict(torch.load(config.experiment.generator_checkpoint))
48
+ generator.eval()
49
+ generator.requires_grad_(False)
50
+ return generator
51
+
52
+ @torch.no_grad()
53
+ def sample_fn(generator,
54
+ tokenizer,
55
+ labels=None,
56
+ guidance_scale=3.0,
57
+ randomize_temperature=2.0,
58
+ num_sample_steps=8,
59
+ device="cuda"):
60
+ generator.eval()
61
+ tokenizer.eval()
62
+ if labels is None:
63
+ # goldfish, chicken, tiger, cat, hourglass, ship, dog, race car, airliner, teddy bear, random
64
+ labels = [1, 7, 282, 604, 724, 179, 751, 404, 850, torch.randint(0, 999, size=(1,))]
65
+
66
+ labels = torch.LongTensor(labels).to(device)
67
+
68
+ generated_tokens = generator.generate(
69
+ condition=labels,
70
+ guidance_scale=guidance_scale,
71
+ randomize_temperature=randomize_temperature,
72
+ num_sample_steps=num_sample_steps)
73
+
74
+ generated_image = tokenizer.decode_tokens(
75
+ generated_tokens.view(generated_tokens.shape[0], -1)
76
+ )
77
+
78
+ generated_image = torch.clamp(generated_image, 0.0, 1.0)
79
+ generated_image = (generated_image * 255.0).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
80
+
81
+ return generated_image
imagenet_classes.py ADDED
@@ -0,0 +1,1001 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ imagenet_idx2classname = {
2
+ 0: 'tench, Tinca tinca',
3
+ 1: 'goldfish, Carassius auratus',
4
+ 2: 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias',
5
+ 3: 'tiger shark, Galeocerdo cuvieri',
6
+ 4: 'hammerhead, hammerhead shark',
7
+ 5: 'electric ray, crampfish, numbfish, torpedo',
8
+ 6: 'stingray',
9
+ 7: 'cock',
10
+ 8: 'hen',
11
+ 9: 'ostrich, Struthio camelus',
12
+ 10: 'brambling, Fringilla montifringilla',
13
+ 11: 'goldfinch, Carduelis carduelis',
14
+ 12: 'house finch, linnet, Carpodacus mexicanus',
15
+ 13: 'junco, snowbird',
16
+ 14: 'indigo bunting, indigo finch, indigo bird, Passerina cyanea',
17
+ 15: 'robin, American robin, Turdus migratorius',
18
+ 16: 'bulbul',
19
+ 17: 'jay',
20
+ 18: 'magpie',
21
+ 19: 'chickadee',
22
+ 20: 'water ouzel, dipper',
23
+ 21: 'kite',
24
+ 22: 'bald eagle, American eagle, Haliaeetus leucocephalus',
25
+ 23: 'vulture',
26
+ 24: 'great grey owl, great gray owl, Strix nebulosa',
27
+ 25: 'European fire salamander, Salamandra salamandra',
28
+ 26: 'common newt, Triturus vulgaris',
29
+ 27: 'eft',
30
+ 28: 'spotted salamander, Ambystoma maculatum',
31
+ 29: 'axolotl, mud puppy, Ambystoma mexicanum',
32
+ 30: 'bullfrog, Rana catesbeiana',
33
+ 31: 'tree frog, tree-frog',
34
+ 32: 'tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui',
35
+ 33: 'loggerhead, loggerhead turtle, Caretta caretta',
36
+ 34: 'leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea',
37
+ 35: 'mud turtle',
38
+ 36: 'terrapin',
39
+ 37: 'box turtle, box tortoise',
40
+ 38: 'banded gecko',
41
+ 39: 'common iguana, iguana, Iguana iguana',
42
+ 40: 'American chameleon, anole, Anolis carolinensis',
43
+ 41: 'whiptail, whiptail lizard',
44
+ 42: 'agama',
45
+ 43: 'frilled lizard, Chlamydosaurus kingi',
46
+ 44: 'alligator lizard',
47
+ 45: 'Gila monster, Heloderma suspectum',
48
+ 46: 'green lizard, Lacerta viridis',
49
+ 47: 'African chameleon, Chamaeleo chamaeleon',
50
+ 48: 'Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis',
51
+ 49: 'African crocodile, Nile crocodile, Crocodylus niloticus',
52
+ 50: 'American alligator, Alligator mississipiensis',
53
+ 51: 'triceratops',
54
+ 52: 'thunder snake, worm snake, Carphophis amoenus',
55
+ 53: 'ringneck snake, ring-necked snake, ring snake',
56
+ 54: 'hognose snake, puff adder, sand viper',
57
+ 55: 'green snake, grass snake',
58
+ 56: 'king snake, kingsnake',
59
+ 57: 'garter snake, grass snake',
60
+ 58: 'water snake',
61
+ 59: 'vine snake',
62
+ 60: 'night snake, Hypsiglena torquata',
63
+ 61: 'boa constrictor, Constrictor constrictor',
64
+ 62: 'rock python, rock snake, Python sebae',
65
+ 63: 'Indian cobra, Naja naja',
66
+ 64: 'green mamba',
67
+ 65: 'sea snake',
68
+ 66: 'horned viper, cerastes, sand viper, horned asp, Cerastes cornutus',
69
+ 67: 'diamondback, diamondback rattlesnake, Crotalus adamanteus',
70
+ 68: 'sidewinder, horned rattlesnake, Crotalus cerastes',
71
+ 69: 'trilobite',
72
+ 70: 'harvestman, daddy longlegs, Phalangium opilio',
73
+ 71: 'scorpion',
74
+ 72: 'black and gold garden spider, Argiope aurantia',
75
+ 73: 'barn spider, Araneus cavaticus',
76
+ 74: 'garden spider, Aranea diademata',
77
+ 75: 'black widow, Latrodectus mactans',
78
+ 76: 'tarantula',
79
+ 77: 'wolf spider, hunting spider',
80
+ 78: 'tick',
81
+ 79: 'centipede',
82
+ 80: 'black grouse',
83
+ 81: 'ptarmigan',
84
+ 82: 'ruffed grouse, partridge, Bonasa umbellus',
85
+ 83: 'prairie chicken, prairie grouse, prairie fowl',
86
+ 84: 'peacock',
87
+ 85: 'quail',
88
+ 86: 'partridge',
89
+ 87: 'African grey, African gray, Psittacus erithacus',
90
+ 88: 'macaw',
91
+ 89: 'sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita',
92
+ 90: 'lorikeet',
93
+ 91: 'coucal',
94
+ 92: 'bee eater',
95
+ 93: 'hornbill',
96
+ 94: 'hummingbird',
97
+ 95: 'jacamar',
98
+ 96: 'toucan',
99
+ 97: 'drake',
100
+ 98: 'red-breasted merganser, Mergus serrator',
101
+ 99: 'goose',
102
+ 100: 'black swan, Cygnus atratus',
103
+ 101: 'tusker',
104
+ 102: 'echidna, spiny anteater, anteater',
105
+ 103: 'platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus',
106
+ 104: 'wallaby, brush kangaroo',
107
+ 105: 'koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus',
108
+ 106: 'wombat',
109
+ 107: 'jellyfish',
110
+ 108: 'sea anemone, anemone',
111
+ 109: 'brain coral',
112
+ 110: 'flatworm, platyhelminth',
113
+ 111: 'nematode, nematode worm, roundworm',
114
+ 112: 'conch',
115
+ 113: 'snail',
116
+ 114: 'slug',
117
+ 115: 'sea slug, nudibranch',
118
+ 116: 'chiton, coat-of-mail shell, sea cradle, polyplacophore',
119
+ 117: 'chambered nautilus, pearly nautilus, nautilus',
120
+ 118: 'Dungeness crab, Cancer magister',
121
+ 119: 'rock crab, Cancer irroratus',
122
+ 120: 'fiddler crab',
123
+ 121: 'king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica',
124
+ 122: 'American lobster, Northern lobster, Maine lobster, Homarus americanus',
125
+ 123: 'spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish',
126
+ 124: 'crayfish, crawfish, crawdad, crawdaddy',
127
+ 125: 'hermit crab',
128
+ 126: 'isopod',
129
+ 127: 'white stork, Ciconia ciconia',
130
+ 128: 'black stork, Ciconia nigra',
131
+ 129: 'spoonbill',
132
+ 130: 'flamingo',
133
+ 131: 'little blue heron, Egretta caerulea',
134
+ 132: 'American egret, great white heron, Egretta albus',
135
+ 133: 'bittern',
136
+ 134: 'crane',
137
+ 135: 'limpkin, Aramus pictus',
138
+ 136: 'European gallinule, Porphyrio porphyrio',
139
+ 137: 'American coot, marsh hen, mud hen, water hen, Fulica americana',
140
+ 138: 'bustard',
141
+ 139: 'ruddy turnstone, Arenaria interpres',
142
+ 140: 'red-backed sandpiper, dunlin, Erolia alpina',
143
+ 141: 'redshank, Tringa totanus',
144
+ 142: 'dowitcher',
145
+ 143: 'oystercatcher, oyster catcher',
146
+ 144: 'pelican',
147
+ 145: 'king penguin, Aptenodytes patagonica',
148
+ 146: 'albatross, mollymawk',
149
+ 147: 'grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus',
150
+ 148: 'killer whale, killer, orca, grampus, sea wolf, Orcinus orca',
151
+ 149: 'dugong, Dugong dugon',
152
+ 150: 'sea lion',
153
+ 151: 'Chihuahua',
154
+ 152: 'Japanese spaniel',
155
+ 153: 'Maltese dog, Maltese terrier, Maltese',
156
+ 154: 'Pekinese, Pekingese, Peke',
157
+ 155: 'Shih-Tzu',
158
+ 156: 'Blenheim spaniel',
159
+ 157: 'papillon',
160
+ 158: 'toy terrier',
161
+ 159: 'Rhodesian ridgeback',
162
+ 160: 'Afghan hound, Afghan',
163
+ 161: 'basset, basset hound',
164
+ 162: 'beagle',
165
+ 163: 'bloodhound, sleuthhound',
166
+ 164: 'bluetick',
167
+ 165: 'black-and-tan coonhound',
168
+ 166: 'Walker hound, Walker foxhound',
169
+ 167: 'English foxhound',
170
+ 168: 'redbone',
171
+ 169: 'borzoi, Russian wolfhound',
172
+ 170: 'Irish wolfhound',
173
+ 171: 'Italian greyhound',
174
+ 172: 'whippet',
175
+ 173: 'Ibizan hound, Ibizan Podenco',
176
+ 174: 'Norwegian elkhound, elkhound',
177
+ 175: 'otterhound, otter hound',
178
+ 176: 'Saluki, gazelle hound',
179
+ 177: 'Scottish deerhound, deerhound',
180
+ 178: 'Weimaraner',
181
+ 179: 'Staffordshire bullterrier, Staffordshire bull terrier',
182
+ 180: 'American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier',
183
+ 181: 'Bedlington terrier',
184
+ 182: 'Border terrier',
185
+ 183: 'Kerry blue terrier',
186
+ 184: 'Irish terrier',
187
+ 185: 'Norfolk terrier',
188
+ 186: 'Norwich terrier',
189
+ 187: 'Yorkshire terrier',
190
+ 188: 'wire-haired fox terrier',
191
+ 189: 'Lakeland terrier',
192
+ 190: 'Sealyham terrier, Sealyham',
193
+ 191: 'Airedale, Airedale terrier',
194
+ 192: 'cairn, cairn terrier',
195
+ 193: 'Australian terrier',
196
+ 194: 'Dandie Dinmont, Dandie Dinmont terrier',
197
+ 195: 'Boston bull, Boston terrier',
198
+ 196: 'miniature schnauzer',
199
+ 197: 'giant schnauzer',
200
+ 198: 'standard schnauzer',
201
+ 199: 'Scotch terrier, Scottish terrier, Scottie',
202
+ 200: 'Tibetan terrier, chrysanthemum dog',
203
+ 201: 'silky terrier, Sydney silky',
204
+ 202: 'soft-coated wheaten terrier',
205
+ 203: 'West Highland white terrier',
206
+ 204: 'Lhasa, Lhasa apso',
207
+ 205: 'flat-coated retriever',
208
+ 206: 'curly-coated retriever',
209
+ 207: 'golden retriever',
210
+ 208: 'Labrador retriever',
211
+ 209: 'Chesapeake Bay retriever',
212
+ 210: 'German short-haired pointer',
213
+ 211: 'vizsla, Hungarian pointer',
214
+ 212: 'English setter',
215
+ 213: 'Irish setter, red setter',
216
+ 214: 'Gordon setter',
217
+ 215: 'Brittany spaniel',
218
+ 216: 'clumber, clumber spaniel',
219
+ 217: 'English springer, English springer spaniel',
220
+ 218: 'Welsh springer spaniel',
221
+ 219: 'cocker spaniel, English cocker spaniel, cocker',
222
+ 220: 'Sussex spaniel',
223
+ 221: 'Irish water spaniel',
224
+ 222: 'kuvasz',
225
+ 223: 'schipperke',
226
+ 224: 'groenendael',
227
+ 225: 'malinois',
228
+ 226: 'briard',
229
+ 227: 'kelpie',
230
+ 228: 'komondor',
231
+ 229: 'Old English sheepdog, bobtail',
232
+ 230: 'Shetland sheepdog, Shetland sheep dog, Shetland',
233
+ 231: 'collie',
234
+ 232: 'Border collie',
235
+ 233: 'Bouvier des Flandres, Bouviers des Flandres',
236
+ 234: 'Rottweiler',
237
+ 235: 'German shepherd, German shepherd dog, German police dog, alsatian',
238
+ 236: 'Doberman, Doberman pinscher',
239
+ 237: 'miniature pinscher',
240
+ 238: 'Greater Swiss Mountain dog',
241
+ 239: 'Bernese mountain dog',
242
+ 240: 'Appenzeller',
243
+ 241: 'EntleBucher',
244
+ 242: 'boxer',
245
+ 243: 'bull mastiff',
246
+ 244: 'Tibetan mastiff',
247
+ 245: 'French bulldog',
248
+ 246: 'Great Dane',
249
+ 247: 'Saint Bernard, St Bernard',
250
+ 248: 'Eskimo dog, husky',
251
+ 249: 'malamute, malemute, Alaskan malamute',
252
+ 250: 'Siberian husky',
253
+ 251: 'dalmatian, coach dog, carriage dog',
254
+ 252: 'affenpinscher, monkey pinscher, monkey dog',
255
+ 253: 'basenji',
256
+ 254: 'pug, pug-dog',
257
+ 255: 'Leonberg',
258
+ 256: 'Newfoundland, Newfoundland dog',
259
+ 257: 'Great Pyrenees',
260
+ 258: 'Samoyed, Samoyede',
261
+ 259: 'Pomeranian',
262
+ 260: 'chow, chow chow',
263
+ 261: 'keeshond',
264
+ 262: 'Brabancon griffon',
265
+ 263: 'Pembroke, Pembroke Welsh corgi',
266
+ 264: 'Cardigan, Cardigan Welsh corgi',
267
+ 265: 'toy poodle',
268
+ 266: 'miniature poodle',
269
+ 267: 'standard poodle',
270
+ 268: 'Mexican hairless',
271
+ 269: 'timber wolf, grey wolf, gray wolf, Canis lupus',
272
+ 270: 'white wolf, Arctic wolf, Canis lupus tundrarum',
273
+ 271: 'red wolf, maned wolf, Canis rufus, Canis niger',
274
+ 272: 'coyote, prairie wolf, brush wolf, Canis latrans',
275
+ 273: 'dingo, warrigal, warragal, Canis dingo',
276
+ 274: 'dhole, Cuon alpinus',
277
+ 275: 'African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus',
278
+ 276: 'hyena, hyaena',
279
+ 277: 'red fox, Vulpes vulpes',
280
+ 278: 'kit fox, Vulpes macrotis',
281
+ 279: 'Arctic fox, white fox, Alopex lagopus',
282
+ 280: 'grey fox, gray fox, Urocyon cinereoargenteus',
283
+ 281: 'tabby, tabby cat',
284
+ 282: 'tiger cat',
285
+ 283: 'Persian cat',
286
+ 284: 'Siamese cat, Siamese',
287
+ 285: 'Egyptian cat',
288
+ 286: 'cougar, puma, catamount, mountain lion, painter, panther, Felis concolor',
289
+ 287: 'lynx, catamount',
290
+ 288: 'leopard, Panthera pardus',
291
+ 289: 'snow leopard, ounce, Panthera uncia',
292
+ 290: 'jaguar, panther, Panthera onca, Felis onca',
293
+ 291: 'lion, king of beasts, Panthera leo',
294
+ 292: 'tiger, Panthera tigris',
295
+ 293: 'cheetah, chetah, Acinonyx jubatus',
296
+ 294: 'brown bear, bruin, Ursus arctos',
297
+ 295: 'American black bear, black bear, Ursus americanus, Euarctos americanus',
298
+ 296: 'ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus',
299
+ 297: 'sloth bear, Melursus ursinus, Ursus ursinus',
300
+ 298: 'mongoose',
301
+ 299: 'meerkat, mierkat',
302
+ 300: 'tiger beetle',
303
+ 301: 'ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle',
304
+ 302: 'ground beetle, carabid beetle',
305
+ 303: 'long-horned beetle, longicorn, longicorn beetle',
306
+ 304: 'leaf beetle, chrysomelid',
307
+ 305: 'dung beetle',
308
+ 306: 'rhinoceros beetle',
309
+ 307: 'weevil',
310
+ 308: 'fly',
311
+ 309: 'bee',
312
+ 310: 'ant, emmet, pismire',
313
+ 311: 'grasshopper, hopper',
314
+ 312: 'cricket',
315
+ 313: 'walking stick, walkingstick, stick insect',
316
+ 314: 'cockroach, roach',
317
+ 315: 'mantis, mantid',
318
+ 316: 'cicada, cicala',
319
+ 317: 'leafhopper',
320
+ 318: 'lacewing, lacewing fly',
321
+ 319: "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
322
+ 320: 'damselfly',
323
+ 321: 'admiral',
324
+ 322: 'ringlet, ringlet butterfly',
325
+ 323: 'monarch, monarch butterfly, milkweed butterfly, Danaus plexippus',
326
+ 324: 'cabbage butterfly',
327
+ 325: 'sulphur butterfly, sulfur butterfly',
328
+ 326: 'lycaenid, lycaenid butterfly',
329
+ 327: 'starfish, sea star',
330
+ 328: 'sea urchin',
331
+ 329: 'sea cucumber, holothurian',
332
+ 330: 'wood rabbit, cottontail, cottontail rabbit',
333
+ 331: 'hare',
334
+ 332: 'Angora, Angora rabbit',
335
+ 333: 'hamster',
336
+ 334: 'porcupine, hedgehog',
337
+ 335: 'fox squirrel, eastern fox squirrel, Sciurus niger',
338
+ 336: 'marmot',
339
+ 337: 'beaver',
340
+ 338: 'guinea pig, Cavia cobaya',
341
+ 339: 'sorrel',
342
+ 340: 'zebra',
343
+ 341: 'hog, pig, grunter, squealer, Sus scrofa',
344
+ 342: 'wild boar, boar, Sus scrofa',
345
+ 343: 'warthog',
346
+ 344: 'hippopotamus, hippo, river horse, Hippopotamus amphibius',
347
+ 345: 'ox',
348
+ 346: 'water buffalo, water ox, Asiatic buffalo, Bubalus bubalis',
349
+ 347: 'bison',
350
+ 348: 'ram, tup',
351
+ 349: 'bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis',
352
+ 350: 'ibex, Capra ibex',
353
+ 351: 'hartebeest',
354
+ 352: 'impala, Aepyceros melampus',
355
+ 353: 'gazelle',
356
+ 354: 'Arabian camel, dromedary, Camelus dromedarius',
357
+ 355: 'llama',
358
+ 356: 'weasel',
359
+ 357: 'mink',
360
+ 358: 'polecat, fitch, foulmart, foumart, Mustela putorius',
361
+ 359: 'black-footed ferret, ferret, Mustela nigripes',
362
+ 360: 'otter',
363
+ 361: 'skunk, polecat, wood pussy',
364
+ 362: 'badger',
365
+ 363: 'armadillo',
366
+ 364: 'three-toed sloth, ai, Bradypus tridactylus',
367
+ 365: 'orangutan, orang, orangutang, Pongo pygmaeus',
368
+ 366: 'gorilla, Gorilla gorilla',
369
+ 367: 'chimpanzee, chimp, Pan troglodytes',
370
+ 368: 'gibbon, Hylobates lar',
371
+ 369: 'siamang, Hylobates syndactylus, Symphalangus syndactylus',
372
+ 370: 'guenon, guenon monkey',
373
+ 371: 'patas, hussar monkey, Erythrocebus patas',
374
+ 372: 'baboon',
375
+ 373: 'macaque',
376
+ 374: 'langur',
377
+ 375: 'colobus, colobus monkey',
378
+ 376: 'proboscis monkey, Nasalis larvatus',
379
+ 377: 'marmoset',
380
+ 378: 'capuchin, ringtail, Cebus capucinus',
381
+ 379: 'howler monkey, howler',
382
+ 380: 'titi, titi monkey',
383
+ 381: 'spider monkey, Ateles geoffroyi',
384
+ 382: 'squirrel monkey, Saimiri sciureus',
385
+ 383: 'Madagascar cat, ring-tailed lemur, Lemur catta',
386
+ 384: 'indri, indris, Indri indri, Indri brevicaudatus',
387
+ 385: 'Indian elephant, Elephas maximus',
388
+ 386: 'African elephant, Loxodonta africana',
389
+ 387: 'lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens',
390
+ 388: 'giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca',
391
+ 389: 'barracouta, snoek',
392
+ 390: 'eel',
393
+ 391: 'coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch',
394
+ 392: 'rock beauty, Holocanthus tricolor',
395
+ 393: 'anemone fish',
396
+ 394: 'sturgeon',
397
+ 395: 'gar, garfish, garpike, billfish, Lepisosteus osseus',
398
+ 396: 'lionfish',
399
+ 397: 'puffer, pufferfish, blowfish, globefish',
400
+ 398: 'abacus',
401
+ 399: 'abaya',
402
+ 400: "academic gown, academic robe, judge's robe",
403
+ 401: 'accordion, piano accordion, squeeze box',
404
+ 402: 'acoustic guitar',
405
+ 403: 'aircraft carrier, carrier, flattop, attack aircraft carrier',
406
+ 404: 'airliner',
407
+ 405: 'airship, dirigible',
408
+ 406: 'altar',
409
+ 407: 'ambulance',
410
+ 408: 'amphibian, amphibious vehicle',
411
+ 409: 'analog clock',
412
+ 410: 'apiary, bee house',
413
+ 411: 'apron',
414
+ 412: 'ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin',
415
+ 413: 'assault rifle, assault gun',
416
+ 414: 'backpack, back pack, knapsack, packsack, rucksack, haversack',
417
+ 415: 'bakery, bakeshop, bakehouse',
418
+ 416: 'balance beam, beam',
419
+ 417: 'balloon',
420
+ 418: 'ballpoint, ballpoint pen, ballpen, Biro',
421
+ 419: 'Band Aid',
422
+ 420: 'banjo',
423
+ 421: 'bannister, banister, balustrade, balusters, handrail',
424
+ 422: 'barbell',
425
+ 423: 'barber chair',
426
+ 424: 'barbershop',
427
+ 425: 'barn',
428
+ 426: 'barometer',
429
+ 427: 'barrel, cask',
430
+ 428: 'barrow, garden cart, lawn cart, wheelbarrow',
431
+ 429: 'baseball',
432
+ 430: 'basketball',
433
+ 431: 'bassinet',
434
+ 432: 'bassoon',
435
+ 433: 'bathing cap, swimming cap',
436
+ 434: 'bath towel',
437
+ 435: 'bathtub, bathing tub, bath, tub',
438
+ 436: 'beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon',
439
+ 437: 'beacon, lighthouse, beacon light, pharos',
440
+ 438: 'beaker',
441
+ 439: 'bearskin, busby, shako',
442
+ 440: 'beer bottle',
443
+ 441: 'beer glass',
444
+ 442: 'bell cote, bell cot',
445
+ 443: 'bib',
446
+ 444: 'bicycle-built-for-two, tandem bicycle, tandem',
447
+ 445: 'bikini, two-piece',
448
+ 446: 'binder, ring-binder',
449
+ 447: 'binoculars, field glasses, opera glasses',
450
+ 448: 'birdhouse',
451
+ 449: 'boathouse',
452
+ 450: 'bobsled, bobsleigh, bob',
453
+ 451: 'bolo tie, bolo, bola tie, bola',
454
+ 452: 'bonnet, poke bonnet',
455
+ 453: 'bookcase',
456
+ 454: 'bookshop, bookstore, bookstall',
457
+ 455: 'bottlecap',
458
+ 456: 'bow',
459
+ 457: 'bow tie, bow-tie, bowtie',
460
+ 458: 'brass, memorial tablet, plaque',
461
+ 459: 'brassiere, bra, bandeau',
462
+ 460: 'breakwater, groin, groyne, mole, bulwark, seawall, jetty',
463
+ 461: 'breastplate, aegis, egis',
464
+ 462: 'broom',
465
+ 463: 'bucket, pail',
466
+ 464: 'buckle',
467
+ 465: 'bulletproof vest',
468
+ 466: 'bullet train, bullet',
469
+ 467: 'butcher shop, meat market',
470
+ 468: 'cab, hack, taxi, taxicab',
471
+ 469: 'caldron, cauldron',
472
+ 470: 'candle, taper, wax light',
473
+ 471: 'cannon',
474
+ 472: 'canoe',
475
+ 473: 'can opener, tin opener',
476
+ 474: 'cardigan',
477
+ 475: 'car mirror',
478
+ 476: 'carousel, carrousel, merry-go-round, roundabout, whirligig',
479
+ 477: "carpenter's kit, tool kit",
480
+ 478: 'carton',
481
+ 479: 'car wheel',
482
+ 480: 'cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM',
483
+ 481: 'cassette',
484
+ 482: 'cassette player',
485
+ 483: 'castle',
486
+ 484: 'catamaran',
487
+ 485: 'CD player',
488
+ 486: 'cello, violoncello',
489
+ 487: 'cellular telephone, cellular phone, cellphone, cell, mobile phone',
490
+ 488: 'chain',
491
+ 489: 'chainlink fence',
492
+ 490: 'chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour',
493
+ 491: 'chain saw, chainsaw',
494
+ 492: 'chest',
495
+ 493: 'chiffonier, commode',
496
+ 494: 'chime, bell, gong',
497
+ 495: 'china cabinet, china closet',
498
+ 496: 'Christmas stocking',
499
+ 497: 'church, church building',
500
+ 498: 'cinema, movie theater, movie theatre, movie house, picture palace',
501
+ 499: 'cleaver, meat cleaver, chopper',
502
+ 500: 'cliff dwelling',
503
+ 501: 'cloak',
504
+ 502: 'clog, geta, patten, sabot',
505
+ 503: 'cocktail shaker',
506
+ 504: 'coffee mug',
507
+ 505: 'coffeepot',
508
+ 506: 'coil, spiral, volute, whorl, helix',
509
+ 507: 'combination lock',
510
+ 508: 'computer keyboard, keypad',
511
+ 509: 'confectionery, confectionary, candy store',
512
+ 510: 'container ship, containership, container vessel',
513
+ 511: 'convertible',
514
+ 512: 'corkscrew, bottle screw',
515
+ 513: 'cornet, horn, trumpet, trump',
516
+ 514: 'cowboy boot',
517
+ 515: 'cowboy hat, ten-gallon hat',
518
+ 516: 'cradle',
519
+ 517: 'crane',
520
+ 518: 'crash helmet',
521
+ 519: 'crate',
522
+ 520: 'crib, cot',
523
+ 521: 'Crock Pot',
524
+ 522: 'croquet ball',
525
+ 523: 'crutch',
526
+ 524: 'cuirass',
527
+ 525: 'dam, dike, dyke',
528
+ 526: 'desk',
529
+ 527: 'desktop computer',
530
+ 528: 'dial telephone, dial phone',
531
+ 529: 'diaper, nappy, napkin',
532
+ 530: 'digital clock',
533
+ 531: 'digital watch',
534
+ 532: 'dining table, board',
535
+ 533: 'dishrag, dishcloth',
536
+ 534: 'dishwasher, dish washer, dishwashing machine',
537
+ 535: 'disk brake, disc brake',
538
+ 536: 'dock, dockage, docking facility',
539
+ 537: 'dogsled, dog sled, dog sleigh',
540
+ 538: 'dome',
541
+ 539: 'doormat, welcome mat',
542
+ 540: 'drilling platform, offshore rig',
543
+ 541: 'drum, membranophone, tympan',
544
+ 542: 'drumstick',
545
+ 543: 'dumbbell',
546
+ 544: 'Dutch oven',
547
+ 545: 'electric fan, blower',
548
+ 546: 'electric guitar',
549
+ 547: 'electric locomotive',
550
+ 548: 'entertainment center',
551
+ 549: 'envelope',
552
+ 550: 'espresso maker',
553
+ 551: 'face powder',
554
+ 552: 'feather boa, boa',
555
+ 553: 'file, file cabinet, filing cabinet',
556
+ 554: 'fireboat',
557
+ 555: 'fire engine, fire truck',
558
+ 556: 'fire screen, fireguard',
559
+ 557: 'flagpole, flagstaff',
560
+ 558: 'flute, transverse flute',
561
+ 559: 'folding chair',
562
+ 560: 'football helmet',
563
+ 561: 'forklift',
564
+ 562: 'fountain',
565
+ 563: 'fountain pen',
566
+ 564: 'four-poster',
567
+ 565: 'freight car',
568
+ 566: 'French horn, horn',
569
+ 567: 'frying pan, frypan, skillet',
570
+ 568: 'fur coat',
571
+ 569: 'garbage truck, dustcart',
572
+ 570: 'gasmask, respirator, gas helmet',
573
+ 571: 'gas pump, gasoline pump, petrol pump, island dispenser',
574
+ 572: 'goblet',
575
+ 573: 'go-kart',
576
+ 574: 'golf ball',
577
+ 575: 'golfcart, golf cart',
578
+ 576: 'gondola',
579
+ 577: 'gong, tam-tam',
580
+ 578: 'gown',
581
+ 579: 'grand piano, grand',
582
+ 580: 'greenhouse, nursery, glasshouse',
583
+ 581: 'grille, radiator grille',
584
+ 582: 'grocery store, grocery, food market, market',
585
+ 583: 'guillotine',
586
+ 584: 'hair slide',
587
+ 585: 'hair spray',
588
+ 586: 'half track',
589
+ 587: 'hammer',
590
+ 588: 'hamper',
591
+ 589: 'hand blower, blow dryer, blow drier, hair dryer, hair drier',
592
+ 590: 'hand-held computer, hand-held microcomputer',
593
+ 591: 'handkerchief, hankie, hanky, hankey',
594
+ 592: 'hard disc, hard disk, fixed disk',
595
+ 593: 'harmonica, mouth organ, harp, mouth harp',
596
+ 594: 'harp',
597
+ 595: 'harvester, reaper',
598
+ 596: 'hatchet',
599
+ 597: 'holster',
600
+ 598: 'home theater, home theatre',
601
+ 599: 'honeycomb',
602
+ 600: 'hook, claw',
603
+ 601: 'hoopskirt, crinoline',
604
+ 602: 'horizontal bar, high bar',
605
+ 603: 'horse cart, horse-cart',
606
+ 604: 'hourglass',
607
+ 605: 'iPod',
608
+ 606: 'iron, smoothing iron',
609
+ 607: "jack-o'-lantern",
610
+ 608: 'jean, blue jean, denim',
611
+ 609: 'jeep, landrover',
612
+ 610: 'jersey, T-shirt, tee shirt',
613
+ 611: 'jigsaw puzzle',
614
+ 612: 'jinrikisha, ricksha, rickshaw',
615
+ 613: 'joystick',
616
+ 614: 'kimono',
617
+ 615: 'knee pad',
618
+ 616: 'knot',
619
+ 617: 'lab coat, laboratory coat',
620
+ 618: 'ladle',
621
+ 619: 'lampshade, lamp shade',
622
+ 620: 'laptop, laptop computer',
623
+ 621: 'lawn mower, mower',
624
+ 622: 'lens cap, lens cover',
625
+ 623: 'letter opener, paper knife, paperknife',
626
+ 624: 'library',
627
+ 625: 'lifeboat',
628
+ 626: 'lighter, light, igniter, ignitor',
629
+ 627: 'limousine, limo',
630
+ 628: 'liner, ocean liner',
631
+ 629: 'lipstick, lip rouge',
632
+ 630: 'Loafer',
633
+ 631: 'lotion',
634
+ 632: 'loudspeaker, speaker, speaker unit, loudspeaker system, speaker system',
635
+ 633: "loupe, jeweler's loupe",
636
+ 634: 'lumbermill, sawmill',
637
+ 635: 'magnetic compass',
638
+ 636: 'mailbag, postbag',
639
+ 637: 'mailbox, letter box',
640
+ 638: 'maillot',
641
+ 639: 'maillot, tank suit',
642
+ 640: 'manhole cover',
643
+ 641: 'maraca',
644
+ 642: 'marimba, xylophone',
645
+ 643: 'mask',
646
+ 644: 'matchstick',
647
+ 645: 'maypole',
648
+ 646: 'maze, labyrinth',
649
+ 647: 'measuring cup',
650
+ 648: 'medicine chest, medicine cabinet',
651
+ 649: 'megalith, megalithic structure',
652
+ 650: 'microphone, mike',
653
+ 651: 'microwave, microwave oven',
654
+ 652: 'military uniform',
655
+ 653: 'milk can',
656
+ 654: 'minibus',
657
+ 655: 'miniskirt, mini',
658
+ 656: 'minivan',
659
+ 657: 'missile',
660
+ 658: 'mitten',
661
+ 659: 'mixing bowl',
662
+ 660: 'mobile home, manufactured home',
663
+ 661: 'Model T',
664
+ 662: 'modem',
665
+ 663: 'monastery',
666
+ 664: 'monitor',
667
+ 665: 'moped',
668
+ 666: 'mortar',
669
+ 667: 'mortarboard',
670
+ 668: 'mosque',
671
+ 669: 'mosquito net',
672
+ 670: 'motor scooter, scooter',
673
+ 671: 'mountain bike, all-terrain bike, off-roader',
674
+ 672: 'mountain tent',
675
+ 673: 'mouse, computer mouse',
676
+ 674: 'mousetrap',
677
+ 675: 'moving van',
678
+ 676: 'muzzle',
679
+ 677: 'nail',
680
+ 678: 'neck brace',
681
+ 679: 'necklace',
682
+ 680: 'nipple',
683
+ 681: 'notebook, notebook computer',
684
+ 682: 'obelisk',
685
+ 683: 'oboe, hautboy, hautbois',
686
+ 684: 'ocarina, sweet potato',
687
+ 685: 'odometer, hodometer, mileometer, milometer',
688
+ 686: 'oil filter',
689
+ 687: 'organ, pipe organ',
690
+ 688: 'oscilloscope, scope, cathode-ray oscilloscope, CRO',
691
+ 689: 'overskirt',
692
+ 690: 'oxcart',
693
+ 691: 'oxygen mask',
694
+ 692: 'packet',
695
+ 693: 'paddle, boat paddle',
696
+ 694: 'paddlewheel, paddle wheel',
697
+ 695: 'padlock',
698
+ 696: 'paintbrush',
699
+ 697: "pajama, pyjama, pj's, jammies",
700
+ 698: 'palace',
701
+ 699: 'panpipe, pandean pipe, syrinx',
702
+ 700: 'paper towel',
703
+ 701: 'parachute, chute',
704
+ 702: 'parallel bars, bars',
705
+ 703: 'park bench',
706
+ 704: 'parking meter',
707
+ 705: 'passenger car, coach, carriage',
708
+ 706: 'patio, terrace',
709
+ 707: 'pay-phone, pay-station',
710
+ 708: 'pedestal, plinth, footstall',
711
+ 709: 'pencil box, pencil case',
712
+ 710: 'pencil sharpener',
713
+ 711: 'perfume, essence',
714
+ 712: 'Petri dish',
715
+ 713: 'photocopier',
716
+ 714: 'pick, plectrum, plectron',
717
+ 715: 'pickelhaube',
718
+ 716: 'picket fence, paling',
719
+ 717: 'pickup, pickup truck',
720
+ 718: 'pier',
721
+ 719: 'piggy bank, penny bank',
722
+ 720: 'pill bottle',
723
+ 721: 'pillow',
724
+ 722: 'ping-pong ball',
725
+ 723: 'pinwheel',
726
+ 724: 'pirate, pirate ship',
727
+ 725: 'pitcher, ewer',
728
+ 726: "plane, carpenter's plane, woodworking plane",
729
+ 727: 'planetarium',
730
+ 728: 'plastic bag',
731
+ 729: 'plate rack',
732
+ 730: 'plow, plough',
733
+ 731: "plunger, plumber's helper",
734
+ 732: 'Polaroid camera, Polaroid Land camera',
735
+ 733: 'pole',
736
+ 734: 'police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria',
737
+ 735: 'poncho',
738
+ 736: 'pool table, billiard table, snooker table',
739
+ 737: 'pop bottle, soda bottle',
740
+ 738: 'pot, flowerpot',
741
+ 739: "potter's wheel",
742
+ 740: 'power drill',
743
+ 741: 'prayer rug, prayer mat',
744
+ 742: 'printer',
745
+ 743: 'prison, prison house',
746
+ 744: 'projectile, missile',
747
+ 745: 'projector',
748
+ 746: 'puck, hockey puck',
749
+ 747: 'punching bag, punch bag, punching ball, punchball',
750
+ 748: 'purse',
751
+ 749: 'quill, quill pen',
752
+ 750: 'quilt, comforter, comfort, puff',
753
+ 751: 'racer, race car, racing car',
754
+ 752: 'racket, racquet',
755
+ 753: 'radiator',
756
+ 754: 'radio, wireless',
757
+ 755: 'radio telescope, radio reflector',
758
+ 756: 'rain barrel',
759
+ 757: 'recreational vehicle, RV, R.V.',
760
+ 758: 'reel',
761
+ 759: 'reflex camera',
762
+ 760: 'refrigerator, icebox',
763
+ 761: 'remote control, remote',
764
+ 762: 'restaurant, eating house, eating place, eatery',
765
+ 763: 'revolver, six-gun, six-shooter',
766
+ 764: 'rifle',
767
+ 765: 'rocking chair, rocker',
768
+ 766: 'rotisserie',
769
+ 767: 'rubber eraser, rubber, pencil eraser',
770
+ 768: 'rugby ball',
771
+ 769: 'rule, ruler',
772
+ 770: 'running shoe',
773
+ 771: 'safe',
774
+ 772: 'safety pin',
775
+ 773: 'saltshaker, salt shaker',
776
+ 774: 'sandal',
777
+ 775: 'sarong',
778
+ 776: 'sax, saxophone',
779
+ 777: 'scabbard',
780
+ 778: 'scale, weighing machine',
781
+ 779: 'school bus',
782
+ 780: 'schooner',
783
+ 781: 'scoreboard',
784
+ 782: 'screen, CRT screen',
785
+ 783: 'screw',
786
+ 784: 'screwdriver',
787
+ 785: 'seat belt, seatbelt',
788
+ 786: 'sewing machine',
789
+ 787: 'shield, buckler',
790
+ 788: 'shoe shop, shoe-shop, shoe store',
791
+ 789: 'shoji',
792
+ 790: 'shopping basket',
793
+ 791: 'shopping cart',
794
+ 792: 'shovel',
795
+ 793: 'shower cap',
796
+ 794: 'shower curtain',
797
+ 795: 'ski',
798
+ 796: 'ski mask',
799
+ 797: 'sleeping bag',
800
+ 798: 'slide rule, slipstick',
801
+ 799: 'sliding door',
802
+ 800: 'slot, one-armed bandit',
803
+ 801: 'snorkel',
804
+ 802: 'snowmobile',
805
+ 803: 'snowplow, snowplough',
806
+ 804: 'soap dispenser',
807
+ 805: 'soccer ball',
808
+ 806: 'sock',
809
+ 807: 'solar dish, solar collector, solar furnace',
810
+ 808: 'sombrero',
811
+ 809: 'soup bowl',
812
+ 810: 'space bar',
813
+ 811: 'space heater',
814
+ 812: 'space shuttle',
815
+ 813: 'spatula',
816
+ 814: 'speedboat',
817
+ 815: "spider web, spider's web",
818
+ 816: 'spindle',
819
+ 817: 'sports car, sport car',
820
+ 818: 'spotlight, spot',
821
+ 819: 'stage',
822
+ 820: 'steam locomotive',
823
+ 821: 'steel arch bridge',
824
+ 822: 'steel drum',
825
+ 823: 'stethoscope',
826
+ 824: 'stole',
827
+ 825: 'stone wall',
828
+ 826: 'stopwatch, stop watch',
829
+ 827: 'stove',
830
+ 828: 'strainer',
831
+ 829: 'streetcar, tram, tramcar, trolley, trolley car',
832
+ 830: 'stretcher',
833
+ 831: 'studio couch, day bed',
834
+ 832: 'stupa, tope',
835
+ 833: 'submarine, pigboat, sub, U-boat',
836
+ 834: 'suit, suit of clothes',
837
+ 835: 'sundial',
838
+ 836: 'sunglass',
839
+ 837: 'sunglasses, dark glasses, shades',
840
+ 838: 'sunscreen, sunblock, sun blocker',
841
+ 839: 'suspension bridge',
842
+ 840: 'swab, swob, mop',
843
+ 841: 'sweatshirt',
844
+ 842: 'swimming trunks, bathing trunks',
845
+ 843: 'swing',
846
+ 844: 'switch, electric switch, electrical switch',
847
+ 845: 'syringe',
848
+ 846: 'table lamp',
849
+ 847: 'tank, army tank, armored combat vehicle, armoured combat vehicle',
850
+ 848: 'tape player',
851
+ 849: 'teapot',
852
+ 850: 'teddy, teddy bear',
853
+ 851: 'television, television system',
854
+ 852: 'tennis ball',
855
+ 853: 'thatch, thatched roof',
856
+ 854: 'theater curtain, theatre curtain',
857
+ 855: 'thimble',
858
+ 856: 'thresher, thrasher, threshing machine',
859
+ 857: 'throne',
860
+ 858: 'tile roof',
861
+ 859: 'toaster',
862
+ 860: 'tobacco shop, tobacconist shop, tobacconist',
863
+ 861: 'toilet seat',
864
+ 862: 'torch',
865
+ 863: 'totem pole',
866
+ 864: 'tow truck, tow car, wrecker',
867
+ 865: 'toyshop',
868
+ 866: 'tractor',
869
+ 867: 'trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi',
870
+ 868: 'tray',
871
+ 869: 'trench coat',
872
+ 870: 'tricycle, trike, velocipede',
873
+ 871: 'trimaran',
874
+ 872: 'tripod',
875
+ 873: 'triumphal arch',
876
+ 874: 'trolleybus, trolley coach, trackless trolley',
877
+ 875: 'trombone',
878
+ 876: 'tub, vat',
879
+ 877: 'turnstile',
880
+ 878: 'typewriter keyboard',
881
+ 879: 'umbrella',
882
+ 880: 'unicycle, monocycle',
883
+ 881: 'upright, upright piano',
884
+ 882: 'vacuum, vacuum cleaner',
885
+ 883: 'vase',
886
+ 884: 'vault',
887
+ 885: 'velvet',
888
+ 886: 'vending machine',
889
+ 887: 'vestment',
890
+ 888: 'viaduct',
891
+ 889: 'violin, fiddle',
892
+ 890: 'volleyball',
893
+ 891: 'waffle iron',
894
+ 892: 'wall clock',
895
+ 893: 'wallet, billfold, notecase, pocketbook',
896
+ 894: 'wardrobe, closet, press',
897
+ 895: 'warplane, military plane',
898
+ 896: 'washbasin, handbasin, washbowl, lavabo, wash-hand basin',
899
+ 897: 'washer, automatic washer, washing machine',
900
+ 898: 'water bottle',
901
+ 899: 'water jug',
902
+ 900: 'water tower',
903
+ 901: 'whiskey jug',
904
+ 902: 'whistle',
905
+ 903: 'wig',
906
+ 904: 'window screen',
907
+ 905: 'window shade',
908
+ 906: 'Windsor tie',
909
+ 907: 'wine bottle',
910
+ 908: 'wing',
911
+ 909: 'wok',
912
+ 910: 'wooden spoon',
913
+ 911: 'wool, woolen, woollen',
914
+ 912: 'worm fence, snake fence, snake-rail fence, Virginia fence',
915
+ 913: 'wreck',
916
+ 914: 'yawl',
917
+ 915: 'yurt',
918
+ 916: 'web site, website, internet site, site',
919
+ 917: 'comic book',
920
+ 918: 'crossword puzzle, crossword',
921
+ 919: 'street sign',
922
+ 920: 'traffic light, traffic signal, stoplight',
923
+ 921: 'book jacket, dust cover, dust jacket, dust wrapper',
924
+ 922: 'menu',
925
+ 923: 'plate',
926
+ 924: 'guacamole',
927
+ 925: 'consomme',
928
+ 926: 'hot pot, hotpot',
929
+ 927: 'trifle',
930
+ 928: 'ice cream, icecream',
931
+ 929: 'ice lolly, lolly, lollipop, popsicle',
932
+ 930: 'French loaf',
933
+ 931: 'bagel, beigel',
934
+ 932: 'pretzel',
935
+ 933: 'cheeseburger',
936
+ 934: 'hotdog, hot dog, red hot',
937
+ 935: 'mashed potato',
938
+ 936: 'head cabbage',
939
+ 937: 'broccoli',
940
+ 938: 'cauliflower',
941
+ 939: 'zucchini, courgette',
942
+ 940: 'spaghetti squash',
943
+ 941: 'acorn squash',
944
+ 942: 'butternut squash',
945
+ 943: 'cucumber, cuke',
946
+ 944: 'artichoke, globe artichoke',
947
+ 945: 'bell pepper',
948
+ 946: 'cardoon',
949
+ 947: 'mushroom',
950
+ 948: 'Granny Smith',
951
+ 949: 'strawberry',
952
+ 950: 'orange',
953
+ 951: 'lemon',
954
+ 952: 'fig',
955
+ 953: 'pineapple, ananas',
956
+ 954: 'banana',
957
+ 955: 'jackfruit, jak, jack',
958
+ 956: 'custard apple',
959
+ 957: 'pomegranate',
960
+ 958: 'hay',
961
+ 959: 'carbonara',
962
+ 960: 'chocolate sauce, chocolate syrup',
963
+ 961: 'dough',
964
+ 962: 'meat loaf, meatloaf',
965
+ 963: 'pizza, pizza pie',
966
+ 964: 'potpie',
967
+ 965: 'burrito',
968
+ 966: 'red wine',
969
+ 967: 'espresso',
970
+ 968: 'cup',
971
+ 969: 'eggnog',
972
+ 970: 'alp',
973
+ 971: 'bubble',
974
+ 972: 'cliff, drop, drop-off',
975
+ 973: 'coral reef',
976
+ 974: 'geyser',
977
+ 975: 'lakeside, lakeshore',
978
+ 976: 'promontory, headland, head, foreland',
979
+ 977: 'sandbar, sand bar',
980
+ 978: 'seashore, coast, seacoast, sea-coast',
981
+ 979: 'valley, vale',
982
+ 980: 'volcano',
983
+ 981: 'ballplayer, baseball player',
984
+ 982: 'groom, bridegroom',
985
+ 983: 'scuba diver',
986
+ 984: 'rapeseed',
987
+ 985: 'daisy',
988
+ 986: "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
989
+ 987: 'corn',
990
+ 988: 'acorn',
991
+ 989: 'hip, rose hip, rosehip',
992
+ 990: 'buckeye, horse chestnut, conker',
993
+ 991: 'coral fungus',
994
+ 992: 'agaric',
995
+ 993: 'gyromitra',
996
+ 994: 'stinkhorn, carrion fungus',
997
+ 995: 'earthstar',
998
+ 996: 'hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa',
999
+ 997: 'bolete',
1000
+ 998: 'ear, spike, capitulum',
1001
+ 999: 'toilet tissue, toilet paper, bathroom tissue'}
modeling/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (2024) Bytedance Ltd. and/or its affiliates
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
modeling/blocks.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Building blocks for TiTok.
2
+
3
+ Copyright (2024) Bytedance Ltd. and/or its affiliates
4
+
5
+ Licensed under the Apache License, Version 2.0 (the "License");
6
+ you may not use this file except in compliance with the License.
7
+ You may obtain a copy of the License at
8
+
9
+ http://www.apache.org/licenses/LICENSE-2.0
10
+
11
+ Unless required by applicable law or agreed to in writing, software
12
+ distributed under the License is distributed on an "AS IS" BASIS,
13
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ See the License for the specific language governing permissions and
15
+ limitations under the License.
16
+
17
+ Reference:
18
+ https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/transformer.py
19
+ """
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ from collections import OrderedDict
24
+
25
+
26
+ class ResidualAttentionBlock(nn.Module):
27
+ def __init__(
28
+ self,
29
+ d_model,
30
+ n_head,
31
+ mlp_ratio = 4.0,
32
+ act_layer = nn.GELU,
33
+ norm_layer = nn.LayerNorm
34
+ ):
35
+ super().__init__()
36
+
37
+ self.ln_1 = norm_layer(d_model)
38
+ self.attn = nn.MultiheadAttention(d_model, n_head)
39
+ self.mlp_ratio = mlp_ratio
40
+ # optionally we can disable the FFN
41
+ if mlp_ratio > 0:
42
+ self.ln_2 = norm_layer(d_model)
43
+ mlp_width = int(d_model * mlp_ratio)
44
+ self.mlp = nn.Sequential(OrderedDict([
45
+ ("c_fc", nn.Linear(d_model, mlp_width)),
46
+ ("gelu", act_layer()),
47
+ ("c_proj", nn.Linear(mlp_width, d_model))
48
+ ]))
49
+
50
+ def attention(
51
+ self,
52
+ x: torch.Tensor
53
+ ):
54
+ return self.attn(x, x, x, need_weights=False)[0]
55
+
56
+ def forward(
57
+ self,
58
+ x: torch.Tensor,
59
+ ):
60
+ attn_output = self.attention(x=self.ln_1(x))
61
+ x = x + attn_output
62
+ if self.mlp_ratio > 0:
63
+ x = x + self.mlp(self.ln_2(x))
64
+ return x
65
+
66
+
67
+ def _expand_token(token, batch_size: int):
68
+ return token.unsqueeze(0).expand(batch_size, -1, -1)
69
+
70
+
71
+ class TiTokEncoder(nn.Module):
72
+ def __init__(self, config):
73
+ super().__init__()
74
+ self.config = config
75
+ self.image_size = config.dataset.preprocessing.crop_size
76
+ self.patch_size = config.model.vq_model.vit_enc_patch_size
77
+ self.grid_size = self.image_size // self.patch_size
78
+ self.model_size = config.model.vq_model.vit_enc_model_size
79
+ self.num_latent_tokens = config.model.vq_model.num_latent_tokens
80
+ self.token_size = config.model.vq_model.token_size
81
+
82
+ self.width = {
83
+ "small": 512,
84
+ "base": 768,
85
+ "large": 1024,
86
+ }[self.model_size]
87
+ self.num_layers = {
88
+ "small": 8,
89
+ "base": 12,
90
+ "large": 24,
91
+ }[self.model_size]
92
+ self.num_heads = {
93
+ "small": 8,
94
+ "base": 12,
95
+ "large": 16,
96
+ }[self.model_size]
97
+
98
+ self.patch_embed = nn.Conv2d(
99
+ in_channels=3, out_channels=self.width,
100
+ kernel_size=self.patch_size, stride=self.patch_size, bias=True)
101
+
102
+ scale = self.width ** -0.5
103
+ self.class_embedding = nn.Parameter(scale * torch.randn(1, self.width))
104
+ self.positional_embedding = nn.Parameter(
105
+ scale * torch.randn(self.grid_size ** 2 + 1, self.width))
106
+ self.latent_token_positional_embedding = nn.Parameter(
107
+ scale * torch.randn(self.num_latent_tokens, self.width))
108
+ self.ln_pre = nn.LayerNorm(self.width)
109
+ self.transformer = nn.ModuleList()
110
+ for i in range(self.num_layers):
111
+ self.transformer.append(ResidualAttentionBlock(
112
+ self.width, self.num_heads, mlp_ratio=4.0
113
+ ))
114
+ self.ln_post = nn.LayerNorm(self.width)
115
+ self.conv_out = nn.Conv2d(self.width, self.token_size, kernel_size=1, bias=True)
116
+
117
+ def forward(self, pixel_values, latent_tokens):
118
+ batch_size = pixel_values.shape[0]
119
+ x = pixel_values
120
+ x = self.patch_embed(x)
121
+ x = x.reshape(x.shape[0], x.shape[1], -1)
122
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
123
+ # class embeddings and positional embeddings
124
+ x = torch.cat([_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1)
125
+ x = x + self.positional_embedding.to(x.dtype) # shape = [*, grid ** 2 + 1, width]
126
+
127
+
128
+ latent_tokens = _expand_token(latent_tokens, x.shape[0]).to(x.dtype)
129
+ latent_tokens = latent_tokens + self.latent_token_positional_embedding.to(x.dtype)
130
+ x = torch.cat([x, latent_tokens], dim=1)
131
+
132
+ x = self.ln_pre(x)
133
+ x = x.permute(1, 0, 2) # NLD -> LND
134
+ for i in range(self.num_layers):
135
+ x = self.transformer[i](x)
136
+ x = x.permute(1, 0, 2) # LND -> NLD
137
+
138
+ latent_tokens = x[:, 1+self.grid_size**2:]
139
+ latent_tokens = self.ln_post(latent_tokens)
140
+ # fake 2D shape
141
+ latent_tokens = latent_tokens.reshape(batch_size, self.width, self.num_latent_tokens, 1)
142
+ latent_tokens = self.conv_out(latent_tokens)
143
+ latent_tokens = latent_tokens.reshape(batch_size, self.token_size, 1, self.num_latent_tokens)
144
+ return latent_tokens
145
+
146
+
147
+ class TiTokDecoder(nn.Module):
148
+ def __init__(self, config):
149
+ super().__init__()
150
+ self.config = config
151
+ self.image_size = config.dataset.preprocessing.crop_size
152
+ self.patch_size = config.model.vq_model.vit_dec_patch_size
153
+ self.grid_size = self.image_size // self.patch_size
154
+ self.model_size = config.model.vq_model.vit_dec_model_size
155
+ self.num_latent_tokens = config.model.vq_model.num_latent_tokens
156
+ self.token_size = config.model.vq_model.token_size
157
+ self.width = {
158
+ "small": 512,
159
+ "base": 768,
160
+ "large": 1024,
161
+ }[self.model_size]
162
+ self.num_layers = {
163
+ "small": 8,
164
+ "base": 12,
165
+ "large": 24,
166
+ }[self.model_size]
167
+ self.num_heads = {
168
+ "small": 8,
169
+ "base": 12,
170
+ "large": 16,
171
+ }[self.model_size]
172
+
173
+ self.decoder_embed = nn.Linear(
174
+ self.token_size, self.width, bias=True)
175
+ scale = self.width ** -0.5
176
+ self.class_embedding = nn.Parameter(scale * torch.randn(1, self.width))
177
+ self.positional_embedding = nn.Parameter(
178
+ scale * torch.randn(self.grid_size ** 2 + 1, self.width))
179
+ # add mask token and query pos embed
180
+ self.mask_token = nn.Parameter(scale * torch.randn(1, 1, self.width))
181
+ self.latent_token_positional_embedding = nn.Parameter(
182
+ scale * torch.randn(self.num_latent_tokens, self.width))
183
+ self.ln_pre = nn.LayerNorm(self.width)
184
+ self.transformer = nn.ModuleList()
185
+ for i in range(self.num_layers):
186
+ self.transformer.append(ResidualAttentionBlock(
187
+ self.width, self.num_heads, mlp_ratio=4.0
188
+ ))
189
+ self.ln_post = nn.LayerNorm(self.width)
190
+
191
+ self.ffn = nn.Sequential(
192
+ nn.Conv2d(self.width, 2 * self.width, 1, padding=0, bias=True),
193
+ nn.Tanh(),
194
+ nn.Conv2d(2 * self.width, 1024, 1, padding=0, bias=True),
195
+ )
196
+ self.conv_out = nn.Identity()
197
+
198
+ def forward(self, z_quantized):
199
+ N, C, H, W = z_quantized.shape
200
+ assert H == 1 and W == self.num_latent_tokens, f"{H}, {W}, {self.num_latent_tokens}"
201
+ x = z_quantized.reshape(N, C*H, W).permute(0, 2, 1) # NLD
202
+ x = self.decoder_embed(x)
203
+
204
+ batchsize, seq_len, _ = x.shape
205
+
206
+ mask_tokens = self.mask_token.repeat(batchsize, self.grid_size**2, 1).to(x.dtype)
207
+ mask_tokens = torch.cat([_expand_token(self.class_embedding, mask_tokens.shape[0]).to(mask_tokens.dtype),
208
+ mask_tokens], dim=1)
209
+ mask_tokens = mask_tokens + self.positional_embedding.to(mask_tokens.dtype)
210
+ x = x + self.latent_token_positional_embedding[:seq_len]
211
+ x = torch.cat([mask_tokens, x], dim=1)
212
+
213
+ x = self.ln_pre(x)
214
+ x = x.permute(1, 0, 2) # NLD -> LND
215
+ for i in range(self.num_layers):
216
+ x = self.transformer[i](x)
217
+ x = x.permute(1, 0, 2) # LND -> NLD
218
+ x = x[:, 1:1+self.grid_size**2] # remove cls embed
219
+ x = self.ln_post(x)
220
+ # N L D -> N D H W
221
+ x = x.permute(0, 2, 1).reshape(batchsize, self.width, self.grid_size, self.grid_size)
222
+ x = self.ffn(x.contiguous())
223
+ x = self.conv_out(x)
224
+ return x
modeling/maskgit.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This file contains implementation for MaskGIT model.
2
+
3
+ Copyright (2024) Bytedance Ltd. and/or its affiliates
4
+
5
+ Licensed under the Apache License, Version 2.0 (the "License");
6
+ you may not use this file except in compliance with the License.
7
+ You may obtain a copy of the License at
8
+
9
+ http://www.apache.org/licenses/LICENSE-2.0
10
+
11
+ Unless required by applicable law or agreed to in writing, software
12
+ distributed under the License is distributed on an "AS IS" BASIS,
13
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ See the License for the specific language governing permissions and
15
+ limitations under the License.
16
+
17
+ Reference:
18
+ https://github.com/huggingface/open-muse
19
+ https://github.com/baaivision/MUSE-Pytorch
20
+ """
21
+
22
+ import torch
23
+ from torch import nn
24
+ import numpy as np
25
+ import math
26
+ import torch.utils.checkpoint
27
+ from transformers import BertConfig, BertModel
28
+
29
+
30
+ class ImageBert(nn.Module):
31
+ def __init__(self, config):
32
+ super().__init__()
33
+ self.config = config
34
+ self.target_codebook_size = config.model.vq_model.codebook_size
35
+ self.condition_num_classes = config.model.generator.condition_num_classes
36
+ self.image_seq_len = config.model.generator.image_seq_len
37
+ self.mask_token_id = self.target_codebook_size
38
+
39
+ self.model = BertModel(BertConfig(
40
+ vocab_size=self.target_codebook_size + self.condition_num_classes + 2,
41
+ hidden_size=768,
42
+ num_hidden_layers=24,
43
+ num_attention_heads=16,
44
+ intermediate_size=3072,
45
+ hidden_act='gelu',
46
+ hidden_dropout_prob=config.model.generator.dropout,
47
+ attention_probs_dropout_prob=config.model.generator.attn_drop,
48
+ max_position_embeddings=config.model.generator.image_seq_len + 1,
49
+ initializer_range=0.02,
50
+ layer_norm_eps=1e-12,
51
+ pad_token_id=None,
52
+ position_embedding_type="absolute",
53
+ use_cache=True
54
+ ), add_pooling_layer=False)
55
+ self.model.lm_head = nn.Linear(768, self.target_codebook_size, bias=True)
56
+
57
+ self.model.post_init()
58
+
59
+ def forward(self, input_ids=None, condition=None, cond_drop_prob=0.1):
60
+ # Token space:
61
+ # [0, codebook_size - 1] : those are the learned quantized image tokens
62
+ # codebook_size : the mask token used to mask image tokens
63
+ # [codebook_size + 1, codebook_size + nclass] : the imagenet class tokens
64
+ # codebook_size + 1 + nclass : the class drop label
65
+ drop_label_mask = torch.rand_like(condition, dtype=torch.float) < cond_drop_prob
66
+ # Shift the classes
67
+ condition = condition + self.target_codebook_size + 1 # [0, 999] -> [codebook_size + 1, codebook_size + 999]
68
+ condition[drop_label_mask] = self.condition_num_classes + self.target_codebook_size + 1
69
+ # prepend condition token
70
+ if input_ids is not None:
71
+ input_ids = torch.cat([condition.view(condition.shape[0], -1),
72
+ input_ids.view(input_ids.shape[0], -1),], dim=1)
73
+ else:
74
+ # at least there should be masked token
75
+ raise NotImplementedError
76
+ model_output = self.model(input_ids=input_ids)
77
+ model_output = model_output[0]
78
+ return self.model.lm_head(model_output[:, 1:]) # remove cond
79
+
80
+ # ref: https://github.com/baaivision/MUSE-Pytorch/blob/master/libs/muse.py#L40
81
+ @torch.no_grad()
82
+ def generate(self,
83
+ condition,
84
+ guidance_scale=3.0,
85
+ randomize_temperature=4.5,
86
+ num_sample_steps=8):
87
+ device = condition.device
88
+ ids = torch.full((condition.shape[0], self.image_seq_len),
89
+ self.mask_token_id, device=device)
90
+ cfg_scale = guidance_scale
91
+
92
+ for step in range(num_sample_steps):
93
+ ratio = 1. * (step + 1) / num_sample_steps
94
+ annealed_temp = randomize_temperature * (1.0 - ratio)
95
+ is_mask = (ids == self.mask_token_id)
96
+ if cfg_scale != 0:
97
+ cond_logits = self.forward(
98
+ ids, condition, cond_drop_prob=0.0
99
+ )
100
+ uncond_logits = self.forward(
101
+ ids, condition, cond_drop_prob=1.0
102
+ )
103
+ logits = cond_logits + (cond_logits - uncond_logits) * cfg_scale
104
+ else:
105
+ logits = self.forward(
106
+ ids, condition, cond_drop_prob=0.0
107
+ )
108
+ # Add gumbel noise
109
+ def log(t, eps=1e-20):
110
+ return torch.log(t.clamp(min=eps))
111
+ def gumbel_noise(t):
112
+ noise = torch.zeros_like(t).uniform_(0, 1)
113
+ return -log(-log(noise))
114
+ def add_gumbel_noise(t, temperature):
115
+ return t + temperature * gumbel_noise(t)
116
+
117
+ sampled_ids = add_gumbel_noise(logits, annealed_temp).argmax(dim=-1)
118
+ sampled_logits = torch.squeeze(
119
+ torch.gather(logits, dim=-1, index=torch.unsqueeze(sampled_ids, -1)), -1)
120
+ sampled_ids = torch.where(is_mask, sampled_ids, ids)
121
+ sampled_logits = torch.where(is_mask, sampled_logits, +np.inf).float()
122
+ # masking
123
+ mask_ratio = np.arccos(ratio) / (math.pi * 0.5)
124
+
125
+ mask_len = torch.Tensor([np.floor(self.image_seq_len * mask_ratio)]).to(device)
126
+ mask_len = torch.maximum(torch.Tensor([1]).to(device),
127
+ torch.minimum(torch.sum(is_mask, dim=-1, keepdims=True) - 1,
128
+ mask_len))[0].squeeze()
129
+ confidence = add_gumbel_noise(sampled_logits, annealed_temp)
130
+ sorted_confidence, _ = torch.sort(confidence, axis=-1)
131
+ cut_off = sorted_confidence[:, mask_len.long() - 1:mask_len.long()]
132
+ masking = (confidence <= cut_off)
133
+ if step == num_sample_steps - 1:
134
+ ids = sampled_ids
135
+ else:
136
+ ids = torch.where(masking, self.mask_token_id, sampled_ids)
137
+
138
+ return ids
modeling/maskgit_vqgan.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This file contains code for MaskGIT-VQGAN.
2
+
3
+ This file may have been modified by Bytedance Ltd. and/or its affiliates (“Bytedance's Modifications”).
4
+ All Bytedance's Modifications are Copyright (year) Bytedance Ltd. and/or its affiliates.
5
+
6
+ Reference:
7
+ https://github.com/huggingface/open-muse/blob/main/muse/modeling_maskgit_vqgan.py
8
+ """
9
+ # Copyright 2023 Google LLC and The HuggingFace Inc. team.
10
+ #
11
+ # Licensed under the Apache License, Version 2.0 (the "License");
12
+ # you may not use this file except in compliance with the License.
13
+ # You may obtain a copy of the License at
14
+ #
15
+ # http://www.apache.org/licenses/LICENSE-2.0
16
+ #
17
+ # Unless required by applicable law or agreed to in writing, software
18
+ # distributed under the License is distributed on an "AS IS" BASIS,
19
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20
+ # See the License for the specific language governing permissions and
21
+ # limitations under the License.
22
+
23
+ r"""MaskGIT Tokenizer based on VQGAN.
24
+
25
+ This tokenizer is a reimplementation of VQGAN [https://arxiv.org/abs/2012.09841]
26
+ with several modifications. The non-local layers are removed from VQGAN for
27
+ faster speed.
28
+ """
29
+
30
+ import math
31
+
32
+ import torch
33
+ import torch.nn.functional as F
34
+ from torch import nn
35
+
36
+
37
+ # Conv2D with same padding
38
+ class Conv2dSame(nn.Conv2d):
39
+ def calc_same_pad(self, i: int, k: int, s: int, d: int) -> int:
40
+ return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0)
41
+
42
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
43
+ ih, iw = x.size()[-2:]
44
+
45
+ pad_h = self.calc_same_pad(i=ih, k=self.kernel_size[0], s=self.stride[0], d=self.dilation[0])
46
+ pad_w = self.calc_same_pad(i=iw, k=self.kernel_size[1], s=self.stride[1], d=self.dilation[1])
47
+
48
+ if pad_h > 0 or pad_w > 0:
49
+ x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
50
+ return super().forward(x)
51
+
52
+
53
+ class ResnetBlock(nn.Module):
54
+ def __init__(
55
+ self,
56
+ in_channels: int,
57
+ out_channels: int = None,
58
+ dropout_prob: float = 0.0,
59
+ ):
60
+ super().__init__()
61
+
62
+ self.in_channels = in_channels
63
+ self.out_channels = out_channels
64
+ self.out_channels_ = self.in_channels if self.out_channels is None else self.out_channels
65
+
66
+ self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
67
+ self.conv1 = Conv2dSame(self.in_channels, self.out_channels_, kernel_size=3, bias=False)
68
+
69
+ self.norm2 = nn.GroupNorm(num_groups=32, num_channels=self.out_channels_, eps=1e-6, affine=True)
70
+ self.dropout = nn.Dropout(dropout_prob)
71
+ self.conv2 = Conv2dSame(self.out_channels_, self.out_channels_, kernel_size=3, bias=False)
72
+
73
+ if self.in_channels != self.out_channels_:
74
+ self.nin_shortcut = Conv2dSame(self.out_channels_, self.out_channels_, kernel_size=1, bias=False)
75
+
76
+ def forward(self, hidden_states):
77
+ residual = hidden_states
78
+ hidden_states = self.norm1(hidden_states)
79
+ hidden_states = F.silu(hidden_states)
80
+ hidden_states = self.conv1(hidden_states)
81
+
82
+ hidden_states = self.norm2(hidden_states)
83
+ hidden_states = F.silu(hidden_states)
84
+ hidden_states = self.dropout(hidden_states)
85
+ hidden_states = self.conv2(hidden_states)
86
+
87
+ if self.in_channels != self.out_channels_:
88
+ residual = self.nin_shortcut(hidden_states)
89
+
90
+ return hidden_states + residual
91
+
92
+
93
+ class DownsamplingBlock(nn.Module):
94
+ def __init__(self, config, block_idx: int):
95
+ super().__init__()
96
+
97
+ self.config = config
98
+ self.block_idx = block_idx
99
+
100
+ in_channel_mult = (1,) + tuple(self.config.channel_mult)
101
+ block_in = self.config.hidden_channels * in_channel_mult[self.block_idx]
102
+ block_out = self.config.hidden_channels * self.config.channel_mult[self.block_idx]
103
+
104
+ res_blocks = nn.ModuleList()
105
+ for _ in range(self.config.num_res_blocks):
106
+ res_blocks.append(ResnetBlock(block_in, block_out, dropout_prob=self.config.dropout))
107
+ block_in = block_out
108
+ self.block = res_blocks
109
+
110
+ self.downsample = self.block_idx != self.config.num_resolutions - 1
111
+
112
+ def forward(self, hidden_states):
113
+ for res_block in self.block:
114
+ hidden_states = res_block(hidden_states)
115
+
116
+ if self.downsample:
117
+ hidden_states = F.avg_pool2d(hidden_states, kernel_size=2, stride=2)
118
+
119
+ return hidden_states
120
+
121
+
122
+ class UpsamplingBlock(nn.Module):
123
+ def __init__(self, config, block_idx: int):
124
+ super().__init__()
125
+
126
+ self.config = config
127
+ self.block_idx = block_idx
128
+
129
+ if self.block_idx == self.config.num_resolutions - 1:
130
+ block_in = self.config.hidden_channels * self.config.channel_mult[-1]
131
+ else:
132
+ block_in = self.config.hidden_channels * self.config.channel_mult[self.block_idx + 1]
133
+
134
+ block_out = self.config.hidden_channels * self.config.channel_mult[self.block_idx]
135
+
136
+ res_blocks = []
137
+ for _ in range(self.config.num_res_blocks):
138
+ res_blocks.append(ResnetBlock(block_in, block_out, dropout_prob=self.config.dropout))
139
+ block_in = block_out
140
+ self.block = nn.ModuleList(res_blocks)
141
+
142
+ self.add_upsample = self.block_idx != 0
143
+ if self.add_upsample:
144
+ self.upsample_conv = Conv2dSame(block_out, block_out, kernel_size=3)
145
+
146
+ def forward(self, hidden_states):
147
+ for res_block in self.block:
148
+ hidden_states = res_block(hidden_states)
149
+
150
+ if self.add_upsample:
151
+ hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
152
+ hidden_states = self.upsample_conv(hidden_states)
153
+
154
+ return hidden_states
155
+
156
+
157
+ class Encoder(nn.Module):
158
+ def __init__(self, config):
159
+ super().__init__()
160
+ self.config = config
161
+ # downsampling
162
+ self.conv_in = Conv2dSame(self.config.num_channels, self.config.hidden_channels, kernel_size=3, bias=False)
163
+
164
+ downsample_blocks = []
165
+ for i_level in range(self.config.num_resolutions):
166
+ downsample_blocks.append(DownsamplingBlock(self.config, block_idx=i_level))
167
+ self.down = nn.ModuleList(downsample_blocks)
168
+
169
+ # middle
170
+ mid_channels = self.config.hidden_channels * self.config.channel_mult[-1]
171
+ res_blocks = nn.ModuleList()
172
+ for _ in range(self.config.num_res_blocks):
173
+ res_blocks.append(ResnetBlock(mid_channels, mid_channels, dropout_prob=self.config.dropout))
174
+ self.mid = res_blocks
175
+
176
+ # end
177
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=mid_channels, eps=1e-6, affine=True)
178
+ self.conv_out = Conv2dSame(mid_channels, self.config.z_channels, kernel_size=1)
179
+
180
+ def forward(self, pixel_values):
181
+ # downsampling
182
+ hidden_states = self.conv_in(pixel_values)
183
+ for block in self.down:
184
+ hidden_states = block(hidden_states)
185
+
186
+ # middle
187
+ for block in self.mid:
188
+ hidden_states = block(hidden_states)
189
+
190
+ # end
191
+ hidden_states = self.norm_out(hidden_states)
192
+ hidden_states = F.silu(hidden_states)
193
+ hidden_states = self.conv_out(hidden_states)
194
+ return hidden_states
195
+
196
+
197
+ class Decoder(nn.Module):
198
+ def __init__(self, config):
199
+ super().__init__()
200
+
201
+ self.config = config
202
+
203
+ # compute in_channel_mult, block_in and curr_res at lowest res
204
+ block_in = self.config.hidden_channels * self.config.channel_mult[self.config.num_resolutions - 1]
205
+ curr_res = self.config.resolution // 2 ** (self.config.num_resolutions - 1)
206
+ self.z_shape = (1, self.config.z_channels, curr_res, curr_res)
207
+
208
+ # z to block_in
209
+ self.conv_in = Conv2dSame(self.config.z_channels, block_in, kernel_size=3)
210
+
211
+ # middle
212
+ res_blocks = nn.ModuleList()
213
+ for _ in range(self.config.num_res_blocks):
214
+ res_blocks.append(ResnetBlock(block_in, block_in, dropout_prob=self.config.dropout))
215
+ self.mid = res_blocks
216
+
217
+ # upsampling
218
+ upsample_blocks = []
219
+ for i_level in reversed(range(self.config.num_resolutions)):
220
+ upsample_blocks.append(UpsamplingBlock(self.config, block_idx=i_level))
221
+ self.up = nn.ModuleList(list(reversed(upsample_blocks))) # reverse to get consistent order
222
+
223
+ # end
224
+ block_out = self.config.hidden_channels * self.config.channel_mult[0]
225
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_out, eps=1e-6, affine=True)
226
+ self.conv_out = Conv2dSame(block_out, self.config.num_channels, kernel_size=3)
227
+
228
+ def forward(self, hidden_states):
229
+ # z to block_in
230
+ hidden_states = self.conv_in(hidden_states)
231
+
232
+ # middle
233
+ for block in self.mid:
234
+ hidden_states = block(hidden_states)
235
+
236
+ # upsampling
237
+ for block in reversed(self.up):
238
+ hidden_states = block(hidden_states)
239
+
240
+ # end
241
+ hidden_states = self.norm_out(hidden_states)
242
+ hidden_states = F.silu(hidden_states)
243
+ hidden_states = self.conv_out(hidden_states)
244
+
245
+ return hidden_states
246
+
247
+
248
+ class VectorQuantizer(nn.Module):
249
+ """
250
+ see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
251
+ Discretization bottleneck part of the VQ-VAE.
252
+ """
253
+
254
+ def __init__(self, num_embeddings, embedding_dim, commitment_cost):
255
+ r"""
256
+ Args:
257
+ num_embeddings: number of vectors in the quantized space.
258
+ embedding_dim: dimensionality of the tensors in the quantized space.
259
+ Inputs to the modules must be in this format as well.
260
+ commitment_cost: scalar which controls the weighting of the loss terms
261
+ (see equation 4 in the paper https://arxiv.org/abs/1711.00937 - this variable is Beta).
262
+ """
263
+ super().__init__()
264
+
265
+ self.num_embeddings = num_embeddings
266
+ self.embedding_dim = embedding_dim
267
+ self.commitment_cost = commitment_cost
268
+
269
+ self.embedding = nn.Embedding(num_embeddings, embedding_dim)
270
+ self.embedding.weight.data.uniform_(-1.0 / num_embeddings, 1.0 / num_embeddings)
271
+
272
+ def forward(self, hidden_states, return_loss=False):
273
+ """
274
+ Inputs the output of the encoder network z and maps it to a discrete one-hot vector that is the index of the
275
+ closest embedding vector e_j z (continuous) -> z_q (discrete) z.shape = (batch, channel, height, width)
276
+ quantization pipeline:
277
+ 1. get encoder input (B,C,H,W)
278
+ 2. flatten input to (B*H*W,C)
279
+ """
280
+ # reshape z -> (batch, height, width, channel) and flatten
281
+ hidden_states = hidden_states.permute(0, 2, 3, 1).contiguous()
282
+
283
+ distances = self.compute_distances(hidden_states)
284
+ min_encoding_indices = torch.argmin(distances, axis=1).unsqueeze(1)
285
+ min_encodings = torch.zeros(min_encoding_indices.shape[0], self.num_embeddings).to(hidden_states)
286
+ min_encodings.scatter_(1, min_encoding_indices, 1)
287
+
288
+ # get quantized latent vectors
289
+ z_q = torch.matmul(min_encodings, self.embedding.weight).view(hidden_states.shape)
290
+
291
+ # reshape to (batch, num_tokens)
292
+ min_encoding_indices = min_encoding_indices.reshape(hidden_states.shape[0], -1)
293
+
294
+ # compute loss for embedding
295
+ loss = None
296
+ if return_loss:
297
+ loss = torch.mean((z_q.detach() - hidden_states) ** 2) + self.commitment_cost * torch.mean(
298
+ (z_q - hidden_states.detach()) ** 2
299
+ )
300
+ # preserve gradients
301
+ z_q = hidden_states + (z_q - hidden_states).detach()
302
+
303
+ # reshape back to match original input shape
304
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
305
+
306
+ return z_q, min_encoding_indices, loss
307
+
308
+ def compute_distances(self, hidden_states):
309
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
310
+ hidden_states_flattended = hidden_states.reshape((-1, self.embedding_dim))
311
+ emb_weights = self.embedding.weight.t()
312
+
313
+ inputs_norm_sq = hidden_states_flattended.pow(2.0).sum(dim=1, keepdim=True)
314
+ codebook_t_norm_sq = emb_weights.pow(2.0).sum(dim=0, keepdim=True)
315
+ distances = torch.addmm(
316
+ inputs_norm_sq + codebook_t_norm_sq,
317
+ hidden_states_flattended,
318
+ emb_weights,
319
+ alpha=-2.0,
320
+ )
321
+ return distances
322
+
323
+ def get_codebook_entry(self, indices):
324
+ # indices are expected to be of shape (batch, num_tokens)
325
+ # get quantized latent vectors
326
+ if len(indices.shape) == 2:
327
+ batch, num_tokens = indices.shape
328
+ z_q = self.embedding(indices)
329
+ z_q = z_q.reshape(batch, int(math.sqrt(num_tokens)), int(math.sqrt(num_tokens)), -1).permute(0, 3, 1, 2)
330
+ elif len(indices.shape) == 3:
331
+ batch, height, width = indices.shape
332
+ indices = indices.view(batch, -1)
333
+ z_q = self.embedding(indices)
334
+ z_q = z_q.reshape(batch, height, width, -1).permute(0, 3, 1, 2)
335
+ else:
336
+ print(indices.shape)
337
+ raise NotImplementedError
338
+ return z_q
339
+
340
+ # adapted from https://github.com/kakaobrain/rq-vae-transformer/blob/main/rqvae/models/rqvae/quantizations.py#L372
341
+ def get_soft_code(self, hidden_states, temp=1.0, stochastic=False):
342
+ hidden_states = hidden_states.permute(0, 2, 3, 1).contiguous() # (batch, height, width, channel)
343
+ distances = self.compute_distances(hidden_states) # (batch * height * width, num_embeddings)
344
+
345
+ soft_code = F.softmax(-distances / temp, dim=-1) # (batch * height * width, num_embeddings)
346
+ if stochastic:
347
+ code = torch.multinomial(soft_code, 1) # (batch * height * width, 1)
348
+ else:
349
+ code = distances.argmin(dim=-1) # (batch * height * width)
350
+
351
+ code = code.reshape(hidden_states.shape[0], -1) # (batch, height * width)
352
+ batch, num_tokens = code.shape
353
+ soft_code = soft_code.reshape(batch, num_tokens, -1) # (batch, height * width, num_embeddings)
354
+ return soft_code, code
355
+
356
+ def get_code(self, hidden_states):
357
+ # reshape z -> (batch, height, width, channel)
358
+ hidden_states = hidden_states.permute(0, 2, 3, 1).contiguous()
359
+ distances = self.compute_distances(hidden_states)
360
+ indices = torch.argmin(distances, axis=1).unsqueeze(1)
361
+ indices = indices.reshape(hidden_states.shape[0], -1)
362
+ return indices
modeling/quantizer.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Vector quantizer.
2
+
3
+ Copyright (2024) Bytedance Ltd. and/or its affiliates
4
+
5
+ Licensed under the Apache License, Version 2.0 (the "License");
6
+ you may not use this file except in compliance with the License.
7
+ You may obtain a copy of the License at
8
+
9
+ http://www.apache.org/licenses/LICENSE-2.0
10
+
11
+ Unless required by applicable law or agreed to in writing, software
12
+ distributed under the License is distributed on an "AS IS" BASIS,
13
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ See the License for the specific language governing permissions and
15
+ limitations under the License.
16
+
17
+ Reference:
18
+ https://github.com/CompVis/taming-transformers/blob/master/taming/modules/vqvae/quantize.py
19
+ https://github.com/google-research/magvit/blob/main/videogvt/models/vqvae.py
20
+ """
21
+ from typing import Mapping, Text, Tuple
22
+
23
+ import torch
24
+ from einops import rearrange
25
+ from torch.cuda.amp import autocast
26
+
27
+ class VectorQuantizer(torch.nn.Module):
28
+ def __init__(self,
29
+ codebook_size: int = 1024,
30
+ token_size: int = 256,
31
+ commitment_cost: float = 0.25,
32
+ use_l2_norm: bool = False,
33
+ ):
34
+ super().__init__()
35
+ self.commitment_cost = commitment_cost
36
+
37
+ self.embedding = torch.nn.Embedding(codebook_size, token_size)
38
+ self.embedding.weight.data.uniform_(-1.0 / codebook_size, 1.0 / codebook_size)
39
+ self.use_l2_norm = use_l2_norm
40
+
41
+ # Ensure quantization is performed using f32
42
+ @autocast(enabled=False)
43
+ def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]:
44
+ z = z.float()
45
+ z = rearrange(z, 'b c h w -> b h w c').contiguous()
46
+ z_flattened = rearrange(z, 'b h w c -> (b h w) c')
47
+
48
+ if self.use_l2_norm:
49
+ z_flattened = torch.nn.functional.normalize(z_flattened, dim=-1)
50
+ embedding = torch.nn.functional.normalize(self.embedding.weight, dim=-1)
51
+ else:
52
+ embedding = self.embedding.weight
53
+ d = torch.sum(z_flattened**2, dim=1, keepdim=True) + \
54
+ torch.sum(embedding**2, dim=1) - 2 * \
55
+ torch.einsum('bd,dn->bn', z_flattened, embedding.T)
56
+
57
+ min_encoding_indices = torch.argmin(d, dim=1) # num_ele
58
+ z_quantized = self.get_codebook_entry(min_encoding_indices).view(z.shape)
59
+
60
+ if self.use_l2_norm:
61
+ z_quantized = torch.nn.functional.normalize(z_quantized, dim=-1)
62
+ z = torch.nn.functional.normalize(z, dim=-1)
63
+
64
+ # compute loss for embedding
65
+ commitment_loss = self.commitment_cost * torch.mean((z_quantized.detach() - z) **2)
66
+ codebook_loss = torch.mean((z_quantized - z.detach()) **2)
67
+
68
+ loss = commitment_loss + codebook_loss
69
+
70
+ # preserve gradients
71
+ z_quantized = z + (z_quantized - z).detach()
72
+
73
+ # reshape back to match original input shape
74
+ z_quantized = rearrange(z_quantized, 'b h w c -> b c h w').contiguous()
75
+
76
+ result_dict = dict(
77
+ quantizer_loss=loss,
78
+ commitment_loss=commitment_loss,
79
+ codebook_loss=codebook_loss,
80
+ min_encoding_indices=min_encoding_indices.view(z_quantized.shape[0], z_quantized.shape[2], z_quantized.shape[3])
81
+ )
82
+
83
+ return z_quantized, result_dict
84
+
85
+ def get_codebook_entry(self, indices):
86
+ if len(indices.shape) == 1:
87
+ z_quantized = self.embedding(indices)
88
+ elif len(indices.shape) == 2:
89
+ z_quantized = torch.einsum('bd,dn->bn', indices, self.embedding.weight)
90
+ else:
91
+ raise NotImplementedError
92
+ return z_quantized
modeling/titok.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This file contains the model definition of TiTok.
2
+
3
+ Copyright (2024) Bytedance Ltd. and/or its affiliates
4
+
5
+ Licensed under the Apache License, Version 2.0 (the "License");
6
+ you may not use this file except in compliance with the License.
7
+ You may obtain a copy of the License at
8
+
9
+ http://www.apache.org/licenses/LICENSE-2.0
10
+
11
+ Unless required by applicable law or agreed to in writing, software
12
+ distributed under the License is distributed on an "AS IS" BASIS,
13
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ See the License for the specific language governing permissions and
15
+ limitations under the License.
16
+ """
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ from einops import rearrange
21
+
22
+ from .blocks import TiTokEncoder, TiTokDecoder
23
+ from .quantizer import VectorQuantizer
24
+ from .maskgit_vqgan import Decoder as Pixel_Decoder
25
+ from .maskgit_vqgan import VectorQuantizer as Pixel_Quantizer
26
+ from omegaconf import OmegaConf
27
+
28
+ class TiTok(nn.Module):
29
+ def __init__(self, config):
30
+ super().__init__()
31
+ self.config = config
32
+ self.encoder = TiTokEncoder(config)
33
+ self.decoder = TiTokDecoder(config)
34
+
35
+ self.num_latent_tokens = config.model.vq_model.num_latent_tokens
36
+ scale = self.encoder.width ** -0.5
37
+ self.latent_tokens = nn.Parameter(
38
+ scale * torch.randn(self.num_latent_tokens, self.encoder.width))
39
+
40
+ self.apply(self._init_weights)
41
+
42
+ self.quantize = VectorQuantizer(
43
+ codebook_size=config.model.vq_model.codebook_size,
44
+ token_size=config.model.vq_model.token_size,
45
+ commitment_cost=config.model.vq_model.commitment_cost,
46
+ use_l2_norm=config.model.vq_model.use_l2_norm,)
47
+
48
+ self.pixel_quantize = Pixel_Quantizer(
49
+ num_embeddings=1024, embedding_dim=256, commitment_cost=0.25)
50
+ self.pixel_decoder = Pixel_Decoder(OmegaConf.create(
51
+ {"channel_mult": [1, 1, 2, 2, 4],
52
+ "num_resolutions": 5,
53
+ "dropout": 0.0,
54
+ "hidden_channels": 128,
55
+ "num_channels": 3,
56
+ "num_res_blocks": 2,
57
+ "resolution": 256,
58
+ "z_channels": 256}))
59
+
60
+ def _init_weights(self, module):
61
+ """ Initialize the weights.
62
+ :param:
63
+ module -> torch.nn.Module: module to initialize
64
+ """
65
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv1d) or isinstance(module, nn.Conv2d):
66
+ module.weight.data = nn.init.trunc_normal_(module.weight.data, mean=0.0, std=0.02)
67
+ if module.bias is not None:
68
+ module.bias.data.zero_()
69
+ elif isinstance(module, nn.Embedding):
70
+ module.weight.data = nn.init.trunc_normal_(module.weight.data, mean=0.0, std=0.02)
71
+ elif isinstance(module, nn.LayerNorm):
72
+ module.bias.data.zero_()
73
+ module.weight.data.fill_(1.0)
74
+
75
+ def encode(self, x):
76
+ z = self.encoder(pixel_values=x, latent_tokens=self.latent_tokens)
77
+ z_quantized, result_dict = self.quantize(z)
78
+ return z_quantized, result_dict
79
+
80
+ def decode(self, z_quantized):
81
+ decoded_latent = self.decoder(z_quantized)
82
+ quantized_states = torch.einsum(
83
+ 'nchw,cd->ndhw', decoded_latent.softmax(1),
84
+ self.pixel_quantize.embedding.weight)
85
+ decoded = self.pixel_decoder(quantized_states)
86
+ return decoded
87
+
88
+ def decode_tokens(self, tokens):
89
+ tokens = tokens.squeeze(1)
90
+ batch, seq_len = tokens.shape # B x N
91
+ z_quantized = self.quantize.get_codebook_entry(
92
+ tokens.reshape(-1)).reshape(batch, 1, seq_len, -1)
93
+ if self.quantize.use_l2_norm:
94
+ z_quantized = torch.nn.functional.normalize(z_quantized, dim=-1)
95
+ z_quantized = rearrange(z_quantized, 'b h w c -> b c h w').contiguous()
96
+ decoded = self.decode(z_quantized)
97
+ return decoded
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ torchvision
3
+ omegaconf
4
+ transformers
5
+ timm
6
+ open_clip_torch
7
+ einops
8
+ scipy
9
+ pillow
10
+ accelerate
11
+ gdown