Spaces:
Runtime error
Runtime error
Rodneyontherock1067
commited on
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +4 -0
- .gitignore +2 -0
- LICENSE.txt +77 -0
- Notice +233 -0
- README.md +172 -12
- README_zh.md +425 -0
- assets/3dvae.png +0 -0
- assets/WECHAT.md +7 -0
- assets/backbone.png +3 -0
- assets/hunyuanvideo.pdf +3 -0
- assets/logo.png +0 -0
- assets/overall.png +3 -0
- assets/text_encoder.png +3 -0
- assets/wechat.jpg +0 -0
- docker/Dockerfile_xDiT +41 -0
- environment.yml +8 -0
- gradio_server.py +376 -0
- hyvideo/__init__.py +0 -0
- hyvideo/config.py +406 -0
- hyvideo/constants.py +90 -0
- hyvideo/diffusion/__init__.py +2 -0
- hyvideo/diffusion/pipelines/__init__.py +1 -0
- hyvideo/diffusion/pipelines/pipeline_hunyuan_video.py +1103 -0
- hyvideo/diffusion/schedulers/__init__.py +1 -0
- hyvideo/diffusion/schedulers/scheduling_flow_match_discrete.py +257 -0
- hyvideo/inference.py +682 -0
- hyvideo/modules/__init__.py +26 -0
- hyvideo/modules/activation_layers.py +23 -0
- hyvideo/modules/attenion.py +257 -0
- hyvideo/modules/embed_layers.py +157 -0
- hyvideo/modules/mlp_layers.py +118 -0
- hyvideo/modules/models.py +870 -0
- hyvideo/modules/modulate_layers.py +76 -0
- hyvideo/modules/norm_layers.py +77 -0
- hyvideo/modules/posemb_layers.py +310 -0
- hyvideo/modules/token_refiner.py +236 -0
- hyvideo/prompt_rewrite.py +51 -0
- hyvideo/text_encoder/__init__.py +366 -0
- hyvideo/utils/__init__.py +0 -0
- hyvideo/utils/data_utils.py +15 -0
- hyvideo/utils/file_utils.py +70 -0
- hyvideo/utils/helpers.py +40 -0
- hyvideo/utils/preprocess_text_encoder_tokenizer_utils.py +46 -0
- hyvideo/vae/__init__.py +62 -0
- hyvideo/vae/autoencoder_kl_causal_3d.py +603 -0
- hyvideo/vae/unet_causal_3d_blocks.py +764 -0
- hyvideo/vae/vae.py +355 -0
- requirements.txt +21 -0
- requirements_xdit.txt +16 -0
- sample_video.py +74 -0
.gitattributes
CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/backbone.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
assets/hunyuanvideo.pdf filter=lfs diff=lfs merge=lfs -text
|
38 |
+
assets/overall.png filter=lfs diff=lfs merge=lfs -text
|
39 |
+
assets/text_encoder.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
__pycache__
|
2 |
+
/ckpts/**/
|
LICENSE.txt
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT
|
2 |
+
Tencent HunyuanVideo Release Date: December 3, 2024
|
3 |
+
THIS LICENSE AGREEMENT DOES NOT APPLY IN THE EUROPEAN UNION, UNITED KINGDOM AND SOUTH KOREA AND IS EXPRESSLY LIMITED TO THE TERRITORY, AS DEFINED BELOW.
|
4 |
+
By clicking to agree or by using, reproducing, modifying, distributing, performing or displaying any portion or element of the Tencent Hunyuan Works, including via any Hosted Service, You will be deemed to have recognized and accepted the content of this Agreement, which is effective immediately.
|
5 |
+
1. DEFINITIONS.
|
6 |
+
a. “Acceptable Use Policy” shall mean the policy made available by Tencent as set forth in the Exhibit A.
|
7 |
+
b. “Agreement” shall mean the terms and conditions for use, reproduction, distribution, modification, performance and displaying of Tencent Hunyuan Works or any portion or element thereof set forth herein.
|
8 |
+
c. “Documentation” shall mean the specifications, manuals and documentation for Tencent Hunyuan made publicly available by Tencent.
|
9 |
+
d. “Hosted Service” shall mean a hosted service offered via an application programming interface (API), web access, or any other electronic or remote means.
|
10 |
+
e. “Licensee,” “You” or “Your” shall mean a natural person or legal entity exercising the rights granted by this Agreement and/or using the Tencent Hunyuan Works for any purpose and in any field of use.
|
11 |
+
f. “Materials” shall mean, collectively, Tencent’s proprietary Tencent Hunyuan and Documentation (and any portion thereof) as made available by Tencent under this Agreement.
|
12 |
+
g. “Model Derivatives” shall mean all: (i) modifications to Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; (ii) works based on Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; or (iii) any other machine learning model which is created by transfer of patterns of the weights, parameters, operations, or Output of Tencent Hunyuan or any Model Derivative of Tencent Hunyuan, to that model in order to cause that model to perform similarly to Tencent Hunyuan or a Model Derivative of Tencent Hunyuan, including distillation methods, methods that use intermediate data representations, or methods based on the generation of synthetic data Outputs by Tencent Hunyuan or a Model Derivative of Tencent Hunyuan for training that model. For clarity, Outputs by themselves are not deemed Model Derivatives.
|
13 |
+
h. “Output” shall mean the information and/or content output of Tencent Hunyuan or a Model Derivative that results from operating or otherwise using Tencent Hunyuan or a Model Derivative, including via a Hosted Service.
|
14 |
+
i. “Tencent,” “We” or “Us” shall mean THL A29 Limited.
|
15 |
+
j. “Tencent Hunyuan” shall mean the large language models, text/image/video/audio/3D generation models, and multimodal large language models and their software and algorithms, including trained model weights, parameters (including optimizer states), machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing made publicly available by Us, including, without limitation to, Tencent HunyuanVideo released at [https://github.com/Tencent/HunyuanVideo].
|
16 |
+
k. “Tencent Hunyuan Works” shall mean: (i) the Materials; (ii) Model Derivatives; and (iii) all derivative works thereof.
|
17 |
+
l. “Territory” shall mean the worldwide territory, excluding the territory of the European Union, United Kingdom and South Korea.
|
18 |
+
m. “Third Party” or “Third Parties” shall mean individuals or legal entities that are not under common control with Us or You.
|
19 |
+
n. “including” shall mean including but not limited to.
|
20 |
+
2. GRANT OF RIGHTS.
|
21 |
+
We grant You, for the Territory only, a non-exclusive, non-transferable and royalty-free limited license under Tencent’s intellectual property or other rights owned by Us embodied in or utilized by the Materials to use, reproduce, distribute, create derivative works of (including Model Derivatives), and make modifications to the Materials, only in accordance with the terms of this Agreement and the Acceptable Use Policy, and You must not violate (or encourage or permit anyone else to violate) any term of this Agreement or the Acceptable Use Policy.
|
22 |
+
3. DISTRIBUTION.
|
23 |
+
You may, subject to Your compliance with this Agreement, distribute or make available to Third Parties the Tencent Hunyuan Works, exclusively in the Territory, provided that You meet all of the following conditions:
|
24 |
+
a. You must provide all such Third Party recipients of the Tencent Hunyuan Works or products or services using them a copy of this Agreement;
|
25 |
+
b. You must cause any modified files to carry prominent notices stating that You changed the files;
|
26 |
+
c. You are encouraged to: (i) publish at least one technology introduction blogpost or one public statement expressing Your experience of using the Tencent Hunyuan Works; and (ii) mark the products or services developed by using the Tencent Hunyuan Works to indicate that the product/service is “Powered by Tencent Hunyuan”; and
|
27 |
+
d. All distributions to Third Parties (other than through a Hosted Service) must be accompanied by a “Notice” text file that contains the following notice: “Tencent Hunyuan is licensed under the Tencent Hunyuan Community License Agreement, Copyright © 2024 Tencent. All Rights Reserved. The trademark rights of “Tencent Hunyuan” are owned by Tencent or its affiliate.”
|
28 |
+
You may add Your own copyright statement to Your modifications and, except as set forth in this Section and in Section 5, may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Model Derivatives as a whole, provided Your use, reproduction, modification, distribution, performance and display of the work otherwise complies with the terms and conditions of this Agreement (including as regards the Territory). If You receive Tencent Hunyuan Works from a Licensee as part of an integrated end user product, then this Section 3 of this Agreement will not apply to You.
|
29 |
+
4. ADDITIONAL COMMERCIAL TERMS.
|
30 |
+
If, on the Tencent Hunyuan version release date, the monthly active users of all products or services made available by or for Licensee is greater than 100 million monthly active users in the preceding calendar month, You must request a license from Tencent, which Tencent may grant to You in its sole discretion, and You are not authorized to exercise any of the rights under this Agreement unless or until Tencent otherwise expressly grants You such rights.
|
31 |
+
5. RULES OF USE.
|
32 |
+
a. Your use of the Tencent Hunyuan Works must comply with applicable laws and regulations (including trade compliance laws and regulations) and adhere to the Acceptable Use Policy for the Tencent Hunyuan Works, which is hereby incorporated by reference into this Agreement. You must include the use restrictions referenced in these Sections 5(a) and 5(b) as an enforceable provision in any agreement (e.g., license agreement, terms of use, etc.) governing the use and/or distribution of Tencent Hunyuan Works and You must provide notice to subsequent users to whom You distribute that Tencent Hunyuan Works are subject to the use restrictions in these Sections 5(a) and 5(b).
|
33 |
+
b. You must not use the Tencent Hunyuan Works or any Output or results of the Tencent Hunyuan Works to improve any other AI model (other than Tencent Hunyuan or Model Derivatives thereof).
|
34 |
+
c. You must not use, reproduce, modify, distribute, or display the Tencent Hunyuan Works, Output or results of the Tencent Hunyuan Works outside the Territory. Any such use outside the Territory is unlicensed and unauthorized under this Agreement.
|
35 |
+
6. INTELLECTUAL PROPERTY.
|
36 |
+
a. Subject to Tencent’s ownership of Tencent Hunyuan Works made by or for Tencent and intellectual property rights therein, conditioned upon Your compliance with the terms and conditions of this Agreement, as between You and Tencent, You will be the owner of any derivative works and modifications of the Materials and any Model Derivatives that are made by or for You.
|
37 |
+
b. No trademark licenses are granted under this Agreement, and in connection with the Tencent Hunyuan Works, Licensee may not use any name or mark owned by or associated with Tencent or any of its affiliates, except as required for reasonable and customary use in describing and distributing the Tencent Hunyuan Works. Tencent hereby grants You a license to use “Tencent Hunyuan” (the “Mark”) in the Territory solely as required to comply with the provisions of Section 3(c), provided that You comply with any applicable laws related to trademark protection. All goodwill arising out of Your use of the Mark will inure to the benefit of Tencent.
|
38 |
+
c. If You commence a lawsuit or other proceedings (including a cross-claim or counterclaim in a lawsuit) against Us or any person or entity alleging that the Materials or any Output, or any portion of any of the foregoing, infringe any intellectual property or other right owned or licensable by You, then all licenses granted to You under this Agreement shall terminate as of the date such lawsuit or other proceeding is filed. You will defend, indemnify and hold harmless Us from and against any claim by any Third Party arising out of or related to Your or the Third Party’s use or distribution of the Tencent Hunyuan Works.
|
39 |
+
d. Tencent claims no rights in Outputs You generate. You and Your users are solely responsible for Outputs and their subsequent uses.
|
40 |
+
7. DISCLAIMERS OF WARRANTY AND LIMITATIONS OF LIABILITY.
|
41 |
+
a. We are not obligated to support, update, provide training for, or develop any further version of the Tencent Hunyuan Works or to grant any license thereto.
|
42 |
+
b. UNLESS AND ONLY TO THE EXTENT REQUIRED BY APPLICABLE LAW, THE TENCENT HUNYUAN WORKS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED “AS IS” WITHOUT ANY EXPRESS OR IMPLIED WARRANTIES OF ANY KIND INCLUDING ANY WARRANTIES OF TITLE, MERCHANTABILITY, NONINFRINGEMENT, COURSE OF DEALING, USAGE OF TRADE, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING, REPRODUCING, MODIFYING, PERFORMING, DISPLAYING OR DISTRIBUTING ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND ASSUME ANY AND ALL RISKS ASSOCIATED WITH YOUR OR A THIRD PARTY’S USE OR DISTRIBUTION OF ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND YOUR EXERCISE OF RIGHTS AND PERMISSIONS UNDER THIS AGREEMENT.
|
43 |
+
c. TO THE FULLEST EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT SHALL TENCENT OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, FOR ANY DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, CONSEQUENTIAL OR PUNITIVE DAMAGES, OR LOST PROFITS OF ANY KIND ARISING FROM THIS AGREEMENT OR RELATED TO ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS, EVEN IF TENCENT OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
|
44 |
+
8. SURVIVAL AND TERMINATION.
|
45 |
+
a. The term of this Agreement shall commence upon Your acceptance of this Agreement or access to the Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein.
|
46 |
+
b. We may terminate this Agreement if You breach any of the terms or conditions of this Agreement. Upon termination of this Agreement, You must promptly delete and cease use of the Tencent Hunyuan Works. Sections 6(a), 6(c), 7 and 9 shall survive the termination of this Agreement.
|
47 |
+
9. GOVERNING LAW AND JURISDICTION.
|
48 |
+
a. This Agreement and any dispute arising out of or relating to it will be governed by the laws of the Hong Kong Special Administrative Region of the People’s Republic of China, without regard to conflict of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement.
|
49 |
+
b. Exclusive jurisdiction and venue for any dispute arising out of or relating to this Agreement will be a court of competent jurisdiction in the Hong Kong Special Administrative Region of the People’s Republic of China, and Tencent and Licensee consent to the exclusive jurisdiction of such court with respect to any such dispute.
|
50 |
+
|
51 |
+
EXHIBIT A
|
52 |
+
ACCEPTABLE USE POLICY
|
53 |
+
|
54 |
+
Tencent reserves the right to update this Acceptable Use Policy from time to time.
|
55 |
+
Last modified: November 5, 2024
|
56 |
+
|
57 |
+
Tencent endeavors to promote safe and fair use of its tools and features, including Tencent Hunyuan. You agree not to use Tencent Hunyuan or Model Derivatives:
|
58 |
+
1. Outside the Territory;
|
59 |
+
2. In any way that violates any applicable national, federal, state, local, international or any other law or regulation;
|
60 |
+
3. To harm Yourself or others;
|
61 |
+
4. To repurpose or distribute output from Tencent Hunyuan or any Model Derivatives to harm Yourself or others;
|
62 |
+
5. To override or circumvent the safety guardrails and safeguards We have put in place;
|
63 |
+
6. For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;
|
64 |
+
7. To generate or disseminate verifiably false information and/or content with the purpose of harming others or influencing elections;
|
65 |
+
8. To generate or facilitate false online engagement, including fake reviews and other means of fake online engagement;
|
66 |
+
9. To intentionally defame, disparage or otherwise harass others;
|
67 |
+
10. To generate and/or disseminate malware (including ransomware) or any other content to be used for the purpose of harming electronic systems;
|
68 |
+
11. To generate or disseminate personal identifiable information with the purpose of harming others;
|
69 |
+
12. To generate or disseminate information (including images, code, posts, articles), and place the information in any public context (including –through the use of bot generated tweets), without expressly and conspicuously identifying that the information and/or content is machine generated;
|
70 |
+
13. To impersonate another individual without consent, authorization, or legal right;
|
71 |
+
14. To make high-stakes automated decisions in domains that affect an individual’s safety, rights or wellbeing (e.g., law enforcement, migration, medicine/health, management of critical infrastructure, safety components of products, essential services, credit, employment, housing, education, social scoring, or insurance);
|
72 |
+
15. In a manner that violates or disrespects the social ethics and moral standards of other countries or regions;
|
73 |
+
16. To perform, facilitate, threaten, incite, plan, promote or encourage violent extremism or terrorism;
|
74 |
+
17. For any use intended to discriminate against or harm individuals or groups based on protected characteristics or categories, online or offline social behavior or known or predicted personal or personality characteristics;
|
75 |
+
18. To intentionally exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;
|
76 |
+
19. For military purposes;
|
77 |
+
20. To engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or other professional practices.
|
Notice
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Usage and Legal Notices:
|
2 |
+
|
3 |
+
Tencent is pleased to support the open source community by making Tencent HunyuanVideo available.
|
4 |
+
|
5 |
+
Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. The below software and/or models in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) THL A29 Limited.
|
6 |
+
|
7 |
+
Tencent HunyuanVideo is licensed under the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT except for the third-party components listed below. Tencent HunyuanVideo does not impose any additional limitations beyond what is outlined in the repsective licenses of these third-party components. Users must comply with all terms and conditions of original licenses of these third-party components and must ensure that the usage of the third party components adheres to all relevant laws and regulations.
|
8 |
+
|
9 |
+
For avoidance of doubts, Tencent HunyuanVideo means the large language models and their software and algorithms, including trained model weights, parameters (including optimizer states), machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing may be made publicly available by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
10 |
+
|
11 |
+
|
12 |
+
Other dependencies and licenses:
|
13 |
+
|
14 |
+
|
15 |
+
Open Source Model Licensed under the Apache License Version 2.0:
|
16 |
+
The below software in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
|
17 |
+
--------------------------------------------------------------------
|
18 |
+
1. diffusers
|
19 |
+
Copyright (c) diffusers original author and authors
|
20 |
+
Please note this software has been modified by Tencent in this distribution.
|
21 |
+
|
22 |
+
2. transformers
|
23 |
+
Copyright (c) transformers original author and authors
|
24 |
+
|
25 |
+
3. safetensors
|
26 |
+
Copyright (c) safetensors original author and authors
|
27 |
+
|
28 |
+
4. flux
|
29 |
+
Copyright (c) flux original author and authors
|
30 |
+
|
31 |
+
|
32 |
+
Terms of the Apache License Version 2.0:
|
33 |
+
--------------------------------------------------------------------
|
34 |
+
Apache License
|
35 |
+
|
36 |
+
Version 2.0, January 2004
|
37 |
+
|
38 |
+
http://www.apache.org/licenses/
|
39 |
+
|
40 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
41 |
+
1. Definitions.
|
42 |
+
|
43 |
+
"License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
|
44 |
+
|
45 |
+
"Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
|
46 |
+
|
47 |
+
"Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
|
48 |
+
|
49 |
+
"You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License.
|
50 |
+
|
51 |
+
"Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
|
52 |
+
|
53 |
+
"Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
|
54 |
+
|
55 |
+
"Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
|
56 |
+
|
57 |
+
"Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
|
58 |
+
|
59 |
+
"Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
|
60 |
+
|
61 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
|
62 |
+
|
63 |
+
2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
|
64 |
+
|
65 |
+
3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
|
66 |
+
|
67 |
+
4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
|
68 |
+
|
69 |
+
You must give any other recipients of the Work or Derivative Works a copy of this License; and
|
70 |
+
|
71 |
+
You must cause any modified files to carry prominent notices stating that You changed the files; and
|
72 |
+
|
73 |
+
You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
|
74 |
+
|
75 |
+
If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
|
76 |
+
|
77 |
+
You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
|
78 |
+
|
79 |
+
5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
|
80 |
+
|
81 |
+
6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
|
82 |
+
|
83 |
+
7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
|
84 |
+
|
85 |
+
8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
|
86 |
+
|
87 |
+
9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
|
88 |
+
|
89 |
+
END OF TERMS AND CONDITIONS
|
90 |
+
|
91 |
+
|
92 |
+
|
93 |
+
Open Source Software Licensed under the BSD 2-Clause License:
|
94 |
+
--------------------------------------------------------------------
|
95 |
+
1. imageio
|
96 |
+
Copyright (c) 2014-2022, imageio developers
|
97 |
+
All rights reserved.
|
98 |
+
|
99 |
+
|
100 |
+
Terms of the BSD 2-Clause License:
|
101 |
+
--------------------------------------------------------------------
|
102 |
+
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
103 |
+
|
104 |
+
* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
105 |
+
|
106 |
+
* Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
107 |
+
|
108 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
109 |
+
|
110 |
+
|
111 |
+
|
112 |
+
Open Source Software Licensed under the BSD 3-Clause License:
|
113 |
+
--------------------------------------------------------------------
|
114 |
+
1. torchvision
|
115 |
+
Copyright (c) Soumith Chintala 2016,
|
116 |
+
All rights reserved.
|
117 |
+
|
118 |
+
2. flash-attn
|
119 |
+
Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file.
|
120 |
+
All rights reserved.
|
121 |
+
|
122 |
+
|
123 |
+
Terms of the BSD 3-Clause License:
|
124 |
+
--------------------------------------------------------------------
|
125 |
+
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
126 |
+
|
127 |
+
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
128 |
+
|
129 |
+
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
130 |
+
|
131 |
+
3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
|
132 |
+
|
133 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
134 |
+
|
135 |
+
|
136 |
+
|
137 |
+
Open Source Software Licensed under the BSD 3-Clause License and Other Licenses of the Third-Party Components therein:
|
138 |
+
--------------------------------------------------------------------
|
139 |
+
1. torch
|
140 |
+
Copyright (c) 2016- Facebook, Inc (Adam Paszke)
|
141 |
+
Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
|
142 |
+
Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
|
143 |
+
Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
|
144 |
+
Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
|
145 |
+
Copyright (c) 2011-2013 NYU (Clement Farabet)
|
146 |
+
Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
|
147 |
+
Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
|
148 |
+
Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
|
149 |
+
|
150 |
+
|
151 |
+
A copy of the BSD 3-Clause is included in this file.
|
152 |
+
|
153 |
+
For the license of other third party components, please refer to the following URL:
|
154 |
+
https://github.com/pytorch/pytorch/tree/v2.1.1/third_party
|
155 |
+
|
156 |
+
|
157 |
+
Open Source Software Licensed under the BSD 3-Clause License and Other Licenses of the Third-Party Components therein:
|
158 |
+
--------------------------------------------------------------------
|
159 |
+
1. pandas
|
160 |
+
Copyright (c) 2008-2011, AQR Capital Management, LLC, Lambda Foundry, Inc. and PyData Development Team
|
161 |
+
All rights reserved.
|
162 |
+
|
163 |
+
Copyright (c) 2011-2023, Open source contributors.
|
164 |
+
|
165 |
+
|
166 |
+
A copy of the BSD 3-Clause is included in this file.
|
167 |
+
|
168 |
+
For the license of other third party components, please refer to the following URL:
|
169 |
+
https://github.com/pandas-dev/pandas/tree/v2.0.3/LICENSES
|
170 |
+
|
171 |
+
|
172 |
+
Open Source Software Licensed under the BSD 3-Clause License and Other Licenses of the Third-Party Components therein:
|
173 |
+
--------------------------------------------------------------------
|
174 |
+
1. numpy
|
175 |
+
Copyright (c) 2005-2022, NumPy Developers.
|
176 |
+
All rights reserved.
|
177 |
+
|
178 |
+
|
179 |
+
A copy of the BSD 3-Clause is included in this file.
|
180 |
+
|
181 |
+
For the license of other third party components, please refer to the following URL:
|
182 |
+
https://github.com/numpy/numpy/blob/v1.24.4/LICENSES_bundled.txt
|
183 |
+
|
184 |
+
|
185 |
+
Open Source Software Licensed under the MIT License:
|
186 |
+
--------------------------------------------------------------------
|
187 |
+
1. einops
|
188 |
+
Copyright (c) 2018 Alex Rogozhnikov
|
189 |
+
|
190 |
+
2. loguru
|
191 |
+
Copyright (c) 2017
|
192 |
+
|
193 |
+
|
194 |
+
Terms of the MIT License:
|
195 |
+
--------------------------------------------------------------------
|
196 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
197 |
+
|
198 |
+
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
199 |
+
|
200 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
201 |
+
|
202 |
+
|
203 |
+
|
204 |
+
Open Source Software Licensed under the MIT License and Other Licenses of the Third-Party Components therein:
|
205 |
+
--------------------------------------------------------------------
|
206 |
+
1. tqdm
|
207 |
+
Copyright (c) 2013 noamraph
|
208 |
+
|
209 |
+
|
210 |
+
A copy of the MIT is included in this file.
|
211 |
+
|
212 |
+
For the license of other third party components, please refer to the following URL:
|
213 |
+
https://github.com/tqdm/tqdm/blob/v4.66.2/LICENCE
|
214 |
+
|
215 |
+
|
216 |
+
|
217 |
+
Open Source Model Licensed under the MIT License:
|
218 |
+
--------------------------------------------------------------------
|
219 |
+
1. clip-large
|
220 |
+
Copyright (c) 2021 OpenAI
|
221 |
+
|
222 |
+
|
223 |
+
A copy of the MIT is included in this file.
|
224 |
+
|
225 |
+
|
226 |
+
--------------------------------------------------------------------
|
227 |
+
We may also use other third-party components:
|
228 |
+
|
229 |
+
1. llava-llama3
|
230 |
+
|
231 |
+
Copyright (c) llava-llama3 original author and authors
|
232 |
+
|
233 |
+
URL: https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers#model
|
README.md
CHANGED
@@ -1,12 +1,172 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!-- ## **HunyuanVideo** -->
|
2 |
+
|
3 |
+
[中文阅读](./README_zh.md)
|
4 |
+
|
5 |
+
|
6 |
+
# HunyuanVideo: A Systematic Framework For Large Video Generation Model
|
7 |
+
<div align="center">
|
8 |
+
<a href="https://github.com/Tencent/HunyuanVideo"><img src="https://img.shields.io/static/v1?label=HunyuanVideo Code&message=Github&color=blue&logo=github-pages"></a>  
|
9 |
+
<a href="https://aivideo.hunyuan.tencent.com"><img src="https://img.shields.io/static/v1?label=Project%20Page&message=Web&color=green&logo=github-pages"></a>  
|
10 |
+
<a href="https://video.hunyuan.tencent.com"><img src="https://img.shields.io/static/v1?label=Playground&message=Web&color=green&logo=github-pages"></a>  
|
11 |
+
<a href="https://arxiv.org/abs/2412.03603"><img src="https://img.shields.io/static/v1?label=Tech Report&message=Arxiv:HunyuanVideo&color=red&logo=arxiv"></a>  
|
12 |
+
<a href="https://huggingface.co/tencent/HunyuanVideo"><img src="https://img.shields.io/static/v1?label=HunyuanVideo&message=HuggingFace&color=yellow"></a>    
|
13 |
+
<a href="https://huggingface.co/tencent/HunyuanVideo-PromptRewrite"><img src="https://img.shields.io/static/v1?label=HunyuanVideo-PromptRewrite&message=HuggingFace&color=yellow"></a>    
|
14 |
+
|
15 |
+
[![Replicate](https://replicate.com/zsxkib/hunyuan-video/badge)](https://replicate.com/zsxkib/hunyuan-video)
|
16 |
+
</div>
|
17 |
+
|
18 |
+
06/04/2025: Version 2.1 Integrated Tea Cache (https://github.com/ali-vilab/TeaCache) for even faster generations\
|
19 |
+
01/04/2025: Version 2.0 Full leverage of mmgp 3.0 (faster and even lower RAM requirements ! + support for compilation on Linux and WSL)\
|
20 |
+
22/12/2024: Version 1.0 First release\
|
21 |
+
|
22 |
+
*GPU Poor version by **DeepBeepMeep**. This great video generator can now run smoothly on a 12 GB to 24 GB GPU.*
|
23 |
+
|
24 |
+
This version has the following improvements over the original Hunyuan Video model:
|
25 |
+
- Reduce greatly the RAM requirements and VRAM requirements
|
26 |
+
- Much faster thanks to compilation and fast loading / unloading
|
27 |
+
- 5 profiles in order to able to run the model at a decent speed on a low end consumer config (32 GB of RAM and 12 VRAM) and to run it at a very good speed on a high end consumer config (48 GB of RAM and 24 GB of VRAM)
|
28 |
+
- Autodownloading of the needed model files
|
29 |
+
- Improved gradio interface with progression bar and more options
|
30 |
+
- Switch easily between Hunyuan and Fast Hunyuan models and quantized / non quantized models
|
31 |
+
- Much simpler installation
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
This fork by DeepBeepMeep is an integration of the mmpg module on the gradio_server.py.
|
36 |
+
|
37 |
+
It is an illustration on how one can set up on an existing model some fast and properly working CPU offloading with changing only a few lines of code in the core model.
|
38 |
+
|
39 |
+
For more information on how to use the mmpg module, please go to: https://github.com/deepbeepmeep/mmgp
|
40 |
+
|
41 |
+
You will find the original Hunyuan Video repository here: https://github.com/Tencent/HunyuanVideo
|
42 |
+
|
43 |
+
|
44 |
+
|
45 |
+
## **Abstract**
|
46 |
+
We present HunyuanVideo, a novel open-source video foundation model that exhibits performance in video generation that is comparable to, if not superior to, leading closed-source models. In order to train HunyuanVideo model, we adopt several key technologies for model learning, including data curation, image-video joint model training, and an efficient infrastructure designed to facilitate large-scale model training and inference. Additionally, through an effective strategy for scaling model architecture and dataset, we successfully trained a video generative model with over 13 billion parameters, making it the largest among all open-source models.
|
47 |
+
|
48 |
+
We conducted extensive experiments and implemented a series of targeted designs to ensure high visual quality, motion diversity, text-video alignment, and generation stability. According to professional human evaluation results, HunyuanVideo outperforms previous state-of-the-art models, including Runway Gen-3, Luma 1.6, and 3 top-performing Chinese video generative models. By releasing the code and weights of the foundation model and its applications, we aim to bridge the gap between closed-source and open-source video foundation models. This initiative will empower everyone in the community to experiment with their ideas, fostering a more dynamic and vibrant video generation ecosystem.
|
49 |
+
|
50 |
+
|
51 |
+
|
52 |
+
### Installation Guide for Linux and Windows
|
53 |
+
|
54 |
+
We provide an `environment.yml` file for setting up a Conda environment.
|
55 |
+
Conda's installation instructions are available [here](https://docs.anaconda.com/free/miniconda/index.html).
|
56 |
+
|
57 |
+
We recommend CUDA versions 12.4 or 11.8 for the manual installation.
|
58 |
+
|
59 |
+
```shell
|
60 |
+
# 1. Prepare conda environment
|
61 |
+
conda env create -f environment.yml
|
62 |
+
|
63 |
+
# 2. Activate the environment
|
64 |
+
conda activate HunyuanVideo
|
65 |
+
|
66 |
+
# 3. Install pip dependencies
|
67 |
+
python -m pip install -r requirements.txt
|
68 |
+
|
69 |
+
|
70 |
+
# 4.1 optional Flash attention support (easy to install on Linux but much harder on Windows)
|
71 |
+
python -m pip install flash-attn==2.7.2.post1
|
72 |
+
|
73 |
+
# 4.2 optional Sage attention support (30% faster, easy to install on Linux but much harder on Windows)
|
74 |
+
python -m pip install sageattention==1.0.6
|
75 |
+
|
76 |
+
```
|
77 |
+
|
78 |
+
### Profiles
|
79 |
+
You can choose between 5 profiles depending on your hardware:
|
80 |
+
- HighRAM_HighVRAM (1): at least 48 GB of RAM and 24 GB of VRAM : the fastest well suited for a RTX 3090 / RTX 4090 but consumes much more VRAM, adapted for fast shorter video
|
81 |
+
- HighRAM_LowVRAM (2): at least 48 GB of RAM and 12 GB of VRAM : a bit slower, better suited for RTX 3070/3080/4070/4080 or for RTX 3090 / RTX 4090 with large pictures batches or long videos
|
82 |
+
- LowRAM_HighVRAM (3): at least 32 GB of RAM and 24 GB of VRAM : adapted for RTX 3090 / RTX 4090 with limited RAM but at the cost of VRAM (shorter videos)
|
83 |
+
- LowRAM_LowVRAM (4): at least 32 GB of RAM and 12 GB of VRAM : if you have little VRAM or want to generate longer videos
|
84 |
+
- VerylowRAM_LowVRAM (5): at least 24 GB of RAM and 10 GB of VRAM : if you don't have much it won't be fast but maybe it will work
|
85 |
+
|
86 |
+
Profile 2 (High RAM) and 4 (Low RAM)are the most recommended profiles since they are versatile (support for long videos for a slight performance cost).\
|
87 |
+
However, a safe approach is to start from profile 5 (default profile) and then go down progressively to profile 4 and then to profile 2 as long as the app remains responsive or doesn't trigger any out of memory error.
|
88 |
+
|
89 |
+
|
90 |
+
### Run a Gradio Server on port 7860 (recommended)
|
91 |
+
```bash
|
92 |
+
python3 gradio_server.py
|
93 |
+
```
|
94 |
+
|
95 |
+
You will have the possibility to configure a RAM / VRAM profile by expanding the section *Video Engine Configuration* in the Web Interface.\
|
96 |
+
If by mistake you have chosen a configuration not supported by your system, you can force a profile while loading the app with the safe profile 5:
|
97 |
+
```bash
|
98 |
+
python3 gradio_server.py --profile 5
|
99 |
+
```
|
100 |
+
|
101 |
+
|
102 |
+
### Run through the command line
|
103 |
+
```bash
|
104 |
+
cd HunyuanVideo
|
105 |
+
|
106 |
+
python3 sample_video.py \
|
107 |
+
--video-size 720 1280 \
|
108 |
+
--video-length 129 \
|
109 |
+
--infer-steps 50 \
|
110 |
+
--prompt "A cat walks on the grass, realistic style." \
|
111 |
+
--flow-reverse \
|
112 |
+
--save-path ./results
|
113 |
+
```
|
114 |
+
|
115 |
+
Please note currently that profile and the models used need to be mentioned inside the *sample_video.py* file.
|
116 |
+
|
117 |
+
### More Configurations
|
118 |
+
|
119 |
+
We list some more useful configurations for easy usage:
|
120 |
+
|
121 |
+
| Argument | Default | Description |
|
122 |
+
|:----------------------:|:---------:|:-----------------------------------------:|
|
123 |
+
| `--prompt` | None | The text prompt for video generation |
|
124 |
+
| `--video-size` | 720 1280 | The size of the generated video |
|
125 |
+
| `--video-length` | 129 | The length of the generated video |
|
126 |
+
| `--infer-steps` | 50 | The number of steps for sampling |
|
127 |
+
| `--embedded-cfg-scale` | 6.0 | Embeded Classifier free guidance scale |
|
128 |
+
| `--flow-shift` | 7.0 | Shift factor for flow matching schedulers |
|
129 |
+
| `--flow-reverse` | False | If reverse, learning/sampling from t=1 -> t=0 |
|
130 |
+
| `--seed` | None | The random seed for generating video, if None, we init a random seed |
|
131 |
+
| `--use-cpu-offload` | False | Use CPU offload for the model load to save more memory, necessary for high-res video generation |
|
132 |
+
| `--save-path` | ./results | Path to save the generated video |
|
133 |
+
|
134 |
+
|
135 |
+
|
136 |
+
## 🔗 BibTeX
|
137 |
+
If you find [HunyuanVideo](https://arxiv.org/abs/2412.03603) useful for your research and applications, please cite using this BibTeX:
|
138 |
+
|
139 |
+
```BibTeX
|
140 |
+
@misc{kong2024hunyuanvideo,
|
141 |
+
title={HunyuanVideo: A Systematic Framework For Large Video Generative Models},
|
142 |
+
author={Weijie Kong, Qi Tian, Zijian Zhang, Rox Min, Zuozhuo Dai, Jin Zhou, Jiangfeng Xiong, Xin Li, Bo Wu, Jianwei Zhang, Kathrina Wu, Qin Lin, Aladdin Wang, Andong Wang, Changlin Li, Duojun Huang, Fang Yang, Hao Tan, Hongmei Wang, Jacob Song, Jiawang Bai, Jianbing Wu, Jinbao Xue, Joey Wang, Junkun Yuan, Kai Wang, Mengyang Liu, Pengyu Li, Shuai Li, Weiyan Wang, Wenqing Yu, Xinchi Deng, Yang Li, Yanxin Long, Yi Chen, Yutao Cui, Yuanbo Peng, Zhentao Yu, Zhiyu He, Zhiyong Xu, Zixiang Zhou, Zunnan Xu, Yangyu Tao, Qinglin Lu, Songtao Liu, Dax Zhou, Hongfa Wang, Yong Yang, Di Wang, Yuhong Liu, and Jie Jiang, along with Caesar Zhong},
|
143 |
+
year={2024},
|
144 |
+
archivePrefix={arXiv preprint arXiv:2412.03603},
|
145 |
+
primaryClass={cs.CV},
|
146 |
+
url={https://arxiv.org/abs/2412.03603},
|
147 |
+
}
|
148 |
+
```
|
149 |
+
|
150 |
+
|
151 |
+
|
152 |
+
## 🧩 Projects that use HunyuanVideo
|
153 |
+
|
154 |
+
If you develop/use HunyuanVideo in your projects, welcome to let us know.
|
155 |
+
|
156 |
+
- ComfyUI (with support for F8 Inference and Video2Video Generation): [ComfyUI-HunyuanVideoWrapper](https://github.com/kijai/ComfyUI-HunyuanVideoWrapper) by [Kijai](https://github.com/kijai)
|
157 |
+
|
158 |
+
|
159 |
+
|
160 |
+
## Acknowledgements
|
161 |
+
|
162 |
+
We would like to thank the contributors to the [SD3](https://huggingface.co/stabilityai/stable-diffusion-3-medium), [FLUX](https://github.com/black-forest-labs/flux), [Llama](https://github.com/meta-llama/llama), [LLaVA](https://github.com/haotian-liu/LLaVA), [Xtuner](https://github.com/InternLM/xtuner), [diffusers](https://github.com/huggingface/diffusers) and [HuggingFace](https://huggingface.co) repositories, for their open research and exploration.
|
163 |
+
Additionally, we also thank the Tencent Hunyuan Multimodal team for their help with the text encoder.
|
164 |
+
|
165 |
+
## Star History
|
166 |
+
<a href="https://star-history.com/#Tencent/HunyuanVideo&Date">
|
167 |
+
<picture>
|
168 |
+
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=Tencent/HunyuanVideo&type=Date&theme=dark" />
|
169 |
+
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=Tencent/HunyuanVideo&type=Date" />
|
170 |
+
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=Tencent/HunyuanVideo&type=Date" />
|
171 |
+
</picture>
|
172 |
+
</a>
|
README_zh.md
ADDED
@@ -0,0 +1,425 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!-- ## **HunyuanVideo** -->
|
2 |
+
|
3 |
+
[English](./README.md)
|
4 |
+
|
5 |
+
<p align="center">
|
6 |
+
<img src="https://raw.githubusercontent.com/Tencent/HunyuanVideo/refs/heads/main/assets/logo.png" height=100>
|
7 |
+
</p>
|
8 |
+
|
9 |
+
# HunyuanVideo: A Systematic Framework For Large Video Generation Model
|
10 |
+
|
11 |
+
<div align="center">
|
12 |
+
<a href="https://github.com/Tencent/HunyuanVideo"><img src="https://img.shields.io/static/v1?label=HunyuanVideo Code&message=Github&color=blue&logo=github-pages"></a>  
|
13 |
+
<a href="https://aivideo.hunyuan.tencent.com"><img src="https://img.shields.io/static/v1?label=Project%20Page&message=Web&color=green&logo=github-pages"></a>  
|
14 |
+
<a href="https://video.hunyuan.tencent.com"><img src="https://img.shields.io/static/v1?label=Playground&message=Web&color=green&logo=github-pages"></a>  
|
15 |
+
<a href="https://arxiv.org/abs/2412.03603"><img src="https://img.shields.io/static/v1?label=Tech Report&message=Arxiv:HunyuanVideo&color=red&logo=arxiv"></a>  
|
16 |
+
<a href="https://huggingface.co/tencent/HunyuanVideo"><img src="https://img.shields.io/static/v1?label=HunyuanVideo&message=HuggingFace&color=yellow"></a>    
|
17 |
+
<a href="https://huggingface.co/tencent/HunyuanVideo-PromptRewrite"><img src="https://img.shields.io/static/v1?label=HunyuanVideo-PromptRewrite&message=HuggingFace&color=yellow"></a>    
|
18 |
+
|
19 |
+
[![Replicate](https://replicate.com/zsxkib/hunyuan-video/badge)](https://replicate.com/zsxkib/hunyuan-video)
|
20 |
+
</div>
|
21 |
+
<p align="center">
|
22 |
+
👋 加入我们的 <a href="assets/WECHAT.md" target="_blank">WeChat</a> 和 <a href="https://discord.gg/GpARqvrh" target="_blank">Discord</a>
|
23 |
+
</p>
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
-----
|
28 |
+
|
29 |
+
本仓库包含了 HunyuanVideo 项目的 PyTorch 模型定义、预训练权重和推理/采样代码。参考我们的项目页面 [project page](https://aivideo.hunyuan.tencent.com) 查看更多内容。
|
30 |
+
|
31 |
+
> [**HunyuanVideo: A Systematic Framework For Large Video Generation Model**](https://arxiv.org/abs/2412.03603) <br>
|
32 |
+
|
33 |
+
## 🎥 作品展示
|
34 |
+
<div align="center">
|
35 |
+
<video src="https://github.com/user-attachments/assets/f37925a3-7d42-40c9-8a9b-5a010c7198e2" width="50%">
|
36 |
+
</div>
|
37 |
+
|
38 |
+
注:由于 GitHub 的政策限制,上面的视频质量被大幅压缩。你可以从 [这里](https://aivideo.hunyuan.tencent.com/download/HunyuanVideo/material) 下载高质量版本。
|
39 |
+
|
40 |
+
## 🔥🔥🔥 更新!!
|
41 |
+
* 2024年12月03日: 🚀 开源 HunyuanVideo 多卡并行推理代码,由[xDiT](https://github.com/xdit-project/xDiT)提供。
|
42 |
+
* 2024年12月03日: 🤗 开源 HunyuanVideo 文生视频的推理代码和模型权重。
|
43 |
+
|
44 |
+
## 📑 开源计划
|
45 |
+
|
46 |
+
- HunyuanVideo (文生视频模型)
|
47 |
+
- [x] 推理代码
|
48 |
+
- [x] 模型权重
|
49 |
+
- [x] 多GPU序列并行推理(GPU 越多,推理速度越快)
|
50 |
+
- [x] Web Demo (Gradio)
|
51 |
+
- [ ] Penguin Video 基准测试集
|
52 |
+
- [ ] ComfyUI
|
53 |
+
- [ ] Diffusers
|
54 |
+
- [ ] 多GPU PipeFusion并行推理 (更低显存需求)
|
55 |
+
- HunyuanVideo (图生视频模型)
|
56 |
+
- [ ] 推理代码
|
57 |
+
- [ ] 模型权重
|
58 |
+
|
59 |
+
## 目录
|
60 |
+
- [HunyuanVideo: A Systematic Framework For Large Video Generation Model](#hunyuanvideo-a-systematic-framework-for-large-video-generation-model)
|
61 |
+
- [🎥 作品展示](#-作品展示)
|
62 |
+
- [🔥🔥🔥 更新!!](#-更新)
|
63 |
+
- [📑 开源计划](#-开源计划)
|
64 |
+
- [目录](#目录)
|
65 |
+
- [**摘要**](#摘要)
|
66 |
+
- [**HunyuanVideo 的架构**](#hunyuanvideo-的架构)
|
67 |
+
- [🎉 **亮点**](#-亮点)
|
68 |
+
- [**统一的图视频生成架构**](#统一的图视频生成架构)
|
69 |
+
- [**MLLM 文本编码器**](#mllm-文本编码器)
|
70 |
+
- [**3D VAE**](#3d-vae)
|
71 |
+
- [**Prompt 改写**](#prompt-改写)
|
72 |
+
- [📈 能力评估](#-能力评估)
|
73 |
+
- [📜 运行配置](#-运行配置)
|
74 |
+
- [🛠️ 安装和依赖](#️-安装和依赖)
|
75 |
+
- [Linux 安装指引](#linux-安装指引)
|
76 |
+
- [🧱 下载预训练模型](#-下载预训练模型)
|
77 |
+
- [🔑 推理](#-推理)
|
78 |
+
- [使用命令行](#使用命令行)
|
79 |
+
- [运行gradio服务](#运行gradio服务)
|
80 |
+
- [更多配置](#更多配置)
|
81 |
+
- [🚀 使用 xDiT 实现多卡并行推理](#-使用-xdit-实现多卡并行推理)
|
82 |
+
- [安装与 xDiT 兼容的依赖项](#安装与-xdit-兼容的依赖项)
|
83 |
+
- [使用命令行](#使用命令行-1)
|
84 |
+
- [🔗 BibTeX](#-bibtex)
|
85 |
+
- [🧩 使用 HunyuanVideo 的项目](#-使用-hunyuanvideo-的项目)
|
86 |
+
- [致谢](#致谢)
|
87 |
+
- [Star 趋势](#star-趋势)
|
88 |
+
---
|
89 |
+
|
90 |
+
## **摘要**
|
91 |
+
HunyuanVideo 是一个全新的开源视频生成大模型,具有与领先的闭源模型相媲美甚至更优的视频生成表现。为了训练 HunyuanVideo,我们采用了一个全面的框架,集成了数据整理、图像-视频联合模型训练和高效的基础设施以支持大规模模型训练和推理。此外,通过有效的模型架构和数据集扩展策略,我们成功地训练了一个拥有超过 130 亿参数的视频生成模型,使其成为最大的开源视频生成模型之一。
|
92 |
+
|
93 |
+
我们在模型结构的设计上做了大量的实验以确保其能拥有高质量的视觉效果、多样的运动、文本-视频对齐和生成稳定性。根据专业人员的评估结果,HunyuanVideo 在综合指标上优于以往的最先进模型,包括 Runway Gen-3、Luma 1.6 和 3 个中文社区表现最好的视频生成模型。**通过开源基础模型和应用模型的代码和权重,我们旨在弥合闭源和开源视频基础模型之间的差距,帮助社区中的每个人都能够尝试自己的想法,促进更加动态和活跃的视频生成生态。**
|
94 |
+
|
95 |
+
|
96 |
+
## **HunyuanVideo 的架构**
|
97 |
+
|
98 |
+
HunyuanVideo 是一个隐空间模型,训练时它采用了 3D VAE 压缩时间维度和空间维度的特征。文本提示通过一个大语言模型编码后作为条件输入模型,引导模型通过对高斯噪声的多步去噪,输出一个视频的隐空间表示。最后,推理时通过 3D VAE 解码器将隐空间表示解码为视频。
|
99 |
+
<p align="center">
|
100 |
+
<img src="https://raw.githubusercontent.com/Tencent/HunyuanVideo/refs/heads/main/assets/overall.png" height=300>
|
101 |
+
</p>
|
102 |
+
|
103 |
+
## 🎉 **亮点**
|
104 |
+
### **统一的图视频生成架构**
|
105 |
+
|
106 |
+
HunyuanVideo 采用了 Transformer 和 Full Attention 的设计用于视频生成。具体来说,我们使用了一个“双流到单流”的混合模型设计用于视频生成。在双流阶段,视频和文本 token 通过并行的 Transformer Block 独立处理,使得每个模态可以学习适合自己的调制机制而不会相互干扰。在单流阶段,我们将视频和文本 token 连接起来并将它们输入到后续的 Transformer Block 中进行有效的多模态信息融合。这种设计捕捉了视觉和语义信息之间的复杂交互,增强了整体模型性能。
|
107 |
+
<p align="center">
|
108 |
+
<img src="https://raw.githubusercontent.com/Tencent/HunyuanVideo/refs/heads/main/assets/backbone.png" height=350>
|
109 |
+
</p>
|
110 |
+
|
111 |
+
### **MLLM 文本编码器**
|
112 |
+
过去的视频生成模型通常使用预训练的 CLIP 和 T5-XXL 作为文本编码器,其中 CLIP 使用 Transformer Encoder,T5 使用 Encoder-Decoder 结构。HunyuanVideo 使用了一个预训练的 Multimodal Large Language Model (MLLM) 作为文本编码器,它具有以下优势:
|
113 |
+
* 与 T5 相比,MLLM 基于图文数据指令微调后在特征空间中具有更好的图像-文本对齐能力,这减轻了扩散模型中的图文对齐的难度;
|
114 |
+
* 与 CLIP 相比,MLLM 在图像的细节描述和复杂推理方面表现出更强的能力;
|
115 |
+
* MLLM 可以通过遵循系统指令实现零样本生成,帮助文本特征更多地关注关键信息。
|
116 |
+
|
117 |
+
由于 MLLM 是基于 Causal Attention 的,而 T5-XXL 使用了 Bidirectional Attention 为扩散模型提供更好的文本引导。因此,我们引入了一个额外的 token 优化器来增强文本特征。
|
118 |
+
<p align="center">
|
119 |
+
<img src="https://raw.githubusercontent.com/Tencent/HunyuanVideo/refs/heads/main/assets/text_encoder.png" height=275>
|
120 |
+
</p>
|
121 |
+
|
122 |
+
### **3D VAE**
|
123 |
+
我们的 VAE 采用了 CausalConv3D 作为 HunyuanVideo 的编码器和解码器,用于压缩视频的时间维度和空间维度,其中时间维度压缩 4 倍,空间维度压缩 8 倍,压缩为 16 channels。这样可以显著减少后续 Transformer 模型的 token 数量,使我们能够在原始分辨率和帧率下训练视频生成模型。
|
124 |
+
<p align="center">
|
125 |
+
<img src="https://raw.githubusercontent.com/Tencent/HunyuanVideo/refs/heads/main/assets/3dvae.png" height=150>
|
126 |
+
</p>
|
127 |
+
|
128 |
+
### **Prompt 改写**
|
129 |
+
为了解决用户输入文本提示的多样性和不一致性的困难,我们微调了 [Hunyuan-Large model](https://github.com/Tencent/Tencent-Hunyuan-Large) 模型作为我们的 prompt 改写模型,将用户输入的提示词改写为更适合模型偏好的写法。
|
130 |
+
|
131 |
+
我们提供了两个改写模式:正常模式和导演模式。两种模式的提示词见[这里](hyvideo/prompt_rewrite.py)。正常模式旨在增强视频生成模型对用户意图的理解,从而更准确地解释提供的指令。导演模式增强了诸如构图、光照和摄像机移动等方面的描述,倾向于生成视觉质量更高的视频。注意,这种增强有时可能会导致一些语义细节的丢失。
|
132 |
+
|
133 |
+
Prompt 改写模型可以直接使用 [Hunyuan-Large](https://github.com/Tencent/Tencent-Hunyuan-Large) 部署和推理. 我们开源了 prompt 改写模型的权重,见[这里](https://huggingface.co/Tencent/HunyuanVideo-PromptRewrite).
|
134 |
+
|
135 |
+
## 📈 能力评估
|
136 |
+
|
137 |
+
为了评估 HunyuanVideo 的能力,我们选择了四个闭源视频生成模型作为对比。我们总共使用了 1,533 个 prompt,每个 prompt 通过一次推理生成了相同数量的视频样本。为了公平比较,我们只进行了一次推理以避免任何挑选。在与其他方法比较时,我们保持了所有选择模型的默认设置,并确保了视频分辨率的一致性。视频根据三个标准进行评估:文本对齐、运动质量和视觉质量。在 60 多名专业评估人员评估后,HunyuanVideo 在综合指标上表现最好,特别是在运动质量方面表现较为突出。
|
138 |
+
|
139 |
+
<p align="center">
|
140 |
+
<table>
|
141 |
+
<thead>
|
142 |
+
<tr>
|
143 |
+
<th rowspan="2">模型</th> <th rowspan="2">是否开源</th> <th>时长</th> <th>文本对齐</th> <th>运动质量</th> <th rowspan="2">视觉质量</th> <th rowspan="2">综合评价</th> <th rowspan="2">排序</th>
|
144 |
+
</tr>
|
145 |
+
</thead>
|
146 |
+
<tbody>
|
147 |
+
<tr>
|
148 |
+
<td>HunyuanVideo (Ours)</td> <td> ✔ </td> <td>5s</td> <td>61.8%</td> <td>66.5%</td> <td>95.7%</td> <td>41.3%</td> <td>1</td>
|
149 |
+
</tr>
|
150 |
+
<tr>
|
151 |
+
<td>国内模型 A (API)</td> <td> ✘ </td> <td>5s</td> <td>62.6%</td> <td>61.7%</td> <td>95.6%</td> <td>37.7%</td> <td>2</td>
|
152 |
+
</tr>
|
153 |
+
<tr>
|
154 |
+
<td>国内模型 B (Web)</td> <td> ✘</td> <td>5s</td> <td>60.1%</td> <td>62.9%</td> <td>97.7%</td> <td>37.5%</td> <td>3</td>
|
155 |
+
</tr>
|
156 |
+
<tr>
|
157 |
+
<td>GEN-3 alpha (Web)</td> <td>✘</td> <td>6s</td> <td>47.7%</td> <td>54.7%</td> <td>97.5%</td> <td>27.4%</td> <td>4</td>
|
158 |
+
</tr>
|
159 |
+
<tr>
|
160 |
+
<td>Luma1.6 (API)</td><td>✘</td> <td>5s</td> <td>57.6%</td> <td>44.2%</td> <td>94.1%</td> <td>24.8%</td> <td>5</td>
|
161 |
+
</tr>
|
162 |
+
</tbody>
|
163 |
+
</table>
|
164 |
+
</p>
|
165 |
+
|
166 |
+
## 📜 运行配置
|
167 |
+
|
168 |
+
下表列出了运行 HunyuanVideo 模型使用文本生成视频的推荐配置(batch size = 1):
|
169 |
+
|
170 |
+
| 模型 | 分辨率<br/>(height/width/frame) | 峰值显存 |
|
171 |
+
|:--------------:|:--------------------------------:|:----------------:|
|
172 |
+
| HunyuanVideo | 720px1280px129f | 60G |
|
173 |
+
| HunyuanVideo | 544px960px129f | 45G |
|
174 |
+
|
175 |
+
* 本项目适用于使用 NVIDIA GPU 和支持 CUDA 的设备
|
176 |
+
* 模型在单张 80G GPU 上测试
|
177 |
+
* 运行 720px1280px129f 的最小显存要求是 60GB,544px960px129f 的最小显存要求是 45GB。
|
178 |
+
* 测试操作系统:Linux
|
179 |
+
|
180 |
+
## 🛠️ 安装和依赖
|
181 |
+
|
182 |
+
首先克隆 git 仓库:
|
183 |
+
```shell
|
184 |
+
git clone https://github.com/tencent/HunyuanVideo
|
185 |
+
cd HunyuanVideo
|
186 |
+
```
|
187 |
+
|
188 |
+
### Linux 安装指引
|
189 |
+
|
190 |
+
我们提供了 `environment.yml` 文件来设置 Conda 环境。Conda 的安装指南可以参考[这里](https://docs.anaconda.com/free/miniconda/index.html)。
|
191 |
+
|
192 |
+
我们推理使用 CUDA 12.4 或 11.8 的版本。
|
193 |
+
|
194 |
+
```shell
|
195 |
+
# 1. Prepare conda environment
|
196 |
+
conda env create -f environment.yml
|
197 |
+
|
198 |
+
# 2. Activate the environment
|
199 |
+
conda activate HunyuanVideo
|
200 |
+
|
201 |
+
# 3. Install pip dependencies
|
202 |
+
python -m pip install -r requirements.txt
|
203 |
+
|
204 |
+
# 4. Install flash attention v2 for acceleration (requires CUDA 11.8 or above)
|
205 |
+
python -m pip install ninja
|
206 |
+
python -m pip install git+https://github.com/Dao-AILab/[email protected]
|
207 |
+
```
|
208 |
+
|
209 |
+
如果在特定GPU型号上遭遇float point exception(core dump)问题,可尝试以下方案修复:
|
210 |
+
|
211 |
+
```shell
|
212 |
+
#选项1:确保已正确安装CUDA 12.4, CUBLAS>=12.4.5.8, and CUDNN>=9.00(或直接使用我们提供的CUDA12镜像)
|
213 |
+
pip install nvidia-cublas-cu12==12.4.5.8
|
214 |
+
export LD_LIBRARY_PATH=/opt/conda/lib/python3.8/site-packages/nvidia/cublas/lib/
|
215 |
+
|
216 |
+
#选项2:强制显式使用CUDA11.8编译的Pytorch版本以及其他所有软件包
|
217 |
+
pip uninstall -r requirements.txt # 确保卸载所有依赖包
|
218 |
+
pip install torch==2.4.0 --index-url https://download.pytorch.org/whl/cu118
|
219 |
+
pip install -r requirements.txt
|
220 |
+
python -m pip install git+https://github.com/Dao-AILab/[email protected]
|
221 |
+
```
|
222 |
+
|
223 |
+
另外,我们提供了一个预构建的 Docker 镜像,可以使用如下命令进行拉取和运行。
|
224 |
+
```shell
|
225 |
+
# 用于CUDA 12.4 (已更新避免float point exception)
|
226 |
+
docker pull hunyuanvideo/hunyuanvideo:cuda_12
|
227 |
+
docker run -itd --gpus all --init --net=host --uts=host --ipc=host --name hunyuanvideo --security-opt=seccomp=unconfined --ulimit=stack=67108864 --ulimit=memlock=-1 --privileged hunyuanvideo/hunyuanvideo:cuda_12
|
228 |
+
|
229 |
+
# 用于CUDA 11.8
|
230 |
+
docker pull hunyuanvideo/hunyuanvideo:cuda_11
|
231 |
+
docker run -itd --gpus all --init --net=host --uts=host --ipc=host --name hunyuanvideo --security-opt=seccomp=unconfined --ulimit=stack=67108864 --ulimit=memlock=-1 --privileged hunyuanvideo/hunyuanvideo:cuda_11
|
232 |
+
```
|
233 |
+
|
234 |
+
## 🧱 下载预训练模型
|
235 |
+
|
236 |
+
下载预训练模型参考[这里](ckpts/README.md)。
|
237 |
+
|
238 |
+
## 🔑 推理
|
239 |
+
我们在下表中列出了支持的高度/宽度/帧数设置。
|
240 |
+
|
241 |
+
| 分辨率 | h/w=9:16 | h/w=16:9 | h/w=4:3 | h/w=3:4 | h/w=1:1 |
|
242 |
+
|:---------------------:|:----------------------------:|:---------------:|:---------------:|:---------------:|:---------------:|
|
243 |
+
| 540p | 544px960px129f | 960px544px129f | 624px832px129f | 832px624px129f | 720px720px129f |
|
244 |
+
| 720p (推荐) | 720px1280px129f | 1280px720px129f | 1104px832px129f | 832px1104px129f | 960px960px129f |
|
245 |
+
|
246 |
+
### 使用命令行
|
247 |
+
|
248 |
+
```bash
|
249 |
+
cd HunyuanVideo
|
250 |
+
|
251 |
+
python3 sample_video.py \
|
252 |
+
--video-size 720 1280 \
|
253 |
+
--video-length 129 \
|
254 |
+
--infer-steps 50 \
|
255 |
+
--prompt "A cat walks on the grass, realistic style." \
|
256 |
+
--flow-reverse \
|
257 |
+
--use-cpu-offload \
|
258 |
+
--save-path ./results
|
259 |
+
```
|
260 |
+
|
261 |
+
### 运行gradio服务
|
262 |
+
```bash
|
263 |
+
python3 gradio_server.py --flow-reverse
|
264 |
+
|
265 |
+
# set SERVER_NAME and SERVER_PORT manually
|
266 |
+
# SERVER_NAME=0.0.0.0 SERVER_PORT=8081 python3 gradio_server.py --flow-reverse
|
267 |
+
```
|
268 |
+
|
269 |
+
### 更多配置
|
270 |
+
|
271 |
+
下面列出了更多关键配置项:
|
272 |
+
|
273 |
+
| 参数 | 默认值 | 描述 |
|
274 |
+
|:----------------------:|:---------:|:-----------------------------------------:|
|
275 |
+
| `--prompt` | None | 用于生成视频的 prompt |
|
276 |
+
| `--video-size` | 720 1280 | 生成视频的高度和宽度 |
|
277 |
+
| `--video-length` | 129 | 生成视频的帧数 |
|
278 |
+
| `--infer-steps` | 50 | 生成时采样的步数 |
|
279 |
+
| `--embedded-cfg-scale` | 6.0 | 文本的控制强度 |
|
280 |
+
| `--flow-shift` | 7.0 | 推理时 timestep 的 shift 系数,值越大,高噪区域采样步数越多 |
|
281 |
+
| `--flow-reverse` | False | If reverse, learning/sampling from t=1 -> t=0 |
|
282 |
+
| `--neg-prompt` | None | 负向词 |
|
283 |
+
| `--seed` | 0 | 随机种子 |
|
284 |
+
| `--use-cpu-offload` | False | 启用 CPU offload,可以节省显存 |
|
285 |
+
| `--save-path` | ./results | 保存路径 |
|
286 |
+
|
287 |
+
|
288 |
+
## 🚀 使用 xDiT 实现多卡并行推理
|
289 |
+
|
290 |
+
[xDiT](https://github.com/xdit-project/xDiT) 是一个针对多 GPU 集群的扩展推理引擎,用于扩展 Transformers(DiTs)。
|
291 |
+
它成功为各种 DiT 模型(包括 mochi-1、CogVideoX、Flux.1、SD3 等)提供了低延迟的并行推理解决方案。该存储库采用了 [Unified Sequence Parallelism (USP)](https://arxiv.org/abs/2405.07719) API 用于混元视频模型的并行推理。
|
292 |
+
|
293 |
+
### 安装与 xDiT 兼容的依赖项
|
294 |
+
|
295 |
+
```
|
296 |
+
# 1. 创建一个空白的 conda 环境
|
297 |
+
conda create -n hunyuanxdit python==3.10.9
|
298 |
+
conda activate hunyuanxdit
|
299 |
+
|
300 |
+
# 2. 使用 CUDA 11.8 安装 PyTorch 组件
|
301 |
+
conda install pytorch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 pytorch-cuda=11.8 -c pytorch -c nvidia
|
302 |
+
|
303 |
+
# 3. 安装 pip 依赖项
|
304 |
+
python -m pip install -r requirements_xdit.txt
|
305 |
+
```
|
306 |
+
|
307 |
+
您可以跳过上述步骤,直接拉取预构建的 Docker 镜像,这个镜像是从 [docker/Dockerfile_xDiT](./docker/Dockerfile_xDiT) 构建的
|
308 |
+
|
309 |
+
```
|
310 |
+
docker pull thufeifeibear/hunyuanvideo:latest
|
311 |
+
```
|
312 |
+
|
313 |
+
### 使用命令行
|
314 |
+
|
315 |
+
例如,可用如下命令使用8张GPU卡完成推理
|
316 |
+
|
317 |
+
```bash
|
318 |
+
cd HunyuanVideo
|
319 |
+
|
320 |
+
torchrun --nproc_per_node=8 sample_video_parallel.py \
|
321 |
+
--video-size 1280 720 \
|
322 |
+
--video-length 129 \
|
323 |
+
--infer-steps 50 \
|
324 |
+
--prompt "A cat walks on the grass, realistic style." \
|
325 |
+
--flow-reverse \
|
326 |
+
--seed 42 \
|
327 |
+
--ulysses_degree 8 \
|
328 |
+
--ring_degree 1 \
|
329 |
+
--save-path ./results
|
330 |
+
```
|
331 |
+
|
332 |
+
可以配置`--ulysses-degree`和`--ring-degree`来控制并行配置,可选参数如下。
|
333 |
+
|
334 |
+
<details>
|
335 |
+
<summary>支持的并行配置 (点击查看详情)</summary>
|
336 |
+
|
337 |
+
| --video-size | --video-length | --ulysses-degree x --ring-degree | --nproc_per_node |
|
338 |
+
|----------------------|----------------|----------------------------------|------------------|
|
339 |
+
| 1280 720 或 720 1280 | 129 | 8x1,4x2,2x4,1x8 | 8 |
|
340 |
+
| 1280 720 或 720 1280 | 129 | 1x5 | 5 |
|
341 |
+
| 1280 720 或 720 1280 | 129 | 4x1,2x2,1x4 | 4 |
|
342 |
+
| 1280 720 或 720 1280 | 129 | 3x1,1x3 | 3 |
|
343 |
+
| 1280 720 或 720 1280 | 129 | 2x1,1x2 | 2 |
|
344 |
+
| 1104 832 或 832 1104 | 129 | 4x1,2x2,1x4 | 4 |
|
345 |
+
| 1104 832 或 832 1104 | 129 | 3x1,1x3 | 3 |
|
346 |
+
| 1104 832 或 832 1104 | 129 | 2x1,1x2 | 2 |
|
347 |
+
| 960 960 | 129 | 6x1,3x2,2x3,1x6 | 6 |
|
348 |
+
| 960 960 | 129 | 4x1,2x2,1x4 | 4 |
|
349 |
+
| 960 960 | 129 | 3x1,1x3 | 3 |
|
350 |
+
| 960 960 | 129 | 1x2,2x1 | 2 |
|
351 |
+
| 960 544 或 544 960 | 129 | 6x1,3x2,2x3,1x6 | 6 |
|
352 |
+
| 960 544 或 544 960 | 129 | 4x1,2x2,1x4 | 4 |
|
353 |
+
| 960 544 或 544 960 | 129 | 3x1,1x3 | 3 |
|
354 |
+
| 960 544 或 544 960 | 129 | 1x2,2x1 | 2 |
|
355 |
+
| 832 624 或 624 832 | 129 | 4x1,2x2,1x4 | 4 |
|
356 |
+
| 624 832 或 624 832 | 129 | 3x1,1x3 | 3 |
|
357 |
+
| 832 624 或 624 832 | 129 | 2x1,1x2 | 2 |
|
358 |
+
| 720 720 | 129 | 1x5 | 5 |
|
359 |
+
| 720 720 | 129 | 3x1,1x3 | 3 |
|
360 |
+
|
361 |
+
</details>
|
362 |
+
|
363 |
+
<p align="center">
|
364 |
+
<table align="center">
|
365 |
+
<thead>
|
366 |
+
<tr>
|
367 |
+
<th colspan="4">在 8xGPU上生成1280x720 (129 帧 50 步)的时耗 (秒) </th>
|
368 |
+
</tr>
|
369 |
+
<tr>
|
370 |
+
<th>1</th>
|
371 |
+
<th>2</th>
|
372 |
+
<th>4</th>
|
373 |
+
<th>8</th>
|
374 |
+
</tr>
|
375 |
+
</thead>
|
376 |
+
<tbody>
|
377 |
+
<tr>
|
378 |
+
<th>1904.08</th>
|
379 |
+
<th>934.09 (2.04x)</th>
|
380 |
+
<th>514.08 (3.70x)</th>
|
381 |
+
<th>337.58 (5.64x)</th>
|
382 |
+
</tr>
|
383 |
+
|
384 |
+
</tbody>
|
385 |
+
</table>
|
386 |
+
</p>
|
387 |
+
|
388 |
+
|
389 |
+
## 🔗 BibTeX
|
390 |
+
如果您认为 [HunyuanVideo](https://arxiv.org/abs/2412.03603) 给您的研究和应用带来了一些帮助,可以通过下面的方式来引用:
|
391 |
+
|
392 |
+
```BibTeX
|
393 |
+
@misc{kong2024hunyuanvideo,
|
394 |
+
title={HunyuanVideo: A Systematic Framework For Large Video Generative Models},
|
395 |
+
author={Weijie Kong, Qi Tian, Zijian Zhang, Rox Min, Zuozhuo Dai, Jin Zhou, Jiangfeng Xiong, Xin Li, Bo Wu, Jianwei Zhang, Kathrina Wu, Qin Lin, Aladdin Wang, Andong Wang, Changlin Li, Duojun Huang, Fang Yang, Hao Tan, Hongmei Wang, Jacob Song, Jiawang Bai, Jianbing Wu, Jinbao Xue, Joey Wang, Junkun Yuan, Kai Wang, Mengyang Liu, Pengyu Li, Shuai Li, Weiyan Wang, Wenqing Yu, Xinchi Deng, Yang Li, Yanxin Long, Yi Chen, Yutao Cui, Yuanbo Peng, Zhentao Yu, Zhiyu He, Zhiyong Xu, Zixiang Zhou, Zunnan Xu, Yangyu Tao, Qinglin Lu, Songtao Liu, Dax Zhou, Hongfa Wang, Yong Yang, Di Wang, Yuhong Liu, and Jie Jiang, along with Caesar Zhong},
|
396 |
+
year={2024},
|
397 |
+
archivePrefix={arXiv preprint arXiv:2412.03603},
|
398 |
+
primaryClass={cs.CV}
|
399 |
+
}
|
400 |
+
```
|
401 |
+
|
402 |
+
|
403 |
+
|
404 |
+
## 🧩 使用 HunyuanVideo 的项目
|
405 |
+
|
406 |
+
如果您的项目中有开发或使用 HunyuanVideo,欢迎告知我们。
|
407 |
+
|
408 |
+
- ComfyUI (支持F8推理和Video2Video生成): [ComfyUI-HunyuanVideoWrapper](https://github.com/kijai/ComfyUI-HunyuanVideoWrapper) by [Kijai](https://github.com/kijai)
|
409 |
+
|
410 |
+
|
411 |
+
|
412 |
+
## 致谢
|
413 |
+
|
414 |
+
HunyuanVideo 的开源离不开诸多开源工作,这里我们特别感谢 [SD3](https://huggingface.co/stabilityai/stable-diffusion-3-medium), [FLUX](https://github.com/black-forest-labs/flux), [Llama](https://github.com/meta-llama/llama), [LLaVA](https://github.com/haotian-liu/LLaVA), [Xtuner](https://github.com/InternLM/xtuner), [diffusers](https://github.com/huggingface/diffusers) and [HuggingFace](https://huggingface.co) 的开源工作和探索。另外,我们也感谢腾讯混元多模态团队对 HunyuanVideo 适配多种文本编码器的支持。
|
415 |
+
|
416 |
+
|
417 |
+
## Star 趋势
|
418 |
+
|
419 |
+
<a href="https://star-history.com/#Tencent/HunyuanVideo&Date">
|
420 |
+
<picture>
|
421 |
+
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=Tencent/HunyuanVideo&type=Date&theme=dark" />
|
422 |
+
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=Tencent/HunyuanVideo&type=Date" />
|
423 |
+
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=Tencent/HunyuanVideo&type=Date" />
|
424 |
+
</picture>
|
425 |
+
</a>
|
assets/3dvae.png
ADDED
assets/WECHAT.md
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<div align="center">
|
2 |
+
<img src=wechat.jpg width="60%"/>
|
3 |
+
|
4 |
+
<p> 扫码关注混元系列工作,加入「 Hunyuan Video 交流群」 </p>
|
5 |
+
<p> Scan the QR code to join the "Hunyuan Discussion Group" </p>
|
6 |
+
</div>
|
7 |
+
|
assets/backbone.png
ADDED
Git LFS Details
|
assets/hunyuanvideo.pdf
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c5aea4522a926775eec39db8bdb4a9ced63a1c0a703fb7254e7b9fdc1dbe7f98
|
3 |
+
size 44603215
|
assets/logo.png
ADDED
assets/overall.png
ADDED
Git LFS Details
|
assets/text_encoder.png
ADDED
Git LFS Details
|
assets/wechat.jpg
ADDED
docker/Dockerfile_xDiT
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu22.04
|
2 |
+
|
3 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
4 |
+
|
5 |
+
# Install necessary packages
|
6 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
7 |
+
python3 \
|
8 |
+
python3-pip \
|
9 |
+
python3-dev \
|
10 |
+
git \
|
11 |
+
wget \
|
12 |
+
bzip2 \
|
13 |
+
&& rm -rf /var/lib/apt/lists/*
|
14 |
+
|
15 |
+
# Install Miniconda
|
16 |
+
RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O /tmp/miniconda.sh \
|
17 |
+
&& bash /tmp/miniconda.sh -b -p /opt/conda \
|
18 |
+
&& rm /tmp/miniconda.sh
|
19 |
+
|
20 |
+
# Set up environment variables for Conda
|
21 |
+
ENV PATH="/opt/conda/bin:$PATH"
|
22 |
+
|
23 |
+
# Create a new conda environment
|
24 |
+
RUN conda create -n myenv python=3.10 -y
|
25 |
+
|
26 |
+
# Activate the conda environment
|
27 |
+
SHELL ["/bin/bash", "--login", "-c"]
|
28 |
+
RUN echo "source activate myenv" > ~/.bashrc
|
29 |
+
ENV PATH /opt/conda/envs/myenv/bin:$PATH
|
30 |
+
|
31 |
+
# Install PyTorch and other dependencies using conda
|
32 |
+
RUN conda install pytorch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 pytorch-cuda=11.8 -c pytorch -c nvidia
|
33 |
+
|
34 |
+
COPY ./requirements_xdit.txt /tmp/requirements_xdit.txt
|
35 |
+
RUN pip install --no-cache-dir -r /tmp/requirements_xdit.txt
|
36 |
+
|
37 |
+
# Set working directory
|
38 |
+
WORKDIR /workspace
|
39 |
+
|
40 |
+
# Default command
|
41 |
+
CMD ["/bin/bash"]
|
environment.yml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: HunyuanVideo
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- nvidia
|
5 |
+
dependencies:
|
6 |
+
- python=3.10.9
|
7 |
+
- pytorch=2.5.1
|
8 |
+
- pip
|
gradio_server.py
ADDED
@@ -0,0 +1,376 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
from pathlib import Path
|
4 |
+
from loguru import logger
|
5 |
+
from datetime import datetime
|
6 |
+
import gradio as gr
|
7 |
+
import random
|
8 |
+
import json
|
9 |
+
from hyvideo.utils.file_utils import save_videos_grid
|
10 |
+
from hyvideo.config import parse_args
|
11 |
+
from hyvideo.inference import HunyuanVideoSampler
|
12 |
+
from hyvideo.constants import NEGATIVE_PROMPT
|
13 |
+
from mmgp import offload, profile_type
|
14 |
+
|
15 |
+
|
16 |
+
args = parse_args()
|
17 |
+
|
18 |
+
|
19 |
+
force_profile_no = int(args.profile)
|
20 |
+
verbose_level = int(args.verbose)
|
21 |
+
|
22 |
+
transformer_choices=["ckpts/hunyuan-video-t2v-720p/transformers/hunyuan_video_720_bf16.safetensors", "ckpts/hunyuan-video-t2v-720p/transformers/hunyuan_video_720_quanto_int8.safetensors", "ckpts/hunyuan-video-t2v-720p/transformers/fast_hunyuan_video_720_quanto_int8.safetensors"]
|
23 |
+
text_encoder_choices = ["ckpts/text_encoder/llava-llama-3-8b-v1_1_fp16.safetensors", "ckpts/text_encoder/llava-llama-3-8b-v1_1_quanto_int8.safetensors"]
|
24 |
+
server_config_filename = "gradio_config.json"
|
25 |
+
|
26 |
+
if not Path(server_config_filename).is_file():
|
27 |
+
server_config = {"attention_mode" : "sdpa",
|
28 |
+
"transformer_filename": transformer_choices[1],
|
29 |
+
"text_encoder_filename" : text_encoder_choices[1],
|
30 |
+
"compile" : "",
|
31 |
+
"profile" : profile_type.LowRAM_LowVRAM }
|
32 |
+
|
33 |
+
with open(server_config_filename, "w", encoding="utf-8") as writer:
|
34 |
+
writer.write(json.dumps(server_config))
|
35 |
+
else:
|
36 |
+
with open(server_config_filename, "r", encoding="utf-8") as reader:
|
37 |
+
text = reader.read()
|
38 |
+
server_config = json.loads(text)
|
39 |
+
|
40 |
+
transformer_filename = server_config["transformer_filename"]
|
41 |
+
text_encoder_filename = server_config["text_encoder_filename"]
|
42 |
+
attention_mode = server_config["attention_mode"]
|
43 |
+
profile = force_profile_no if force_profile_no >=0 else server_config["profile"]
|
44 |
+
compile = server_config.get("compile", "")
|
45 |
+
|
46 |
+
#transformer_filename = "ckpts/hunyuan-video-t2v-720p/transformers/hunyuan_video_720_bf16.safetensors"
|
47 |
+
#transformer_filename = "ckpts/hunyuan-video-t2v-720p/transformers/hunyuan_video_720_quanto_int8.safetensors"
|
48 |
+
#transformer_filename = "ckpts/hunyuan-video-t2v-720p/transformers/fast_hunyuan_video_720_quanto_int8.safetensors"
|
49 |
+
|
50 |
+
#text_encoder_filename = "ckpts/text_encoder/llava-llama-3-8b-v1_1_fp16.safetensors"
|
51 |
+
#text_encoder_filename = "ckpts/text_encoder/llava-llama-3-8b-v1_1_quanto_int8.safetensors"
|
52 |
+
|
53 |
+
#attention_mode="sage"
|
54 |
+
#attention_mode="flash"
|
55 |
+
|
56 |
+
def download_models(transformer_filename, text_encoder_filename):
|
57 |
+
def computeList(filename):
|
58 |
+
pos = filename.rfind("/")
|
59 |
+
filename = filename[pos+1:]
|
60 |
+
if not "quanto" in filename:
|
61 |
+
return [filename]
|
62 |
+
pos = filename.rfind(".")
|
63 |
+
return [filename, filename[:pos] +"_map.json"]
|
64 |
+
|
65 |
+
from huggingface_hub import hf_hub_download, snapshot_download
|
66 |
+
repoId = "DeepBeepMeep/HunyuanVideo"
|
67 |
+
sourceFolderList = ["text_encoder_2", "text_encoder", "hunyuan-video-t2v-720p/vae", "hunyuan-video-t2v-720p/transformers" ]
|
68 |
+
fileList = [ [], ["config.json", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json"] + computeList(text_encoder_filename) , [], computeList(transformer_filename) ]
|
69 |
+
targetRoot = "ckpts/"
|
70 |
+
for sourceFolder, files in zip(sourceFolderList,fileList ):
|
71 |
+
if len(files)==0:
|
72 |
+
if not Path(targetRoot + sourceFolder).exists():
|
73 |
+
snapshot_download(repo_id=repoId, allow_patterns=sourceFolder +"/*", local_dir= targetRoot)
|
74 |
+
else:
|
75 |
+
for onefile in files:
|
76 |
+
if not os.path.isfile(targetRoot + sourceFolder + "/" + onefile ):
|
77 |
+
hf_hub_download(repo_id=repoId, filename=onefile, local_dir = targetRoot, subfolder=sourceFolder)
|
78 |
+
|
79 |
+
|
80 |
+
download_models(transformer_filename, text_encoder_filename)
|
81 |
+
|
82 |
+
# models_root_path = Path(args.model_base)
|
83 |
+
# if not models_root_path.exists():
|
84 |
+
# raise ValueError(f"`models_root` not exists: {models_root_path}")
|
85 |
+
|
86 |
+
offload.default_verboseLevel = verbose_level
|
87 |
+
with open("./ckpts/hunyuan-video-t2v-720p/vae/config.json", "r", encoding="utf-8") as reader:
|
88 |
+
text = reader.read()
|
89 |
+
vae_config= json.loads(text)
|
90 |
+
# reduce time window used by the VAE for temporal splitting (former time windows is too large for 24 GB)
|
91 |
+
if vae_config["sample_tsize"] == 64:
|
92 |
+
vae_config["sample_tsize"] = 32
|
93 |
+
with open("./ckpts/hunyuan-video-t2v-720p/vae/config.json", "w", encoding="utf-8") as writer:
|
94 |
+
writer.write(json.dumps(vae_config))
|
95 |
+
|
96 |
+
args.flow_reverse = True
|
97 |
+
if profile == 5:
|
98 |
+
pinToMemory = False
|
99 |
+
partialPinning = False
|
100 |
+
else:
|
101 |
+
pinToMemory = True
|
102 |
+
import psutil
|
103 |
+
physical_memory= psutil.virtual_memory().total
|
104 |
+
partialPinning = physical_memory <= 2**30 * 32
|
105 |
+
|
106 |
+
hunyuan_video_sampler = HunyuanVideoSampler.from_pretrained(transformer_filename, text_encoder_filename, attention_mode = attention_mode, pinToMemory = pinToMemory, partialPinning = partialPinning, args=args, device="cpu")
|
107 |
+
|
108 |
+
pipe = hunyuan_video_sampler.pipeline
|
109 |
+
|
110 |
+
offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = False)
|
111 |
+
|
112 |
+
def apply_changes(
|
113 |
+
transformer_choice,
|
114 |
+
text_encoder_choice,
|
115 |
+
attention_choice,
|
116 |
+
compile_choice,
|
117 |
+
profile_choice,
|
118 |
+
):
|
119 |
+
server_config = {"attention_mode" : attention_choice,
|
120 |
+
"transformer_filename": transformer_choices[transformer_choice],
|
121 |
+
"text_encoder_filename" : text_encoder_choices[text_encoder_choice],
|
122 |
+
"compile" : compile_choice,
|
123 |
+
"profile" : profile_choice }
|
124 |
+
|
125 |
+
with open(server_config_filename, "w", encoding="utf-8") as writer:
|
126 |
+
writer.write(json.dumps(server_config))
|
127 |
+
|
128 |
+
return "<h1>New Config file created. Please restart the Gradio Server</h1>"
|
129 |
+
|
130 |
+
|
131 |
+
from moviepy.editor import ImageSequenceClip
|
132 |
+
import numpy as np
|
133 |
+
|
134 |
+
def save_video(final_frames, output_path, fps=24):
|
135 |
+
assert final_frames.ndim == 4 and final_frames.shape[3] == 3, f"invalid shape: {final_frames} (need t h w c)"
|
136 |
+
if final_frames.dtype != np.uint8:
|
137 |
+
final_frames = (final_frames * 255).astype(np.uint8)
|
138 |
+
ImageSequenceClip(list(final_frames), fps=fps).write_videofile(output_path, verbose= False, logger = None)
|
139 |
+
|
140 |
+
|
141 |
+
def generate_video(
|
142 |
+
prompt,
|
143 |
+
resolution,
|
144 |
+
video_length,
|
145 |
+
seed,
|
146 |
+
num_inference_steps,
|
147 |
+
guidance_scale,
|
148 |
+
flow_shift,
|
149 |
+
embedded_guidance_scale,
|
150 |
+
tea_cache,
|
151 |
+
progress=gr.Progress(track_tqdm=True)
|
152 |
+
|
153 |
+
):
|
154 |
+
seed = None if seed == -1 else seed
|
155 |
+
width, height = resolution.split("x")
|
156 |
+
width, height = int(width), int(height)
|
157 |
+
negative_prompt = "" # not applicable in the inference
|
158 |
+
|
159 |
+
|
160 |
+
# TeaCache
|
161 |
+
trans = hunyuan_video_sampler.pipeline.transformer.__class__
|
162 |
+
trans.enable_teacache = tea_cache > 0
|
163 |
+
if trans.enable_teacache:
|
164 |
+
trans.num_steps = num_inference_steps
|
165 |
+
trans.cnt = 0
|
166 |
+
trans.rel_l1_thresh = 0.15 # 0.1 for 1.6x speedup, 0.15 for 2.1x speedup
|
167 |
+
trans.accumulated_rel_l1_distance = 0
|
168 |
+
trans.previous_modulated_input = None
|
169 |
+
trans.previous_residual = None
|
170 |
+
|
171 |
+
|
172 |
+
outputs = hunyuan_video_sampler.predict(
|
173 |
+
prompt=prompt,
|
174 |
+
height=height,
|
175 |
+
width=width,
|
176 |
+
video_length=(video_length // 4)* 4 + 1 ,
|
177 |
+
seed=seed,
|
178 |
+
negative_prompt=negative_prompt,
|
179 |
+
infer_steps=num_inference_steps,
|
180 |
+
guidance_scale=guidance_scale,
|
181 |
+
num_videos_per_prompt=1,
|
182 |
+
flow_shift=flow_shift,
|
183 |
+
batch_size=1,
|
184 |
+
embedded_guidance_scale=embedded_guidance_scale
|
185 |
+
)
|
186 |
+
|
187 |
+
|
188 |
+
from einops import rearrange
|
189 |
+
|
190 |
+
samples = outputs['samples']
|
191 |
+
sample = samples[0]
|
192 |
+
|
193 |
+
video = rearrange(sample.cpu().numpy(), "c t h w -> t h w c")
|
194 |
+
|
195 |
+
save_path = os.path.join(os.getcwd(), "gradio_outputs")
|
196 |
+
os.makedirs(save_path, exist_ok=True)
|
197 |
+
time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%Hh%Mm%Ss")
|
198 |
+
file_name = f"{time_flag}_seed{outputs['seeds'][0]}_{outputs['prompts'][0][:100].replace('/','')}.mp4".replace(':',' ').replace('\\',' ')
|
199 |
+
video_path = os.path.join(os.getcwd(), "gradio_outputs", file_name)
|
200 |
+
|
201 |
+
save_video(video, video_path )
|
202 |
+
print(f"New video saved to Path: "+video_path)
|
203 |
+
|
204 |
+
|
205 |
+
return video_path
|
206 |
+
|
207 |
+
def create_demo(model_path, save_path):
|
208 |
+
|
209 |
+
with gr.Blocks() as demo:
|
210 |
+
gr.Markdown("<div align=center><H1>HunyuanVideo<SUP>GP</SUP> by Tencent</H3></div>")
|
211 |
+
gr.Markdown("*GPU Poor version by **DeepBeepMeep**. Now this great video generator can run smoothly on a 24 GB rig.*")
|
212 |
+
gr.Markdown("Please be aware of these limits with profiles 2 and 4 if you have 24 GB of VRAM (RTX 3090 / RTX 4090):")
|
213 |
+
gr.Markdown("- max 192 frames for 848 x 480 ")
|
214 |
+
gr.Markdown("- max 86 frames for 1280 x 720")
|
215 |
+
gr.Markdown("In the worst case, one step should not take more than 2 minutes. If it the case you may be running out of RAM / VRAM. Try to generate fewer images / lower res / a less demanding profile.")
|
216 |
+
gr.Markdown("If you have a Linux / WSL system you may turn on compilation (see below) and will be able to generate an extra 30°% frames")
|
217 |
+
|
218 |
+
with gr.Accordion("Video Engine Configuration", open = False):
|
219 |
+
gr.Markdown("For the changes to be effective you will need to restart the gradio_server")
|
220 |
+
|
221 |
+
with gr.Column():
|
222 |
+
index = transformer_choices.index(transformer_filename)
|
223 |
+
index = 0 if index ==0 else index
|
224 |
+
|
225 |
+
transformer_choice = gr.Dropdown(
|
226 |
+
choices=[
|
227 |
+
("Hunyuan Video 16 bits - the default engine in its original glory, offers a slightly better image quality but slower and requires more RAM", 0),
|
228 |
+
("Hunyuan Video quantized to 8 bits (recommended) - the default engine but quantized", 1),
|
229 |
+
("Fast Hunyuan Video quantized to 8 bits - requires less than 10 steps but worse quality", 2),
|
230 |
+
],
|
231 |
+
value= index,
|
232 |
+
label="Transformer"
|
233 |
+
)
|
234 |
+
index = text_encoder_choices.index(text_encoder_filename)
|
235 |
+
index = 0 if index ==0 else index
|
236 |
+
|
237 |
+
gr.Markdown("Note that even if you choose a 16 bits Llava model below, depending on the profile it may be automatically quantized to 8 bits on the fly")
|
238 |
+
text_encoder_choice = gr.Dropdown(
|
239 |
+
choices=[
|
240 |
+
("Llava Llama 1.1 16 bits - unquantized text encoder, better quality uses more RAM", 0),
|
241 |
+
("Llava Llama 1.1 quantized to 8 bits - quantized text encoder, worse quality but uses less RAM", 1),
|
242 |
+
],
|
243 |
+
value= index,
|
244 |
+
label="Text Encoder"
|
245 |
+
)
|
246 |
+
attention_choice = gr.Dropdown(
|
247 |
+
choices=[
|
248 |
+
("Scale Dot Product Attention: default", "sdpa"),
|
249 |
+
("Flash: good quality - requires additional install (usually complex to set up on Windows without WSL)", "flash"),
|
250 |
+
("Sage: 30% faster but worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage"),
|
251 |
+
],
|
252 |
+
value= attention_mode,
|
253 |
+
label="Attention Type"
|
254 |
+
)
|
255 |
+
gr.Markdown("Beware: restarting the server or changing a resolution or video duration will trigger a recompilation that may last a few minutes")
|
256 |
+
compile_choice = gr.Dropdown(
|
257 |
+
choices=[
|
258 |
+
("ON: works only on Linux / WSL", "transformer"),
|
259 |
+
("OFF: no other choice if you have Windows without using WSL", "" ),
|
260 |
+
],
|
261 |
+
value= compile,
|
262 |
+
label="Compile Transformer (up to 50% faster and 30% more frames but requires Linux / WSL and Flash or Sage attention)"
|
263 |
+
)
|
264 |
+
profile_choice = gr.Dropdown(
|
265 |
+
choices=[
|
266 |
+
("HighRAM_HighVRAM, profile 1: at least 48 GB of RAM and 24 GB of VRAM, the fastest for shorter videos a RTX 3090 / RTX 4090", 1),
|
267 |
+
("HighRAM_LowVRAM, profile 2 (Recommended): at least 48 GB of RAM and 12 GB of VRAM, the most versatile profile with high RAM, better suited for RTX 3070/3080/4070/4080 or for RTX 3090 / RTX 4090 with large pictures batches or long videos", 2),
|
268 |
+
("LowRAM_HighVRAM, profile 3: at least 32 GB of RAM and 24 GB of VRAM, adapted for RTX 3090 / RTX 4090 with limited RAM for good speed short video",3),
|
269 |
+
("LowRAM_LowVRAM, profile 4 (Default): at least 32 GB of RAM and 12 GB of VRAM, if you have little VRAM or want to generate longer videos",4),
|
270 |
+
("VerylowRAM_LowVRAM, profile 5: (Fail safe): at least 16 GB of RAM and 10 GB of VRAM, if you don't have much it won't be fast but maybe it will work",5)
|
271 |
+
],
|
272 |
+
value= profile,
|
273 |
+
label="Profile"
|
274 |
+
)
|
275 |
+
|
276 |
+
msg = gr.Markdown()
|
277 |
+
apply_btn = gr.Button("Apply Changes")
|
278 |
+
|
279 |
+
|
280 |
+
apply_btn.click(
|
281 |
+
fn=apply_changes,
|
282 |
+
inputs=[
|
283 |
+
transformer_choice,
|
284 |
+
text_encoder_choice,
|
285 |
+
attention_choice,
|
286 |
+
compile_choice,
|
287 |
+
profile_choice,
|
288 |
+
],
|
289 |
+
outputs=msg
|
290 |
+
)
|
291 |
+
|
292 |
+
with gr.Row():
|
293 |
+
with gr.Column():
|
294 |
+
prompt = gr.Textbox(label="Prompt", value="A large orange octopus is seen resting on the bottom of the ocean floor, blending in with the sandy and rocky terrain. Its tentacles are spread out around its body, and its eyes are closed. The octopus is unaware of a king crab that is crawling towards it from behind a rock, its claws raised and ready to attack. The crab is brown and spiny, with long legs and antennae. The scene is captured from a wide angle, showing the vastness and depth of the ocean. The water is clear and blue, with rays of sunlight filtering through. The shot is sharp and crisp, with a high dynamic range. The octopus and the crab are in focus, while the background is slightly blurred, creating a depth of field effect.")
|
295 |
+
with gr.Row():
|
296 |
+
resolution = gr.Dropdown(
|
297 |
+
choices=[
|
298 |
+
# 720p
|
299 |
+
("1280x720 (16:9, 720p)", "1280x720"),
|
300 |
+
("720x1280 (9:16, 720p)", "720x1280"),
|
301 |
+
("1104x832 (4:3, 720p)", "1104x832"),
|
302 |
+
("832x1104 (3:4, 720p)", "832x1104"),
|
303 |
+
("960x960 (1:1, 720p)", "960x960"),
|
304 |
+
# 540p
|
305 |
+
("960x544 (16:9, 540p)", "960x544"),
|
306 |
+
("848x480 (16:9, 540p)", "848x480"),
|
307 |
+
("544x960 (9:16, 540p)", "544x960"),
|
308 |
+
("832x624 (4:3, 540p)", "832x624"),
|
309 |
+
("624x832 (3:4, 540p)", "624x832"),
|
310 |
+
("720x720 (1:1, 540p)", "720x720"),
|
311 |
+
],
|
312 |
+
value="848x480",
|
313 |
+
label="Resolution"
|
314 |
+
)
|
315 |
+
|
316 |
+
video_length = gr.Slider(5, 193, value=97, step=4, label="Number of frames (24 = 1s)")
|
317 |
+
|
318 |
+
# video_length = gr.Dropdown(
|
319 |
+
# label="Video Length",
|
320 |
+
# choices=[
|
321 |
+
# ("1.5s(41f)", 41),
|
322 |
+
# ("2s(65f)", 65),
|
323 |
+
# ("4s(97f)", 97),
|
324 |
+
# ("5s(129f)", 129),
|
325 |
+
# ],
|
326 |
+
# value=97,
|
327 |
+
# )
|
328 |
+
num_inference_steps = gr.Slider(1, 100, value=50, step=1, label="Number of Inference Steps")
|
329 |
+
show_advanced = gr.Checkbox(label="Show Advanced Options", value=False)
|
330 |
+
with gr.Row(visible=False) as advanced_row:
|
331 |
+
with gr.Column():
|
332 |
+
seed = gr.Number(value=-1, label="Seed (-1 for random)")
|
333 |
+
guidance_scale = gr.Slider(1.0, 20.0, value=1.0, step=0.5, label="Guidance Scale")
|
334 |
+
flow_shift = gr.Slider(0.0, 25.0, value=7.0, step=0.1, label="Flow Shift")
|
335 |
+
embedded_guidance_scale = gr.Slider(1.0, 20.0, value=6.0, step=0.5, label="Embedded Guidance Scale")
|
336 |
+
with gr.Row():
|
337 |
+
tea_cache_setting = gr.Dropdown(
|
338 |
+
choices=[
|
339 |
+
("Disabled", 0),
|
340 |
+
("Fast (x1.6 speed up)", 0.1),
|
341 |
+
("Faster (x2.1 speed up)", 0.15),
|
342 |
+
],
|
343 |
+
value=0,
|
344 |
+
label="Tea Cache acceleration (the faster the acceleration the higher the degradation of the quality of the video)"
|
345 |
+
)
|
346 |
+
|
347 |
+
show_advanced.change(fn=lambda x: gr.Row(visible=x), inputs=[show_advanced], outputs=[advanced_row])
|
348 |
+
generate_btn = gr.Button("Generate")
|
349 |
+
|
350 |
+
with gr.Column():
|
351 |
+
output = gr.Video(label="Generated Video")
|
352 |
+
|
353 |
+
generate_btn.click(
|
354 |
+
fn=generate_video,
|
355 |
+
inputs=[
|
356 |
+
prompt,
|
357 |
+
resolution,
|
358 |
+
video_length,
|
359 |
+
seed,
|
360 |
+
num_inference_steps,
|
361 |
+
guidance_scale,
|
362 |
+
flow_shift,
|
363 |
+
embedded_guidance_scale,
|
364 |
+
tea_cache_setting
|
365 |
+
],
|
366 |
+
outputs=output
|
367 |
+
)
|
368 |
+
|
369 |
+
return demo
|
370 |
+
|
371 |
+
if __name__ == "__main__":
|
372 |
+
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
373 |
+
server_name = os.getenv("SERVER_NAME", "0.0.0.0")
|
374 |
+
server_port = int(os.getenv("SERVER_PORT", "7860"))
|
375 |
+
demo = create_demo(args.model_base, args.save_path)
|
376 |
+
demo.launch(server_name=server_name, server_port=server_port)
|
hyvideo/__init__.py
ADDED
File without changes
|
hyvideo/config.py
ADDED
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from .constants import *
|
3 |
+
import re
|
4 |
+
from .modules.models import HUNYUAN_VIDEO_CONFIG
|
5 |
+
|
6 |
+
|
7 |
+
def parse_args(namespace=None):
|
8 |
+
parser = argparse.ArgumentParser(description="HunyuanVideo inference script")
|
9 |
+
|
10 |
+
parser = add_network_args(parser)
|
11 |
+
parser = add_extra_models_args(parser)
|
12 |
+
parser = add_denoise_schedule_args(parser)
|
13 |
+
parser = add_inference_args(parser)
|
14 |
+
parser = add_parallel_args(parser)
|
15 |
+
|
16 |
+
args = parser.parse_args(namespace=namespace)
|
17 |
+
args = sanity_check_args(args)
|
18 |
+
|
19 |
+
return args
|
20 |
+
|
21 |
+
|
22 |
+
def add_network_args(parser: argparse.ArgumentParser):
|
23 |
+
group = parser.add_argument_group(title="HunyuanVideo network args")
|
24 |
+
|
25 |
+
group.add_argument(
|
26 |
+
"--profile",
|
27 |
+
type=str,
|
28 |
+
default=-1,
|
29 |
+
help="Profile No"
|
30 |
+
)
|
31 |
+
|
32 |
+
group.add_argument(
|
33 |
+
"--verbose",
|
34 |
+
type=str,
|
35 |
+
default=1,
|
36 |
+
help="Verbose level"
|
37 |
+
)
|
38 |
+
|
39 |
+
# Main model
|
40 |
+
group.add_argument(
|
41 |
+
"--model",
|
42 |
+
type=str,
|
43 |
+
choices=list(HUNYUAN_VIDEO_CONFIG.keys()),
|
44 |
+
default="HYVideo-T/2-cfgdistill",
|
45 |
+
)
|
46 |
+
group.add_argument(
|
47 |
+
"--latent-channels",
|
48 |
+
type=str,
|
49 |
+
default=16,
|
50 |
+
help="Number of latent channels of DiT. If None, it will be determined by `vae`. If provided, "
|
51 |
+
"it still needs to match the latent channels of the VAE model.",
|
52 |
+
)
|
53 |
+
group.add_argument(
|
54 |
+
"--precision",
|
55 |
+
type=str,
|
56 |
+
default="bf16",
|
57 |
+
choices=PRECISIONS,
|
58 |
+
help="Precision mode. Options: fp32, fp16, bf16. Applied to the backbone model and optimizer.",
|
59 |
+
)
|
60 |
+
|
61 |
+
# RoPE
|
62 |
+
group.add_argument(
|
63 |
+
"--rope-theta", type=int, default=256, help="Theta used in RoPE."
|
64 |
+
)
|
65 |
+
return parser
|
66 |
+
|
67 |
+
|
68 |
+
def add_extra_models_args(parser: argparse.ArgumentParser):
|
69 |
+
group = parser.add_argument_group(
|
70 |
+
title="Extra models args, including vae, text encoders and tokenizers)"
|
71 |
+
)
|
72 |
+
|
73 |
+
# - VAE
|
74 |
+
group.add_argument(
|
75 |
+
"--vae",
|
76 |
+
type=str,
|
77 |
+
default="884-16c-hy",
|
78 |
+
choices=list(VAE_PATH),
|
79 |
+
help="Name of the VAE model.",
|
80 |
+
)
|
81 |
+
group.add_argument(
|
82 |
+
"--vae-precision",
|
83 |
+
type=str,
|
84 |
+
default="fp16",
|
85 |
+
choices=PRECISIONS,
|
86 |
+
help="Precision mode for the VAE model.",
|
87 |
+
)
|
88 |
+
group.add_argument(
|
89 |
+
"--vae-tiling",
|
90 |
+
action="store_true",
|
91 |
+
help="Enable tiling for the VAE model to save GPU memory.",
|
92 |
+
)
|
93 |
+
group.set_defaults(vae_tiling=True)
|
94 |
+
|
95 |
+
group.add_argument(
|
96 |
+
"--text-encoder",
|
97 |
+
type=str,
|
98 |
+
default="llm",
|
99 |
+
choices=list(TEXT_ENCODER_PATH),
|
100 |
+
help="Name of the text encoder model.",
|
101 |
+
)
|
102 |
+
group.add_argument(
|
103 |
+
"--text-encoder-precision",
|
104 |
+
type=str,
|
105 |
+
default="fp16",
|
106 |
+
choices=PRECISIONS,
|
107 |
+
help="Precision mode for the text encoder model.",
|
108 |
+
)
|
109 |
+
group.add_argument(
|
110 |
+
"--text-states-dim",
|
111 |
+
type=int,
|
112 |
+
default=4096,
|
113 |
+
help="Dimension of the text encoder hidden states.",
|
114 |
+
)
|
115 |
+
group.add_argument(
|
116 |
+
"--text-len", type=int, default=256, help="Maximum length of the text input."
|
117 |
+
)
|
118 |
+
group.add_argument(
|
119 |
+
"--tokenizer",
|
120 |
+
type=str,
|
121 |
+
default="llm",
|
122 |
+
choices=list(TOKENIZER_PATH),
|
123 |
+
help="Name of the tokenizer model.",
|
124 |
+
)
|
125 |
+
group.add_argument(
|
126 |
+
"--prompt-template",
|
127 |
+
type=str,
|
128 |
+
default="dit-llm-encode",
|
129 |
+
choices=PROMPT_TEMPLATE,
|
130 |
+
help="Image prompt template for the decoder-only text encoder model.",
|
131 |
+
)
|
132 |
+
group.add_argument(
|
133 |
+
"--prompt-template-video",
|
134 |
+
type=str,
|
135 |
+
default="dit-llm-encode-video",
|
136 |
+
choices=PROMPT_TEMPLATE,
|
137 |
+
help="Video prompt template for the decoder-only text encoder model.",
|
138 |
+
)
|
139 |
+
group.add_argument(
|
140 |
+
"--hidden-state-skip-layer",
|
141 |
+
type=int,
|
142 |
+
default=2,
|
143 |
+
help="Skip layer for hidden states.",
|
144 |
+
)
|
145 |
+
group.add_argument(
|
146 |
+
"--apply-final-norm",
|
147 |
+
action="store_true",
|
148 |
+
help="Apply final normalization to the used text encoder hidden states.",
|
149 |
+
)
|
150 |
+
|
151 |
+
# - CLIP
|
152 |
+
group.add_argument(
|
153 |
+
"--text-encoder-2",
|
154 |
+
type=str,
|
155 |
+
default="clipL",
|
156 |
+
choices=list(TEXT_ENCODER_PATH),
|
157 |
+
help="Name of the second text encoder model.",
|
158 |
+
)
|
159 |
+
group.add_argument(
|
160 |
+
"--text-encoder-precision-2",
|
161 |
+
type=str,
|
162 |
+
default="fp16",
|
163 |
+
choices=PRECISIONS,
|
164 |
+
help="Precision mode for the second text encoder model.",
|
165 |
+
)
|
166 |
+
group.add_argument(
|
167 |
+
"--text-states-dim-2",
|
168 |
+
type=int,
|
169 |
+
default=768,
|
170 |
+
help="Dimension of the second text encoder hidden states.",
|
171 |
+
)
|
172 |
+
group.add_argument(
|
173 |
+
"--tokenizer-2",
|
174 |
+
type=str,
|
175 |
+
default="clipL",
|
176 |
+
choices=list(TOKENIZER_PATH),
|
177 |
+
help="Name of the second tokenizer model.",
|
178 |
+
)
|
179 |
+
group.add_argument(
|
180 |
+
"--text-len-2",
|
181 |
+
type=int,
|
182 |
+
default=77,
|
183 |
+
help="Maximum length of the second text input.",
|
184 |
+
)
|
185 |
+
|
186 |
+
return parser
|
187 |
+
|
188 |
+
|
189 |
+
def add_denoise_schedule_args(parser: argparse.ArgumentParser):
|
190 |
+
group = parser.add_argument_group(title="Denoise schedule args")
|
191 |
+
|
192 |
+
group.add_argument(
|
193 |
+
"--denoise-type",
|
194 |
+
type=str,
|
195 |
+
default="flow",
|
196 |
+
help="Denoise type for noised inputs.",
|
197 |
+
)
|
198 |
+
|
199 |
+
# Flow Matching
|
200 |
+
group.add_argument(
|
201 |
+
"--flow-shift",
|
202 |
+
type=float,
|
203 |
+
default=7.0,
|
204 |
+
help="Shift factor for flow matching schedulers.",
|
205 |
+
)
|
206 |
+
group.add_argument(
|
207 |
+
"--flow-reverse",
|
208 |
+
action="store_true",
|
209 |
+
help="If reverse, learning/sampling from t=1 -> t=0.",
|
210 |
+
)
|
211 |
+
group.add_argument(
|
212 |
+
"--flow-solver",
|
213 |
+
type=str,
|
214 |
+
default="euler",
|
215 |
+
help="Solver for flow matching.",
|
216 |
+
)
|
217 |
+
group.add_argument(
|
218 |
+
"--use-linear-quadratic-schedule",
|
219 |
+
action="store_true",
|
220 |
+
help="Use linear quadratic schedule for flow matching."
|
221 |
+
"Following MovieGen (https://ai.meta.com/static-resource/movie-gen-research-paper)",
|
222 |
+
)
|
223 |
+
group.add_argument(
|
224 |
+
"--linear-schedule-end",
|
225 |
+
type=int,
|
226 |
+
default=25,
|
227 |
+
help="End step for linear quadratic schedule for flow matching.",
|
228 |
+
)
|
229 |
+
|
230 |
+
return parser
|
231 |
+
|
232 |
+
|
233 |
+
def add_inference_args(parser: argparse.ArgumentParser):
|
234 |
+
group = parser.add_argument_group(title="Inference args")
|
235 |
+
|
236 |
+
# ======================== Model loads ========================
|
237 |
+
group.add_argument(
|
238 |
+
"--model-base",
|
239 |
+
type=str,
|
240 |
+
default="ckpts",
|
241 |
+
help="Root path of all the models, including t2v models and extra models.",
|
242 |
+
)
|
243 |
+
group.add_argument(
|
244 |
+
"--dit-weight",
|
245 |
+
type=str,
|
246 |
+
default="ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt",
|
247 |
+
help="Path to the HunyuanVideo model. If None, search the model in the args.model_root."
|
248 |
+
"1. If it is a file, load the model directly."
|
249 |
+
"2. If it is a directory, search the model in the directory. Support two types of models: "
|
250 |
+
"1) named `pytorch_model_*.pt`"
|
251 |
+
"2) named `*_model_states.pt`, where * can be `mp_rank_00`.",
|
252 |
+
)
|
253 |
+
group.add_argument(
|
254 |
+
"--model-resolution",
|
255 |
+
type=str,
|
256 |
+
default="540p",
|
257 |
+
choices=["540p", "720p"],
|
258 |
+
help="Root path of all the models, including t2v models and extra models.",
|
259 |
+
)
|
260 |
+
group.add_argument(
|
261 |
+
"--load-key",
|
262 |
+
type=str,
|
263 |
+
default="module",
|
264 |
+
help="Key to load the model states. 'module' for the main model, 'ema' for the EMA model.",
|
265 |
+
)
|
266 |
+
group.add_argument(
|
267 |
+
"--use-cpu-offload",
|
268 |
+
action="store_true",
|
269 |
+
help="Use CPU offload for the model load.",
|
270 |
+
)
|
271 |
+
|
272 |
+
# ======================== Inference general setting ========================
|
273 |
+
group.add_argument(
|
274 |
+
"--batch-size",
|
275 |
+
type=int,
|
276 |
+
default=1,
|
277 |
+
help="Batch size for inference and evaluation.",
|
278 |
+
)
|
279 |
+
group.add_argument(
|
280 |
+
"--infer-steps",
|
281 |
+
type=int,
|
282 |
+
default=50,
|
283 |
+
help="Number of denoising steps for inference.",
|
284 |
+
)
|
285 |
+
group.add_argument(
|
286 |
+
"--disable-autocast",
|
287 |
+
action="store_true",
|
288 |
+
help="Disable autocast for denoising loop and vae decoding in pipeline sampling.",
|
289 |
+
)
|
290 |
+
group.add_argument(
|
291 |
+
"--save-path",
|
292 |
+
type=str,
|
293 |
+
default="./results",
|
294 |
+
help="Path to save the generated samples.",
|
295 |
+
)
|
296 |
+
group.add_argument(
|
297 |
+
"--save-path-suffix",
|
298 |
+
type=str,
|
299 |
+
default="",
|
300 |
+
help="Suffix for the directory of saved samples.",
|
301 |
+
)
|
302 |
+
group.add_argument(
|
303 |
+
"--name-suffix",
|
304 |
+
type=str,
|
305 |
+
default="",
|
306 |
+
help="Suffix for the names of saved samples.",
|
307 |
+
)
|
308 |
+
group.add_argument(
|
309 |
+
"--num-videos",
|
310 |
+
type=int,
|
311 |
+
default=1,
|
312 |
+
help="Number of videos to generate for each prompt.",
|
313 |
+
)
|
314 |
+
# ---sample size---
|
315 |
+
group.add_argument(
|
316 |
+
"--video-size",
|
317 |
+
type=int,
|
318 |
+
nargs="+",
|
319 |
+
default=(720, 1280),
|
320 |
+
help="Video size for training. If a single value is provided, it will be used for both height "
|
321 |
+
"and width. If two values are provided, they will be used for height and width "
|
322 |
+
"respectively.",
|
323 |
+
)
|
324 |
+
group.add_argument(
|
325 |
+
"--video-length",
|
326 |
+
type=int,
|
327 |
+
default=129,
|
328 |
+
help="How many frames to sample from a video. if using 3d vae, the number should be 4n+1",
|
329 |
+
)
|
330 |
+
# --- prompt ---
|
331 |
+
group.add_argument(
|
332 |
+
"--prompt",
|
333 |
+
type=str,
|
334 |
+
default=None,
|
335 |
+
help="Prompt for sampling during evaluation.",
|
336 |
+
)
|
337 |
+
group.add_argument(
|
338 |
+
"--seed-type",
|
339 |
+
type=str,
|
340 |
+
default="auto",
|
341 |
+
choices=["file", "random", "fixed", "auto"],
|
342 |
+
help="Seed type for evaluation. If file, use the seed from the CSV file. If random, generate a "
|
343 |
+
"random seed. If fixed, use the fixed seed given by `--seed`. If auto, `csv` will use the "
|
344 |
+
"seed column if available, otherwise use the fixed `seed` value. `prompt` will use the "
|
345 |
+
"fixed `seed` value.",
|
346 |
+
)
|
347 |
+
group.add_argument("--seed", type=int, default=None, help="Seed for evaluation.")
|
348 |
+
|
349 |
+
# Classifier-Free Guidance
|
350 |
+
group.add_argument(
|
351 |
+
"--neg-prompt", type=str, default=None, help="Negative prompt for sampling."
|
352 |
+
)
|
353 |
+
group.add_argument(
|
354 |
+
"--cfg-scale", type=float, default=1.0, help="Classifier free guidance scale."
|
355 |
+
)
|
356 |
+
group.add_argument(
|
357 |
+
"--embedded-cfg-scale",
|
358 |
+
type=float,
|
359 |
+
default=6.0,
|
360 |
+
help="Embeded classifier free guidance scale.",
|
361 |
+
)
|
362 |
+
|
363 |
+
group.add_argument(
|
364 |
+
"--reproduce",
|
365 |
+
action="store_true",
|
366 |
+
help="Enable reproducibility by setting random seeds and deterministic algorithms.",
|
367 |
+
)
|
368 |
+
|
369 |
+
return parser
|
370 |
+
|
371 |
+
|
372 |
+
def add_parallel_args(parser: argparse.ArgumentParser):
|
373 |
+
group = parser.add_argument_group(title="Parallel args")
|
374 |
+
|
375 |
+
# ======================== Model loads ========================
|
376 |
+
group.add_argument(
|
377 |
+
"--ulysses-degree",
|
378 |
+
type=int,
|
379 |
+
default=1,
|
380 |
+
help="Ulysses degree.",
|
381 |
+
)
|
382 |
+
group.add_argument(
|
383 |
+
"--ring-degree",
|
384 |
+
type=int,
|
385 |
+
default=1,
|
386 |
+
help="Ulysses degree.",
|
387 |
+
)
|
388 |
+
|
389 |
+
return parser
|
390 |
+
|
391 |
+
|
392 |
+
def sanity_check_args(args):
|
393 |
+
# VAE channels
|
394 |
+
vae_pattern = r"\d{2,3}-\d{1,2}c-\w+"
|
395 |
+
if not re.match(vae_pattern, args.vae):
|
396 |
+
raise ValueError(
|
397 |
+
f"Invalid VAE model: {args.vae}. Must be in the format of '{vae_pattern}'."
|
398 |
+
)
|
399 |
+
vae_channels = int(args.vae.split("-")[1][:-1])
|
400 |
+
if args.latent_channels is None:
|
401 |
+
args.latent_channels = vae_channels
|
402 |
+
if vae_channels != args.latent_channels:
|
403 |
+
raise ValueError(
|
404 |
+
f"Latent channels ({args.latent_channels}) must match the VAE channels ({vae_channels})."
|
405 |
+
)
|
406 |
+
return args
|
hyvideo/constants.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
|
4 |
+
__all__ = [
|
5 |
+
"C_SCALE",
|
6 |
+
"PROMPT_TEMPLATE",
|
7 |
+
"MODEL_BASE",
|
8 |
+
"PRECISIONS",
|
9 |
+
"NORMALIZATION_TYPE",
|
10 |
+
"ACTIVATION_TYPE",
|
11 |
+
"VAE_PATH",
|
12 |
+
"TEXT_ENCODER_PATH",
|
13 |
+
"TOKENIZER_PATH",
|
14 |
+
"TEXT_PROJECTION",
|
15 |
+
"DATA_TYPE",
|
16 |
+
"NEGATIVE_PROMPT",
|
17 |
+
]
|
18 |
+
|
19 |
+
PRECISION_TO_TYPE = {
|
20 |
+
'fp32': torch.float32,
|
21 |
+
'fp16': torch.float16,
|
22 |
+
'bf16': torch.bfloat16,
|
23 |
+
}
|
24 |
+
|
25 |
+
# =================== Constant Values =====================
|
26 |
+
# Computation scale factor, 1P = 1_000_000_000_000_000. Tensorboard will display the value in PetaFLOPS to avoid
|
27 |
+
# overflow error when tensorboard logging values.
|
28 |
+
C_SCALE = 1_000_000_000_000_000
|
29 |
+
|
30 |
+
# When using decoder-only models, we must provide a prompt template to instruct the text encoder
|
31 |
+
# on how to generate the text.
|
32 |
+
# --------------------------------------------------------------------
|
33 |
+
PROMPT_TEMPLATE_ENCODE = (
|
34 |
+
"<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, "
|
35 |
+
"quantity, text, spatial relationships of the objects and background:<|eot_id|>"
|
36 |
+
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
|
37 |
+
)
|
38 |
+
PROMPT_TEMPLATE_ENCODE_VIDEO = (
|
39 |
+
"<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
|
40 |
+
"1. The main content and theme of the video."
|
41 |
+
"2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
|
42 |
+
"3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
|
43 |
+
"4. background environment, light, style and atmosphere."
|
44 |
+
"5. camera angles, movements, and transitions used in the video:<|eot_id|>"
|
45 |
+
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
|
46 |
+
)
|
47 |
+
|
48 |
+
NEGATIVE_PROMPT = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion"
|
49 |
+
|
50 |
+
PROMPT_TEMPLATE = {
|
51 |
+
"dit-llm-encode": {
|
52 |
+
"template": PROMPT_TEMPLATE_ENCODE,
|
53 |
+
"crop_start": 36,
|
54 |
+
},
|
55 |
+
"dit-llm-encode-video": {
|
56 |
+
"template": PROMPT_TEMPLATE_ENCODE_VIDEO,
|
57 |
+
"crop_start": 95,
|
58 |
+
},
|
59 |
+
}
|
60 |
+
|
61 |
+
# ======================= Model ======================
|
62 |
+
PRECISIONS = {"fp32", "fp16", "bf16"}
|
63 |
+
NORMALIZATION_TYPE = {"layer", "rms"}
|
64 |
+
ACTIVATION_TYPE = {"relu", "silu", "gelu", "gelu_tanh"}
|
65 |
+
|
66 |
+
# =================== Model Path =====================
|
67 |
+
MODEL_BASE = os.getenv("MODEL_BASE", "./ckpts")
|
68 |
+
|
69 |
+
# =================== Data =======================
|
70 |
+
DATA_TYPE = {"image", "video", "image_video"}
|
71 |
+
|
72 |
+
# 3D VAE
|
73 |
+
VAE_PATH = {"884-16c-hy": f"{MODEL_BASE}/hunyuan-video-t2v-720p/vae"}
|
74 |
+
|
75 |
+
# Text Encoder
|
76 |
+
TEXT_ENCODER_PATH = {
|
77 |
+
"clipL": f"{MODEL_BASE}/text_encoder_2",
|
78 |
+
"llm": f"{MODEL_BASE}/text_encoder",
|
79 |
+
}
|
80 |
+
|
81 |
+
# Tokenizer
|
82 |
+
TOKENIZER_PATH = {
|
83 |
+
"clipL": f"{MODEL_BASE}/text_encoder_2",
|
84 |
+
"llm": f"{MODEL_BASE}/text_encoder",
|
85 |
+
}
|
86 |
+
|
87 |
+
TEXT_PROJECTION = {
|
88 |
+
"linear", # Default, an nn.Linear() layer
|
89 |
+
"single_refiner", # Single TokenRefiner. Refer to LI-DiT
|
90 |
+
}
|
hyvideo/diffusion/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .pipelines import HunyuanVideoPipeline
|
2 |
+
from .schedulers import FlowMatchDiscreteScheduler
|
hyvideo/diffusion/pipelines/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .pipeline_hunyuan_video import HunyuanVideoPipeline
|
hyvideo/diffusion/pipelines/pipeline_hunyuan_video.py
ADDED
@@ -0,0 +1,1103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
#
|
16 |
+
# Modified from diffusers==0.29.2
|
17 |
+
#
|
18 |
+
# ==============================================================================
|
19 |
+
import inspect
|
20 |
+
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
|
21 |
+
import torch
|
22 |
+
import torch.distributed as dist
|
23 |
+
import numpy as np
|
24 |
+
from dataclasses import dataclass
|
25 |
+
from packaging import version
|
26 |
+
|
27 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
28 |
+
from diffusers.configuration_utils import FrozenDict
|
29 |
+
from diffusers.image_processor import VaeImageProcessor
|
30 |
+
from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
31 |
+
from diffusers.models import AutoencoderKL
|
32 |
+
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
33 |
+
from diffusers.schedulers import KarrasDiffusionSchedulers
|
34 |
+
from diffusers.utils import (
|
35 |
+
USE_PEFT_BACKEND,
|
36 |
+
deprecate,
|
37 |
+
logging,
|
38 |
+
replace_example_docstring,
|
39 |
+
scale_lora_layers,
|
40 |
+
unscale_lora_layers,
|
41 |
+
)
|
42 |
+
from diffusers.utils.torch_utils import randn_tensor
|
43 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
44 |
+
from diffusers.utils import BaseOutput
|
45 |
+
|
46 |
+
from ...constants import PRECISION_TO_TYPE
|
47 |
+
from ...vae.autoencoder_kl_causal_3d import AutoencoderKLCausal3D
|
48 |
+
from ...text_encoder import TextEncoder
|
49 |
+
from ...modules import HYVideoDiffusionTransformer
|
50 |
+
|
51 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
52 |
+
|
53 |
+
EXAMPLE_DOC_STRING = """"""
|
54 |
+
|
55 |
+
|
56 |
+
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
57 |
+
"""
|
58 |
+
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
59 |
+
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
|
60 |
+
"""
|
61 |
+
std_text = noise_pred_text.std(
|
62 |
+
dim=list(range(1, noise_pred_text.ndim)), keepdim=True
|
63 |
+
)
|
64 |
+
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
65 |
+
# rescale the results from guidance (fixes overexposure)
|
66 |
+
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
67 |
+
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
68 |
+
noise_cfg = (
|
69 |
+
guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
70 |
+
)
|
71 |
+
return noise_cfg
|
72 |
+
|
73 |
+
|
74 |
+
def retrieve_timesteps(
|
75 |
+
scheduler,
|
76 |
+
num_inference_steps: Optional[int] = None,
|
77 |
+
device: Optional[Union[str, torch.device]] = None,
|
78 |
+
timesteps: Optional[List[int]] = None,
|
79 |
+
sigmas: Optional[List[float]] = None,
|
80 |
+
**kwargs,
|
81 |
+
):
|
82 |
+
"""
|
83 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
84 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
85 |
+
|
86 |
+
Args:
|
87 |
+
scheduler (`SchedulerMixin`):
|
88 |
+
The scheduler to get timesteps from.
|
89 |
+
num_inference_steps (`int`):
|
90 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
91 |
+
must be `None`.
|
92 |
+
device (`str` or `torch.device`, *optional*):
|
93 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
94 |
+
timesteps (`List[int]`, *optional*):
|
95 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
96 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
97 |
+
sigmas (`List[float]`, *optional*):
|
98 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
99 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
103 |
+
second element is the number of inference steps.
|
104 |
+
"""
|
105 |
+
if timesteps is not None and sigmas is not None:
|
106 |
+
raise ValueError(
|
107 |
+
"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
|
108 |
+
)
|
109 |
+
if timesteps is not None:
|
110 |
+
accepts_timesteps = "timesteps" in set(
|
111 |
+
inspect.signature(scheduler.set_timesteps).parameters.keys()
|
112 |
+
)
|
113 |
+
if not accepts_timesteps:
|
114 |
+
raise ValueError(
|
115 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
116 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
117 |
+
)
|
118 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
119 |
+
timesteps = scheduler.timesteps
|
120 |
+
num_inference_steps = len(timesteps)
|
121 |
+
elif sigmas is not None:
|
122 |
+
accept_sigmas = "sigmas" in set(
|
123 |
+
inspect.signature(scheduler.set_timesteps).parameters.keys()
|
124 |
+
)
|
125 |
+
if not accept_sigmas:
|
126 |
+
raise ValueError(
|
127 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
128 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
129 |
+
)
|
130 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
131 |
+
timesteps = scheduler.timesteps
|
132 |
+
num_inference_steps = len(timesteps)
|
133 |
+
else:
|
134 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
135 |
+
timesteps = scheduler.timesteps
|
136 |
+
return timesteps, num_inference_steps
|
137 |
+
|
138 |
+
|
139 |
+
@dataclass
|
140 |
+
class HunyuanVideoPipelineOutput(BaseOutput):
|
141 |
+
videos: Union[torch.Tensor, np.ndarray]
|
142 |
+
|
143 |
+
|
144 |
+
class HunyuanVideoPipeline(DiffusionPipeline):
|
145 |
+
r"""
|
146 |
+
Pipeline for text-to-video generation using HunyuanVideo.
|
147 |
+
|
148 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
149 |
+
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
150 |
+
|
151 |
+
Args:
|
152 |
+
vae ([`AutoencoderKL`]):
|
153 |
+
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
154 |
+
text_encoder ([`TextEncoder`]):
|
155 |
+
Frozen text-encoder.
|
156 |
+
text_encoder_2 ([`TextEncoder`]):
|
157 |
+
Frozen text-encoder_2.
|
158 |
+
transformer ([`HYVideoDiffusionTransformer`]):
|
159 |
+
A `HYVideoDiffusionTransformer` to denoise the encoded video latents.
|
160 |
+
scheduler ([`SchedulerMixin`]):
|
161 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents.
|
162 |
+
"""
|
163 |
+
|
164 |
+
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
|
165 |
+
_optional_components = ["text_encoder_2"]
|
166 |
+
_exclude_from_cpu_offload = ["transformer"]
|
167 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
168 |
+
|
169 |
+
def __init__(
|
170 |
+
self,
|
171 |
+
vae: AutoencoderKL,
|
172 |
+
text_encoder: TextEncoder,
|
173 |
+
transformer: HYVideoDiffusionTransformer,
|
174 |
+
scheduler: KarrasDiffusionSchedulers,
|
175 |
+
text_encoder_2: Optional[TextEncoder] = None,
|
176 |
+
progress_bar_config: Dict[str, Any] = None,
|
177 |
+
args=None,
|
178 |
+
):
|
179 |
+
super().__init__()
|
180 |
+
|
181 |
+
# ==========================================================================================
|
182 |
+
if progress_bar_config is None:
|
183 |
+
progress_bar_config = {}
|
184 |
+
if not hasattr(self, "_progress_bar_config"):
|
185 |
+
self._progress_bar_config = {}
|
186 |
+
self._progress_bar_config.update(progress_bar_config)
|
187 |
+
|
188 |
+
self.args = args
|
189 |
+
# ==========================================================================================
|
190 |
+
|
191 |
+
if (
|
192 |
+
hasattr(scheduler.config, "steps_offset")
|
193 |
+
and scheduler.config.steps_offset != 1
|
194 |
+
):
|
195 |
+
deprecation_message = (
|
196 |
+
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
197 |
+
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
198 |
+
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
|
199 |
+
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
|
200 |
+
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
|
201 |
+
" file"
|
202 |
+
)
|
203 |
+
deprecate(
|
204 |
+
"steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False
|
205 |
+
)
|
206 |
+
new_config = dict(scheduler.config)
|
207 |
+
new_config["steps_offset"] = 1
|
208 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
209 |
+
|
210 |
+
if (
|
211 |
+
hasattr(scheduler.config, "clip_sample")
|
212 |
+
and scheduler.config.clip_sample is True
|
213 |
+
):
|
214 |
+
deprecation_message = (
|
215 |
+
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
216 |
+
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
217 |
+
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
|
218 |
+
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
|
219 |
+
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
|
220 |
+
)
|
221 |
+
deprecate(
|
222 |
+
"clip_sample not set", "1.0.0", deprecation_message, standard_warn=False
|
223 |
+
)
|
224 |
+
new_config = dict(scheduler.config)
|
225 |
+
new_config["clip_sample"] = False
|
226 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
227 |
+
|
228 |
+
self.register_modules(
|
229 |
+
vae=vae,
|
230 |
+
text_encoder=text_encoder,
|
231 |
+
transformer=transformer,
|
232 |
+
scheduler=scheduler,
|
233 |
+
text_encoder_2=text_encoder_2,
|
234 |
+
)
|
235 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
236 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
237 |
+
|
238 |
+
def encode_prompt(
|
239 |
+
self,
|
240 |
+
prompt,
|
241 |
+
device,
|
242 |
+
num_videos_per_prompt,
|
243 |
+
do_classifier_free_guidance,
|
244 |
+
negative_prompt=None,
|
245 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
246 |
+
attention_mask: Optional[torch.Tensor] = None,
|
247 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
248 |
+
negative_attention_mask: Optional[torch.Tensor] = None,
|
249 |
+
lora_scale: Optional[float] = None,
|
250 |
+
clip_skip: Optional[int] = None,
|
251 |
+
text_encoder: Optional[TextEncoder] = None,
|
252 |
+
data_type: Optional[str] = "image",
|
253 |
+
):
|
254 |
+
r"""
|
255 |
+
Encodes the prompt into text encoder hidden states.
|
256 |
+
|
257 |
+
Args:
|
258 |
+
prompt (`str` or `List[str]`, *optional*):
|
259 |
+
prompt to be encoded
|
260 |
+
device: (`torch.device`):
|
261 |
+
torch device
|
262 |
+
num_videos_per_prompt (`int`):
|
263 |
+
number of videos that should be generated per prompt
|
264 |
+
do_classifier_free_guidance (`bool`):
|
265 |
+
whether to use classifier free guidance or not
|
266 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
267 |
+
The prompt or prompts not to guide the video generation. If not defined, one has to pass
|
268 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
269 |
+
less than `1`).
|
270 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
271 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
272 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
273 |
+
attention_mask (`torch.Tensor`, *optional*):
|
274 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
275 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
276 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
277 |
+
argument.
|
278 |
+
negative_attention_mask (`torch.Tensor`, *optional*):
|
279 |
+
lora_scale (`float`, *optional*):
|
280 |
+
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
281 |
+
clip_skip (`int`, *optional*):
|
282 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
283 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
284 |
+
text_encoder (TextEncoder, *optional*):
|
285 |
+
data_type (`str`, *optional*):
|
286 |
+
"""
|
287 |
+
if text_encoder is None:
|
288 |
+
text_encoder = self.text_encoder
|
289 |
+
|
290 |
+
# set lora scale so that monkey patched LoRA
|
291 |
+
# function of text encoder can correctly access it
|
292 |
+
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
|
293 |
+
self._lora_scale = lora_scale
|
294 |
+
|
295 |
+
# dynamically adjust the LoRA scale
|
296 |
+
if not USE_PEFT_BACKEND:
|
297 |
+
adjust_lora_scale_text_encoder(text_encoder.model, lora_scale)
|
298 |
+
else:
|
299 |
+
scale_lora_layers(text_encoder.model, lora_scale)
|
300 |
+
|
301 |
+
if prompt is not None and isinstance(prompt, str):
|
302 |
+
batch_size = 1
|
303 |
+
elif prompt is not None and isinstance(prompt, list):
|
304 |
+
batch_size = len(prompt)
|
305 |
+
else:
|
306 |
+
batch_size = prompt_embeds.shape[0]
|
307 |
+
|
308 |
+
if prompt_embeds is None:
|
309 |
+
# textual inversion: process multi-vector tokens if necessary
|
310 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
311 |
+
prompt = self.maybe_convert_prompt(prompt, text_encoder.tokenizer)
|
312 |
+
|
313 |
+
text_inputs = text_encoder.text2tokens(prompt, data_type=data_type)
|
314 |
+
|
315 |
+
if clip_skip is None:
|
316 |
+
prompt_outputs = text_encoder.encode(
|
317 |
+
text_inputs, data_type=data_type, device=device
|
318 |
+
)
|
319 |
+
prompt_embeds = prompt_outputs.hidden_state
|
320 |
+
else:
|
321 |
+
prompt_outputs = text_encoder.encode(
|
322 |
+
text_inputs,
|
323 |
+
output_hidden_states=True,
|
324 |
+
data_type=data_type,
|
325 |
+
device=device,
|
326 |
+
)
|
327 |
+
# Access the `hidden_states` first, that contains a tuple of
|
328 |
+
# all the hidden states from the encoder layers. Then index into
|
329 |
+
# the tuple to access the hidden states from the desired layer.
|
330 |
+
prompt_embeds = prompt_outputs.hidden_states_list[-(clip_skip + 1)]
|
331 |
+
# We also need to apply the final LayerNorm here to not mess with the
|
332 |
+
# representations. The `last_hidden_states` that we typically use for
|
333 |
+
# obtaining the final prompt representations passes through the LayerNorm
|
334 |
+
# layer.
|
335 |
+
prompt_embeds = text_encoder.model.text_model.final_layer_norm(
|
336 |
+
prompt_embeds
|
337 |
+
)
|
338 |
+
|
339 |
+
attention_mask = prompt_outputs.attention_mask
|
340 |
+
if attention_mask is not None:
|
341 |
+
attention_mask = attention_mask.to(device)
|
342 |
+
bs_embed, seq_len = attention_mask.shape
|
343 |
+
attention_mask = attention_mask.repeat(1, num_videos_per_prompt)
|
344 |
+
attention_mask = attention_mask.view(
|
345 |
+
bs_embed * num_videos_per_prompt, seq_len
|
346 |
+
)
|
347 |
+
|
348 |
+
if text_encoder is not None:
|
349 |
+
prompt_embeds_dtype = text_encoder.dtype
|
350 |
+
elif self.transformer is not None:
|
351 |
+
prompt_embeds_dtype = self.transformer.dtype
|
352 |
+
else:
|
353 |
+
prompt_embeds_dtype = prompt_embeds.dtype
|
354 |
+
|
355 |
+
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
356 |
+
|
357 |
+
if prompt_embeds.ndim == 2:
|
358 |
+
bs_embed, _ = prompt_embeds.shape
|
359 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
360 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt)
|
361 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, -1)
|
362 |
+
else:
|
363 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
364 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
365 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
366 |
+
prompt_embeds = prompt_embeds.view(
|
367 |
+
bs_embed * num_videos_per_prompt, seq_len, -1
|
368 |
+
)
|
369 |
+
|
370 |
+
# get unconditional embeddings for classifier free guidance
|
371 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
372 |
+
uncond_tokens: List[str]
|
373 |
+
if negative_prompt is None:
|
374 |
+
uncond_tokens = [""] * batch_size
|
375 |
+
elif prompt is not None and type(prompt) is not type(negative_prompt):
|
376 |
+
raise TypeError(
|
377 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
378 |
+
f" {type(prompt)}."
|
379 |
+
)
|
380 |
+
elif isinstance(negative_prompt, str):
|
381 |
+
uncond_tokens = [negative_prompt]
|
382 |
+
elif batch_size != len(negative_prompt):
|
383 |
+
raise ValueError(
|
384 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
385 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
386 |
+
" the batch size of `prompt`."
|
387 |
+
)
|
388 |
+
else:
|
389 |
+
uncond_tokens = negative_prompt
|
390 |
+
|
391 |
+
# textual inversion: process multi-vector tokens if necessary
|
392 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
393 |
+
uncond_tokens = self.maybe_convert_prompt(
|
394 |
+
uncond_tokens, text_encoder.tokenizer
|
395 |
+
)
|
396 |
+
|
397 |
+
# max_length = prompt_embeds.shape[1]
|
398 |
+
uncond_input = text_encoder.text2tokens(uncond_tokens, data_type=data_type)
|
399 |
+
|
400 |
+
negative_prompt_outputs = text_encoder.encode(
|
401 |
+
uncond_input, data_type=data_type, device=device
|
402 |
+
)
|
403 |
+
negative_prompt_embeds = negative_prompt_outputs.hidden_state
|
404 |
+
|
405 |
+
negative_attention_mask = negative_prompt_outputs.attention_mask
|
406 |
+
if negative_attention_mask is not None:
|
407 |
+
negative_attention_mask = negative_attention_mask.to(device)
|
408 |
+
_, seq_len = negative_attention_mask.shape
|
409 |
+
negative_attention_mask = negative_attention_mask.repeat(
|
410 |
+
1, num_videos_per_prompt
|
411 |
+
)
|
412 |
+
negative_attention_mask = negative_attention_mask.view(
|
413 |
+
batch_size * num_videos_per_prompt, seq_len
|
414 |
+
)
|
415 |
+
|
416 |
+
if do_classifier_free_guidance:
|
417 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
418 |
+
seq_len = negative_prompt_embeds.shape[1]
|
419 |
+
|
420 |
+
negative_prompt_embeds = negative_prompt_embeds.to(
|
421 |
+
dtype=prompt_embeds_dtype, device=device
|
422 |
+
)
|
423 |
+
|
424 |
+
if negative_prompt_embeds.ndim == 2:
|
425 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(
|
426 |
+
1, num_videos_per_prompt
|
427 |
+
)
|
428 |
+
negative_prompt_embeds = negative_prompt_embeds.view(
|
429 |
+
batch_size * num_videos_per_prompt, -1
|
430 |
+
)
|
431 |
+
else:
|
432 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(
|
433 |
+
1, num_videos_per_prompt, 1
|
434 |
+
)
|
435 |
+
negative_prompt_embeds = negative_prompt_embeds.view(
|
436 |
+
batch_size * num_videos_per_prompt, seq_len, -1
|
437 |
+
)
|
438 |
+
|
439 |
+
if text_encoder is not None:
|
440 |
+
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
|
441 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
442 |
+
unscale_lora_layers(text_encoder.model, lora_scale)
|
443 |
+
|
444 |
+
return (
|
445 |
+
prompt_embeds,
|
446 |
+
negative_prompt_embeds,
|
447 |
+
attention_mask,
|
448 |
+
negative_attention_mask,
|
449 |
+
)
|
450 |
+
|
451 |
+
def decode_latents(self, latents, enable_tiling=True):
|
452 |
+
deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
|
453 |
+
deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
|
454 |
+
|
455 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
456 |
+
if enable_tiling:
|
457 |
+
self.vae.enable_tiling()
|
458 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
459 |
+
else:
|
460 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
461 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
462 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
463 |
+
if image.ndim == 4:
|
464 |
+
image = image.cpu().permute(0, 2, 3, 1).float()
|
465 |
+
else:
|
466 |
+
image = image.cpu().float()
|
467 |
+
return image
|
468 |
+
|
469 |
+
def prepare_extra_func_kwargs(self, func, kwargs):
|
470 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
471 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
472 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
473 |
+
# and should be between [0, 1]
|
474 |
+
extra_step_kwargs = {}
|
475 |
+
|
476 |
+
for k, v in kwargs.items():
|
477 |
+
accepts = k in set(inspect.signature(func).parameters.keys())
|
478 |
+
if accepts:
|
479 |
+
extra_step_kwargs[k] = v
|
480 |
+
return extra_step_kwargs
|
481 |
+
|
482 |
+
def check_inputs(
|
483 |
+
self,
|
484 |
+
prompt,
|
485 |
+
height,
|
486 |
+
width,
|
487 |
+
video_length,
|
488 |
+
callback_steps,
|
489 |
+
negative_prompt=None,
|
490 |
+
prompt_embeds=None,
|
491 |
+
negative_prompt_embeds=None,
|
492 |
+
callback_on_step_end_tensor_inputs=None,
|
493 |
+
vae_ver="88-4c-sd",
|
494 |
+
):
|
495 |
+
if height % 8 != 0 or width % 8 != 0:
|
496 |
+
raise ValueError(
|
497 |
+
f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
|
498 |
+
)
|
499 |
+
|
500 |
+
if video_length is not None:
|
501 |
+
if "884" in vae_ver:
|
502 |
+
if video_length != 1 and (video_length - 1) % 4 != 0:
|
503 |
+
raise ValueError(
|
504 |
+
f"`video_length` has to be 1 or a multiple of 4 but is {video_length}."
|
505 |
+
)
|
506 |
+
elif "888" in vae_ver:
|
507 |
+
if video_length != 1 and (video_length - 1) % 8 != 0:
|
508 |
+
raise ValueError(
|
509 |
+
f"`video_length` has to be 1 or a multiple of 8 but is {video_length}."
|
510 |
+
)
|
511 |
+
|
512 |
+
if callback_steps is not None and (
|
513 |
+
not isinstance(callback_steps, int) or callback_steps <= 0
|
514 |
+
):
|
515 |
+
raise ValueError(
|
516 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
517 |
+
f" {type(callback_steps)}."
|
518 |
+
)
|
519 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
520 |
+
k in self._callback_tensor_inputs
|
521 |
+
for k in callback_on_step_end_tensor_inputs
|
522 |
+
):
|
523 |
+
raise ValueError(
|
524 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
525 |
+
)
|
526 |
+
|
527 |
+
if prompt is not None and prompt_embeds is not None:
|
528 |
+
raise ValueError(
|
529 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
530 |
+
" only forward one of the two."
|
531 |
+
)
|
532 |
+
elif prompt is None and prompt_embeds is None:
|
533 |
+
raise ValueError(
|
534 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
535 |
+
)
|
536 |
+
elif prompt is not None and (
|
537 |
+
not isinstance(prompt, str) and not isinstance(prompt, list)
|
538 |
+
):
|
539 |
+
raise ValueError(
|
540 |
+
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
|
541 |
+
)
|
542 |
+
|
543 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
544 |
+
raise ValueError(
|
545 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
546 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
547 |
+
)
|
548 |
+
|
549 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
550 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
551 |
+
raise ValueError(
|
552 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
553 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
554 |
+
f" {negative_prompt_embeds.shape}."
|
555 |
+
)
|
556 |
+
|
557 |
+
|
558 |
+
def prepare_latents(
|
559 |
+
self,
|
560 |
+
batch_size,
|
561 |
+
num_channels_latents,
|
562 |
+
height,
|
563 |
+
width,
|
564 |
+
video_length,
|
565 |
+
dtype,
|
566 |
+
device,
|
567 |
+
generator,
|
568 |
+
latents=None,
|
569 |
+
):
|
570 |
+
shape = (
|
571 |
+
batch_size,
|
572 |
+
num_channels_latents,
|
573 |
+
video_length,
|
574 |
+
int(height) // self.vae_scale_factor,
|
575 |
+
int(width) // self.vae_scale_factor,
|
576 |
+
)
|
577 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
578 |
+
raise ValueError(
|
579 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
580 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
581 |
+
)
|
582 |
+
|
583 |
+
if latents is None:
|
584 |
+
latents = randn_tensor(
|
585 |
+
shape, generator=generator, device=device, dtype=dtype
|
586 |
+
)
|
587 |
+
else:
|
588 |
+
latents = latents.to(device)
|
589 |
+
|
590 |
+
# Check existence to make it compatible with FlowMatchEulerDiscreteScheduler
|
591 |
+
if hasattr(self.scheduler, "init_noise_sigma"):
|
592 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
593 |
+
latents = latents * self.scheduler.init_noise_sigma
|
594 |
+
return latents
|
595 |
+
|
596 |
+
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
|
597 |
+
def get_guidance_scale_embedding(
|
598 |
+
self,
|
599 |
+
w: torch.Tensor,
|
600 |
+
embedding_dim: int = 512,
|
601 |
+
dtype: torch.dtype = torch.float32,
|
602 |
+
) -> torch.Tensor:
|
603 |
+
"""
|
604 |
+
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
|
605 |
+
|
606 |
+
Args:
|
607 |
+
w (`torch.Tensor`):
|
608 |
+
Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
|
609 |
+
embedding_dim (`int`, *optional*, defaults to 512):
|
610 |
+
Dimension of the embeddings to generate.
|
611 |
+
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
|
612 |
+
Data type of the generated embeddings.
|
613 |
+
|
614 |
+
Returns:
|
615 |
+
`torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
|
616 |
+
"""
|
617 |
+
assert len(w.shape) == 1
|
618 |
+
w = w * 1000.0
|
619 |
+
|
620 |
+
half_dim = embedding_dim // 2
|
621 |
+
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
|
622 |
+
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
|
623 |
+
emb = w.to(dtype)[:, None] * emb[None, :]
|
624 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
625 |
+
if embedding_dim % 2 == 1: # zero pad
|
626 |
+
emb = torch.nn.functional.pad(emb, (0, 1))
|
627 |
+
assert emb.shape == (w.shape[0], embedding_dim)
|
628 |
+
return emb
|
629 |
+
|
630 |
+
@property
|
631 |
+
def guidance_scale(self):
|
632 |
+
return self._guidance_scale
|
633 |
+
|
634 |
+
@property
|
635 |
+
def guidance_rescale(self):
|
636 |
+
return self._guidance_rescale
|
637 |
+
|
638 |
+
@property
|
639 |
+
def clip_skip(self):
|
640 |
+
return self._clip_skip
|
641 |
+
|
642 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
643 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
644 |
+
# corresponds to doing no classifier free guidance.
|
645 |
+
@property
|
646 |
+
def do_classifier_free_guidance(self):
|
647 |
+
# return self._guidance_scale > 1 and self.transformer.config.time_cond_proj_dim is None
|
648 |
+
return self._guidance_scale > 1
|
649 |
+
|
650 |
+
@property
|
651 |
+
def cross_attention_kwargs(self):
|
652 |
+
return self._cross_attention_kwargs
|
653 |
+
|
654 |
+
@property
|
655 |
+
def num_timesteps(self):
|
656 |
+
return self._num_timesteps
|
657 |
+
|
658 |
+
@property
|
659 |
+
def interrupt(self):
|
660 |
+
return self._interrupt
|
661 |
+
|
662 |
+
@torch.no_grad()
|
663 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
664 |
+
def __call__(
|
665 |
+
self,
|
666 |
+
prompt: Union[str, List[str]],
|
667 |
+
height: int,
|
668 |
+
width: int,
|
669 |
+
video_length: int,
|
670 |
+
data_type: str = "video",
|
671 |
+
num_inference_steps: int = 50,
|
672 |
+
timesteps: List[int] = None,
|
673 |
+
sigmas: List[float] = None,
|
674 |
+
guidance_scale: float = 7.5,
|
675 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
676 |
+
num_videos_per_prompt: Optional[int] = 1,
|
677 |
+
eta: float = 0.0,
|
678 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
679 |
+
latents: Optional[torch.Tensor] = None,
|
680 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
681 |
+
attention_mask: Optional[torch.Tensor] = None,
|
682 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
683 |
+
negative_attention_mask: Optional[torch.Tensor] = None,
|
684 |
+
output_type: Optional[str] = "pil",
|
685 |
+
return_dict: bool = True,
|
686 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
687 |
+
guidance_rescale: float = 0.0,
|
688 |
+
clip_skip: Optional[int] = None,
|
689 |
+
callback_on_step_end: Optional[
|
690 |
+
Union[
|
691 |
+
Callable[[int, int, Dict], None],
|
692 |
+
PipelineCallback,
|
693 |
+
MultiPipelineCallbacks,
|
694 |
+
]
|
695 |
+
] = None,
|
696 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
697 |
+
freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
|
698 |
+
vae_ver: str = "88-4c-sd",
|
699 |
+
enable_tiling: bool = False,
|
700 |
+
n_tokens: Optional[int] = None,
|
701 |
+
embedded_guidance_scale: Optional[float] = None,
|
702 |
+
**kwargs,
|
703 |
+
):
|
704 |
+
r"""
|
705 |
+
The call function to the pipeline for generation.
|
706 |
+
|
707 |
+
Args:
|
708 |
+
prompt (`str` or `List[str]`):
|
709 |
+
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
710 |
+
height (`int`):
|
711 |
+
The height in pixels of the generated image.
|
712 |
+
width (`int`):
|
713 |
+
The width in pixels of the generated image.
|
714 |
+
video_length (`int`):
|
715 |
+
The number of frames in the generated video.
|
716 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
717 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
718 |
+
expense of slower inference.
|
719 |
+
timesteps (`List[int]`, *optional*):
|
720 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
721 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
722 |
+
passed will be used. Must be in descending order.
|
723 |
+
sigmas (`List[float]`, *optional*):
|
724 |
+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
725 |
+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
726 |
+
will be used.
|
727 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
728 |
+
A higher guidance scale value encourages the model to generate images closely linked to the text
|
729 |
+
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
730 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
731 |
+
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
732 |
+
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
733 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
734 |
+
The number of images to generate per prompt.
|
735 |
+
eta (`float`, *optional*, defaults to 0.0):
|
736 |
+
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
737 |
+
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
738 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
739 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
740 |
+
generation deterministic.
|
741 |
+
latents (`torch.Tensor`, *optional*):
|
742 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
743 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
744 |
+
tensor is generated by sampling using the supplied random `generator`.
|
745 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
746 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
747 |
+
provided, text embeddings are generated from the `prompt` input argument.
|
748 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
749 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
750 |
+
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
751 |
+
|
752 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
753 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
754 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
755 |
+
Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a
|
756 |
+
plain tuple.
|
757 |
+
cross_attention_kwargs (`dict`, *optional*):
|
758 |
+
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
759 |
+
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
760 |
+
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
761 |
+
Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
|
762 |
+
Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
|
763 |
+
using zero terminal SNR.
|
764 |
+
clip_skip (`int`, *optional*):
|
765 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
766 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
767 |
+
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
768 |
+
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
|
769 |
+
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
|
770 |
+
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
|
771 |
+
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
|
772 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
773 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
774 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
775 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
776 |
+
|
777 |
+
Examples:
|
778 |
+
|
779 |
+
Returns:
|
780 |
+
[`~HunyuanVideoPipelineOutput`] or `tuple`:
|
781 |
+
If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned,
|
782 |
+
otherwise a `tuple` is returned where the first element is a list with the generated images and the
|
783 |
+
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
784 |
+
"not-safe-for-work" (nsfw) content.
|
785 |
+
"""
|
786 |
+
callback = kwargs.pop("callback", None)
|
787 |
+
callback_steps = kwargs.pop("callback_steps", None)
|
788 |
+
|
789 |
+
if callback is not None:
|
790 |
+
deprecate(
|
791 |
+
"callback",
|
792 |
+
"1.0.0",
|
793 |
+
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
794 |
+
)
|
795 |
+
if callback_steps is not None:
|
796 |
+
deprecate(
|
797 |
+
"callback_steps",
|
798 |
+
"1.0.0",
|
799 |
+
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
800 |
+
)
|
801 |
+
|
802 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
803 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
804 |
+
|
805 |
+
# 0. Default height and width to unet
|
806 |
+
# height = height or self.transformer.config.sample_size * self.vae_scale_factor
|
807 |
+
# width = width or self.transformer.config.sample_size * self.vae_scale_factor
|
808 |
+
# to deal with lora scaling and other possible forward hooks
|
809 |
+
|
810 |
+
# 1. Check inputs. Raise error if not correct
|
811 |
+
self.check_inputs(
|
812 |
+
prompt,
|
813 |
+
height,
|
814 |
+
width,
|
815 |
+
video_length,
|
816 |
+
callback_steps,
|
817 |
+
negative_prompt,
|
818 |
+
prompt_embeds,
|
819 |
+
negative_prompt_embeds,
|
820 |
+
callback_on_step_end_tensor_inputs,
|
821 |
+
vae_ver=vae_ver,
|
822 |
+
)
|
823 |
+
|
824 |
+
self._guidance_scale = guidance_scale
|
825 |
+
self._guidance_rescale = guidance_rescale
|
826 |
+
self._clip_skip = clip_skip
|
827 |
+
self._cross_attention_kwargs = cross_attention_kwargs
|
828 |
+
self._interrupt = False
|
829 |
+
|
830 |
+
# 2. Define call parameters
|
831 |
+
if prompt is not None and isinstance(prompt, str):
|
832 |
+
batch_size = 1
|
833 |
+
elif prompt is not None and isinstance(prompt, list):
|
834 |
+
batch_size = len(prompt)
|
835 |
+
else:
|
836 |
+
batch_size = prompt_embeds.shape[0]
|
837 |
+
|
838 |
+
device = torch.device(f"cuda:{dist.get_rank()}") if dist.is_initialized() else self._execution_device
|
839 |
+
|
840 |
+
# 3. Encode input prompt
|
841 |
+
lora_scale = (
|
842 |
+
self.cross_attention_kwargs.get("scale", None)
|
843 |
+
if self.cross_attention_kwargs is not None
|
844 |
+
else None
|
845 |
+
)
|
846 |
+
|
847 |
+
(
|
848 |
+
prompt_embeds,
|
849 |
+
negative_prompt_embeds,
|
850 |
+
prompt_mask,
|
851 |
+
negative_prompt_mask,
|
852 |
+
) = self.encode_prompt(
|
853 |
+
prompt,
|
854 |
+
device,
|
855 |
+
num_videos_per_prompt,
|
856 |
+
self.do_classifier_free_guidance,
|
857 |
+
negative_prompt,
|
858 |
+
prompt_embeds=prompt_embeds,
|
859 |
+
attention_mask=attention_mask,
|
860 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
861 |
+
negative_attention_mask=negative_attention_mask,
|
862 |
+
lora_scale=lora_scale,
|
863 |
+
clip_skip=self.clip_skip,
|
864 |
+
data_type=data_type,
|
865 |
+
)
|
866 |
+
if self.text_encoder_2 is not None:
|
867 |
+
(
|
868 |
+
prompt_embeds_2,
|
869 |
+
negative_prompt_embeds_2,
|
870 |
+
prompt_mask_2,
|
871 |
+
negative_prompt_mask_2,
|
872 |
+
) = self.encode_prompt(
|
873 |
+
prompt,
|
874 |
+
device,
|
875 |
+
num_videos_per_prompt,
|
876 |
+
self.do_classifier_free_guidance,
|
877 |
+
negative_prompt,
|
878 |
+
prompt_embeds=None,
|
879 |
+
attention_mask=None,
|
880 |
+
negative_prompt_embeds=None,
|
881 |
+
negative_attention_mask=None,
|
882 |
+
lora_scale=lora_scale,
|
883 |
+
clip_skip=self.clip_skip,
|
884 |
+
text_encoder=self.text_encoder_2,
|
885 |
+
data_type=data_type,
|
886 |
+
)
|
887 |
+
else:
|
888 |
+
prompt_embeds_2 = None
|
889 |
+
negative_prompt_embeds_2 = None
|
890 |
+
prompt_mask_2 = None
|
891 |
+
negative_prompt_mask_2 = None
|
892 |
+
|
893 |
+
# For classifier free guidance, we need to do two forward passes.
|
894 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
895 |
+
# to avoid doing two forward passes
|
896 |
+
if self.do_classifier_free_guidance:
|
897 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
898 |
+
if prompt_mask is not None:
|
899 |
+
prompt_mask = torch.cat([negative_prompt_mask, prompt_mask])
|
900 |
+
if prompt_embeds_2 is not None:
|
901 |
+
prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2])
|
902 |
+
if prompt_mask_2 is not None:
|
903 |
+
prompt_mask_2 = torch.cat([negative_prompt_mask_2, prompt_mask_2])
|
904 |
+
|
905 |
+
|
906 |
+
# 4. Prepare timesteps
|
907 |
+
extra_set_timesteps_kwargs = self.prepare_extra_func_kwargs(
|
908 |
+
self.scheduler.set_timesteps, {"n_tokens": n_tokens}
|
909 |
+
)
|
910 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
911 |
+
self.scheduler,
|
912 |
+
num_inference_steps,
|
913 |
+
device,
|
914 |
+
timesteps,
|
915 |
+
sigmas,
|
916 |
+
**extra_set_timesteps_kwargs,
|
917 |
+
)
|
918 |
+
|
919 |
+
if "884" in vae_ver:
|
920 |
+
video_length = (video_length - 1) // 4 + 1
|
921 |
+
elif "888" in vae_ver:
|
922 |
+
video_length = (video_length - 1) // 8 + 1
|
923 |
+
else:
|
924 |
+
video_length = video_length
|
925 |
+
|
926 |
+
# 5. Prepare latent variables
|
927 |
+
num_channels_latents = self.transformer.config.in_channels
|
928 |
+
latents = self.prepare_latents(
|
929 |
+
batch_size * num_videos_per_prompt,
|
930 |
+
num_channels_latents,
|
931 |
+
height,
|
932 |
+
width,
|
933 |
+
video_length,
|
934 |
+
prompt_embeds.dtype,
|
935 |
+
device,
|
936 |
+
generator,
|
937 |
+
latents,
|
938 |
+
)
|
939 |
+
|
940 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
941 |
+
extra_step_kwargs = self.prepare_extra_func_kwargs(
|
942 |
+
self.scheduler.step,
|
943 |
+
{"generator": generator, "eta": eta},
|
944 |
+
)
|
945 |
+
|
946 |
+
target_dtype = PRECISION_TO_TYPE[self.args.precision]
|
947 |
+
autocast_enabled = (
|
948 |
+
target_dtype != torch.float32
|
949 |
+
) and not self.args.disable_autocast
|
950 |
+
vae_dtype = PRECISION_TO_TYPE[self.args.vae_precision]
|
951 |
+
vae_autocast_enabled = (
|
952 |
+
vae_dtype != torch.float32
|
953 |
+
) and not self.args.disable_autocast
|
954 |
+
|
955 |
+
# 7. Denoising loop
|
956 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
957 |
+
self._num_timesteps = len(timesteps)
|
958 |
+
|
959 |
+
# if is_progress_bar:
|
960 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
961 |
+
for i, t in enumerate(timesteps):
|
962 |
+
if self.interrupt:
|
963 |
+
continue
|
964 |
+
import os
|
965 |
+
if os.path.isfile("abort"):
|
966 |
+
continue
|
967 |
+
|
968 |
+
# expand the latents if we are doing classifier free guidance
|
969 |
+
latent_model_input = (
|
970 |
+
torch.cat([latents] * 2)
|
971 |
+
if self.do_classifier_free_guidance
|
972 |
+
else latents
|
973 |
+
)
|
974 |
+
latent_model_input = self.scheduler.scale_model_input(
|
975 |
+
latent_model_input, t
|
976 |
+
)
|
977 |
+
|
978 |
+
t_expand = t.repeat(latent_model_input.shape[0])
|
979 |
+
guidance_expand = (
|
980 |
+
torch.tensor(
|
981 |
+
[embedded_guidance_scale] * latent_model_input.shape[0],
|
982 |
+
dtype=torch.float32,
|
983 |
+
device=device,
|
984 |
+
).to(target_dtype)
|
985 |
+
* 1000.0
|
986 |
+
if embedded_guidance_scale is not None
|
987 |
+
else None
|
988 |
+
)
|
989 |
+
|
990 |
+
# predict the noise residual
|
991 |
+
with torch.autocast(
|
992 |
+
device_type="cuda", dtype=target_dtype, enabled=autocast_enabled
|
993 |
+
):
|
994 |
+
noise_pred = self.transformer( # For an input image (129, 192, 336) (1, 256, 256)
|
995 |
+
latent_model_input, # [2, 16, 33, 24, 42]
|
996 |
+
t_expand, # [2]
|
997 |
+
text_states=prompt_embeds, # [2, 256, 4096]
|
998 |
+
text_mask=prompt_mask, # [2, 256]
|
999 |
+
text_states_2=prompt_embeds_2, # [2, 768]
|
1000 |
+
freqs_cos=freqs_cis[0], # [seqlen, head_dim]
|
1001 |
+
freqs_sin=freqs_cis[1], # [seqlen, head_dim]
|
1002 |
+
guidance=guidance_expand,
|
1003 |
+
return_dict=True,
|
1004 |
+
)[
|
1005 |
+
"x"
|
1006 |
+
]
|
1007 |
+
|
1008 |
+
# perform guidance
|
1009 |
+
if self.do_classifier_free_guidance:
|
1010 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
1011 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (
|
1012 |
+
noise_pred_text - noise_pred_uncond
|
1013 |
+
)
|
1014 |
+
|
1015 |
+
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
|
1016 |
+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
1017 |
+
noise_pred = rescale_noise_cfg(
|
1018 |
+
noise_pred,
|
1019 |
+
noise_pred_text,
|
1020 |
+
guidance_rescale=self.guidance_rescale,
|
1021 |
+
)
|
1022 |
+
|
1023 |
+
# compute the previous noisy sample x_t -> x_t-1
|
1024 |
+
latents = self.scheduler.step(
|
1025 |
+
noise_pred, t, latents, **extra_step_kwargs, return_dict=False
|
1026 |
+
)[0]
|
1027 |
+
|
1028 |
+
if callback_on_step_end is not None:
|
1029 |
+
callback_kwargs = {}
|
1030 |
+
for k in callback_on_step_end_tensor_inputs:
|
1031 |
+
callback_kwargs[k] = locals()[k]
|
1032 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
1033 |
+
|
1034 |
+
latents = callback_outputs.pop("latents", latents)
|
1035 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
1036 |
+
negative_prompt_embeds = callback_outputs.pop(
|
1037 |
+
"negative_prompt_embeds", negative_prompt_embeds
|
1038 |
+
)
|
1039 |
+
|
1040 |
+
# call the callback, if provided
|
1041 |
+
if i == len(timesteps) - 1 or (
|
1042 |
+
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
1043 |
+
):
|
1044 |
+
if progress_bar is not None:
|
1045 |
+
progress_bar.update()
|
1046 |
+
if callback is not None and i % callback_steps == 0:
|
1047 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
1048 |
+
callback(step_idx, t, latents)
|
1049 |
+
|
1050 |
+
if not output_type == "latent":
|
1051 |
+
expand_temporal_dim = False
|
1052 |
+
if len(latents.shape) == 4:
|
1053 |
+
if isinstance(self.vae, AutoencoderKLCausal3D):
|
1054 |
+
latents = latents.unsqueeze(2)
|
1055 |
+
expand_temporal_dim = True
|
1056 |
+
elif len(latents.shape) == 5:
|
1057 |
+
pass
|
1058 |
+
else:
|
1059 |
+
raise ValueError(
|
1060 |
+
f"Only support latents with shape (b, c, h, w) or (b, c, f, h, w), but got {latents.shape}."
|
1061 |
+
)
|
1062 |
+
|
1063 |
+
if (
|
1064 |
+
hasattr(self.vae.config, "shift_factor")
|
1065 |
+
and self.vae.config.shift_factor
|
1066 |
+
):
|
1067 |
+
latents = (
|
1068 |
+
latents / self.vae.config.scaling_factor
|
1069 |
+
+ self.vae.config.shift_factor
|
1070 |
+
)
|
1071 |
+
else:
|
1072 |
+
latents = latents / self.vae.config.scaling_factor
|
1073 |
+
|
1074 |
+
with torch.autocast(
|
1075 |
+
device_type="cuda", dtype=vae_dtype, enabled=vae_autocast_enabled
|
1076 |
+
):
|
1077 |
+
if enable_tiling:
|
1078 |
+
self.vae.enable_tiling()
|
1079 |
+
image = self.vae.decode(
|
1080 |
+
latents, return_dict=False, generator=generator
|
1081 |
+
)[0]
|
1082 |
+
else:
|
1083 |
+
image = self.vae.decode(
|
1084 |
+
latents, return_dict=False, generator=generator
|
1085 |
+
)[0]
|
1086 |
+
|
1087 |
+
if expand_temporal_dim or image.shape[2] == 1:
|
1088 |
+
image = image.squeeze(2)
|
1089 |
+
|
1090 |
+
else:
|
1091 |
+
image = latents
|
1092 |
+
|
1093 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
1094 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
1095 |
+
image = image.cpu().float()
|
1096 |
+
|
1097 |
+
# Offload all models
|
1098 |
+
self.maybe_free_model_hooks()
|
1099 |
+
|
1100 |
+
if not return_dict:
|
1101 |
+
return image
|
1102 |
+
|
1103 |
+
return HunyuanVideoPipelineOutput(videos=image)
|
hyvideo/diffusion/schedulers/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .scheduling_flow_match_discrete import FlowMatchDiscreteScheduler
|
hyvideo/diffusion/schedulers/scheduling_flow_match_discrete.py
ADDED
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
#
|
16 |
+
# Modified from diffusers==0.29.2
|
17 |
+
#
|
18 |
+
# ==============================================================================
|
19 |
+
|
20 |
+
from dataclasses import dataclass
|
21 |
+
from typing import Optional, Tuple, Union
|
22 |
+
|
23 |
+
import numpy as np
|
24 |
+
import torch
|
25 |
+
|
26 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
27 |
+
from diffusers.utils import BaseOutput, logging
|
28 |
+
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
29 |
+
|
30 |
+
|
31 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
32 |
+
|
33 |
+
|
34 |
+
@dataclass
|
35 |
+
class FlowMatchDiscreteSchedulerOutput(BaseOutput):
|
36 |
+
"""
|
37 |
+
Output class for the scheduler's `step` function output.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
41 |
+
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
42 |
+
denoising loop.
|
43 |
+
"""
|
44 |
+
|
45 |
+
prev_sample: torch.FloatTensor
|
46 |
+
|
47 |
+
|
48 |
+
class FlowMatchDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
49 |
+
"""
|
50 |
+
Euler scheduler.
|
51 |
+
|
52 |
+
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
53 |
+
methods the library implements for all schedulers such as loading and saving.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
num_train_timesteps (`int`, defaults to 1000):
|
57 |
+
The number of diffusion steps to train the model.
|
58 |
+
timestep_spacing (`str`, defaults to `"linspace"`):
|
59 |
+
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
60 |
+
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
61 |
+
shift (`float`, defaults to 1.0):
|
62 |
+
The shift value for the timestep schedule.
|
63 |
+
reverse (`bool`, defaults to `True`):
|
64 |
+
Whether to reverse the timestep schedule.
|
65 |
+
"""
|
66 |
+
|
67 |
+
_compatibles = []
|
68 |
+
order = 1
|
69 |
+
|
70 |
+
@register_to_config
|
71 |
+
def __init__(
|
72 |
+
self,
|
73 |
+
num_train_timesteps: int = 1000,
|
74 |
+
shift: float = 1.0,
|
75 |
+
reverse: bool = True,
|
76 |
+
solver: str = "euler",
|
77 |
+
n_tokens: Optional[int] = None,
|
78 |
+
):
|
79 |
+
sigmas = torch.linspace(1, 0, num_train_timesteps + 1)
|
80 |
+
|
81 |
+
if not reverse:
|
82 |
+
sigmas = sigmas.flip(0)
|
83 |
+
|
84 |
+
self.sigmas = sigmas
|
85 |
+
# the value fed to model
|
86 |
+
self.timesteps = (sigmas[:-1] * num_train_timesteps).to(dtype=torch.float32)
|
87 |
+
|
88 |
+
self._step_index = None
|
89 |
+
self._begin_index = None
|
90 |
+
|
91 |
+
self.supported_solver = ["euler"]
|
92 |
+
if solver not in self.supported_solver:
|
93 |
+
raise ValueError(
|
94 |
+
f"Solver {solver} not supported. Supported solvers: {self.supported_solver}"
|
95 |
+
)
|
96 |
+
|
97 |
+
@property
|
98 |
+
def step_index(self):
|
99 |
+
"""
|
100 |
+
The index counter for current timestep. It will increase 1 after each scheduler step.
|
101 |
+
"""
|
102 |
+
return self._step_index
|
103 |
+
|
104 |
+
@property
|
105 |
+
def begin_index(self):
|
106 |
+
"""
|
107 |
+
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
108 |
+
"""
|
109 |
+
return self._begin_index
|
110 |
+
|
111 |
+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
112 |
+
def set_begin_index(self, begin_index: int = 0):
|
113 |
+
"""
|
114 |
+
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
115 |
+
|
116 |
+
Args:
|
117 |
+
begin_index (`int`):
|
118 |
+
The begin index for the scheduler.
|
119 |
+
"""
|
120 |
+
self._begin_index = begin_index
|
121 |
+
|
122 |
+
def _sigma_to_t(self, sigma):
|
123 |
+
return sigma * self.config.num_train_timesteps
|
124 |
+
|
125 |
+
def set_timesteps(
|
126 |
+
self,
|
127 |
+
num_inference_steps: int,
|
128 |
+
device: Union[str, torch.device] = None,
|
129 |
+
n_tokens: int = None,
|
130 |
+
):
|
131 |
+
"""
|
132 |
+
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
133 |
+
|
134 |
+
Args:
|
135 |
+
num_inference_steps (`int`):
|
136 |
+
The number of diffusion steps used when generating samples with a pre-trained model.
|
137 |
+
device (`str` or `torch.device`, *optional*):
|
138 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
139 |
+
n_tokens (`int`, *optional*):
|
140 |
+
Number of tokens in the input sequence.
|
141 |
+
"""
|
142 |
+
self.num_inference_steps = num_inference_steps
|
143 |
+
|
144 |
+
sigmas = torch.linspace(1, 0, num_inference_steps + 1)
|
145 |
+
sigmas = self.sd3_time_shift(sigmas)
|
146 |
+
|
147 |
+
if not self.config.reverse:
|
148 |
+
sigmas = 1 - sigmas
|
149 |
+
|
150 |
+
self.sigmas = sigmas
|
151 |
+
self.timesteps = (sigmas[:-1] * self.config.num_train_timesteps).to(
|
152 |
+
dtype=torch.float32, device=device
|
153 |
+
)
|
154 |
+
|
155 |
+
# Reset step index
|
156 |
+
self._step_index = None
|
157 |
+
|
158 |
+
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
159 |
+
if schedule_timesteps is None:
|
160 |
+
schedule_timesteps = self.timesteps
|
161 |
+
|
162 |
+
indices = (schedule_timesteps == timestep).nonzero()
|
163 |
+
|
164 |
+
# The sigma index that is taken for the **very** first `step`
|
165 |
+
# is always the second index (or the last index if there is only 1)
|
166 |
+
# This way we can ensure we don't accidentally skip a sigma in
|
167 |
+
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
168 |
+
pos = 1 if len(indices) > 1 else 0
|
169 |
+
|
170 |
+
return indices[pos].item()
|
171 |
+
|
172 |
+
def _init_step_index(self, timestep):
|
173 |
+
if self.begin_index is None:
|
174 |
+
if isinstance(timestep, torch.Tensor):
|
175 |
+
timestep = timestep.to(self.timesteps.device)
|
176 |
+
self._step_index = self.index_for_timestep(timestep)
|
177 |
+
else:
|
178 |
+
self._step_index = self._begin_index
|
179 |
+
|
180 |
+
def scale_model_input(
|
181 |
+
self, sample: torch.Tensor, timestep: Optional[int] = None
|
182 |
+
) -> torch.Tensor:
|
183 |
+
return sample
|
184 |
+
|
185 |
+
def sd3_time_shift(self, t: torch.Tensor):
|
186 |
+
return (self.config.shift * t) / (1 + (self.config.shift - 1) * t)
|
187 |
+
|
188 |
+
def step(
|
189 |
+
self,
|
190 |
+
model_output: torch.FloatTensor,
|
191 |
+
timestep: Union[float, torch.FloatTensor],
|
192 |
+
sample: torch.FloatTensor,
|
193 |
+
return_dict: bool = True,
|
194 |
+
) -> Union[FlowMatchDiscreteSchedulerOutput, Tuple]:
|
195 |
+
"""
|
196 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
197 |
+
process from the learned model outputs (most often the predicted noise).
|
198 |
+
|
199 |
+
Args:
|
200 |
+
model_output (`torch.FloatTensor`):
|
201 |
+
The direct output from learned diffusion model.
|
202 |
+
timestep (`float`):
|
203 |
+
The current discrete timestep in the diffusion chain.
|
204 |
+
sample (`torch.FloatTensor`):
|
205 |
+
A current instance of a sample created by the diffusion process.
|
206 |
+
generator (`torch.Generator`, *optional*):
|
207 |
+
A random number generator.
|
208 |
+
n_tokens (`int`, *optional*):
|
209 |
+
Number of tokens in the input sequence.
|
210 |
+
return_dict (`bool`):
|
211 |
+
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
|
212 |
+
tuple.
|
213 |
+
|
214 |
+
Returns:
|
215 |
+
[`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
|
216 |
+
If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
|
217 |
+
returned, otherwise a tuple is returned where the first element is the sample tensor.
|
218 |
+
"""
|
219 |
+
|
220 |
+
if (
|
221 |
+
isinstance(timestep, int)
|
222 |
+
or isinstance(timestep, torch.IntTensor)
|
223 |
+
or isinstance(timestep, torch.LongTensor)
|
224 |
+
):
|
225 |
+
raise ValueError(
|
226 |
+
(
|
227 |
+
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
228 |
+
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
|
229 |
+
" one of the `scheduler.timesteps` as a timestep."
|
230 |
+
),
|
231 |
+
)
|
232 |
+
|
233 |
+
if self.step_index is None:
|
234 |
+
self._init_step_index(timestep)
|
235 |
+
|
236 |
+
# Upcast to avoid precision issues when computing prev_sample
|
237 |
+
sample = sample.to(torch.float32)
|
238 |
+
|
239 |
+
dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index]
|
240 |
+
|
241 |
+
if self.config.solver == "euler":
|
242 |
+
prev_sample = sample + model_output.to(torch.float32) * dt
|
243 |
+
else:
|
244 |
+
raise ValueError(
|
245 |
+
f"Solver {self.config.solver} not supported. Supported solvers: {self.supported_solver}"
|
246 |
+
)
|
247 |
+
|
248 |
+
# upon completion increase step index by one
|
249 |
+
self._step_index += 1
|
250 |
+
|
251 |
+
if not return_dict:
|
252 |
+
return (prev_sample,)
|
253 |
+
|
254 |
+
return FlowMatchDiscreteSchedulerOutput(prev_sample=prev_sample)
|
255 |
+
|
256 |
+
def __len__(self):
|
257 |
+
return self.config.num_train_timesteps
|
hyvideo/inference.py
ADDED
@@ -0,0 +1,682 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import random
|
4 |
+
import functools
|
5 |
+
from typing import List, Optional, Tuple, Union
|
6 |
+
|
7 |
+
from pathlib import Path
|
8 |
+
from loguru import logger
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.distributed as dist
|
12 |
+
from hyvideo.constants import PROMPT_TEMPLATE, NEGATIVE_PROMPT, PRECISION_TO_TYPE
|
13 |
+
from hyvideo.vae import load_vae
|
14 |
+
from hyvideo.modules import load_model
|
15 |
+
from hyvideo.text_encoder import TextEncoder
|
16 |
+
from hyvideo.utils.data_utils import align_to
|
17 |
+
from hyvideo.modules.posemb_layers import get_nd_rotary_pos_embed
|
18 |
+
from hyvideo.diffusion.schedulers import FlowMatchDiscreteScheduler
|
19 |
+
from hyvideo.diffusion.pipelines import HunyuanVideoPipeline
|
20 |
+
|
21 |
+
try:
|
22 |
+
import xfuser
|
23 |
+
from xfuser.core.distributed import (
|
24 |
+
get_sequence_parallel_world_size,
|
25 |
+
get_sequence_parallel_rank,
|
26 |
+
get_sp_group,
|
27 |
+
initialize_model_parallel,
|
28 |
+
init_distributed_environment
|
29 |
+
)
|
30 |
+
except:
|
31 |
+
xfuser = None
|
32 |
+
get_sequence_parallel_world_size = None
|
33 |
+
get_sequence_parallel_rank = None
|
34 |
+
get_sp_group = None
|
35 |
+
initialize_model_parallel = None
|
36 |
+
init_distributed_environment = None
|
37 |
+
|
38 |
+
|
39 |
+
def parallelize_transformer(pipe):
|
40 |
+
transformer = pipe.transformer
|
41 |
+
original_forward = transformer.forward
|
42 |
+
|
43 |
+
@functools.wraps(transformer.__class__.forward)
|
44 |
+
def new_forward(
|
45 |
+
self,
|
46 |
+
x: torch.Tensor,
|
47 |
+
t: torch.Tensor, # Should be in range(0, 1000).
|
48 |
+
text_states: torch.Tensor = None,
|
49 |
+
text_mask: torch.Tensor = None, # Now we don't use it.
|
50 |
+
text_states_2: Optional[torch.Tensor] = None, # Text embedding for modulation.
|
51 |
+
freqs_cos: Optional[torch.Tensor] = None,
|
52 |
+
freqs_sin: Optional[torch.Tensor] = None,
|
53 |
+
guidance: torch.Tensor = None, # Guidance for modulation, should be cfg_scale x 1000.
|
54 |
+
return_dict: bool = True,
|
55 |
+
):
|
56 |
+
if x.shape[-2] // 2 % get_sequence_parallel_world_size() == 0:
|
57 |
+
# try to split x by height
|
58 |
+
split_dim = -2
|
59 |
+
elif x.shape[-1] // 2 % get_sequence_parallel_world_size() == 0:
|
60 |
+
# try to split x by width
|
61 |
+
split_dim = -1
|
62 |
+
else:
|
63 |
+
raise ValueError(f"Cannot split video sequence into ulysses_degree x ring_degree ({get_sequence_parallel_world_size()}) parts evenly")
|
64 |
+
|
65 |
+
# patch sizes for the temporal, height, and width dimensions are 1, 2, and 2.
|
66 |
+
temporal_size, h, w = x.shape[2], x.shape[3] // 2, x.shape[4] // 2
|
67 |
+
|
68 |
+
x = torch.chunk(x, get_sequence_parallel_world_size(),dim=split_dim)[get_sequence_parallel_rank()]
|
69 |
+
|
70 |
+
dim_thw = freqs_cos.shape[-1]
|
71 |
+
freqs_cos = freqs_cos.reshape(temporal_size, h, w, dim_thw)
|
72 |
+
freqs_cos = torch.chunk(freqs_cos, get_sequence_parallel_world_size(),dim=split_dim - 1)[get_sequence_parallel_rank()]
|
73 |
+
freqs_cos = freqs_cos.reshape(-1, dim_thw)
|
74 |
+
dim_thw = freqs_sin.shape[-1]
|
75 |
+
freqs_sin = freqs_sin.reshape(temporal_size, h, w, dim_thw)
|
76 |
+
freqs_sin = torch.chunk(freqs_sin, get_sequence_parallel_world_size(),dim=split_dim - 1)[get_sequence_parallel_rank()]
|
77 |
+
freqs_sin = freqs_sin.reshape(-1, dim_thw)
|
78 |
+
|
79 |
+
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
|
80 |
+
|
81 |
+
for block in transformer.double_blocks + transformer.single_blocks:
|
82 |
+
block.hybrid_seq_parallel_attn = xFuserLongContextAttention()
|
83 |
+
|
84 |
+
output = original_forward(
|
85 |
+
x,
|
86 |
+
t,
|
87 |
+
text_states,
|
88 |
+
text_mask,
|
89 |
+
text_states_2,
|
90 |
+
freqs_cos,
|
91 |
+
freqs_sin,
|
92 |
+
guidance,
|
93 |
+
return_dict,
|
94 |
+
)
|
95 |
+
|
96 |
+
return_dict = not isinstance(output, tuple)
|
97 |
+
sample = output["x"]
|
98 |
+
sample = get_sp_group().all_gather(sample, dim=split_dim)
|
99 |
+
output["x"] = sample
|
100 |
+
return output
|
101 |
+
|
102 |
+
new_forward = new_forward.__get__(transformer)
|
103 |
+
transformer.forward = new_forward
|
104 |
+
|
105 |
+
|
106 |
+
class Inference(object):
|
107 |
+
def __init__(
|
108 |
+
self,
|
109 |
+
args,
|
110 |
+
vae,
|
111 |
+
vae_kwargs,
|
112 |
+
text_encoder,
|
113 |
+
model,
|
114 |
+
text_encoder_2=None,
|
115 |
+
pipeline=None,
|
116 |
+
use_cpu_offload=False,
|
117 |
+
device=None,
|
118 |
+
logger=None,
|
119 |
+
parallel_args=None,
|
120 |
+
):
|
121 |
+
self.vae = vae
|
122 |
+
self.vae_kwargs = vae_kwargs
|
123 |
+
|
124 |
+
self.text_encoder = text_encoder
|
125 |
+
self.text_encoder_2 = text_encoder_2
|
126 |
+
|
127 |
+
self.model = model
|
128 |
+
self.pipeline = pipeline
|
129 |
+
self.use_cpu_offload = use_cpu_offload
|
130 |
+
|
131 |
+
self.args = args
|
132 |
+
self.device = (
|
133 |
+
device
|
134 |
+
if device is not None
|
135 |
+
else "cuda"
|
136 |
+
if torch.cuda.is_available()
|
137 |
+
else "cpu"
|
138 |
+
)
|
139 |
+
self.logger = logger
|
140 |
+
self.parallel_args = parallel_args
|
141 |
+
|
142 |
+
@classmethod
|
143 |
+
def from_pretrained(cls, pretrained_model_path, text_encoder_path, args, device=None, **kwargs):
|
144 |
+
"""
|
145 |
+
Initialize the Inference pipeline.
|
146 |
+
|
147 |
+
Args:
|
148 |
+
pretrained_model_path (str or pathlib.Path): The model path, including t2v, text encoder and vae checkpoints.
|
149 |
+
args (argparse.Namespace): The arguments for the pipeline.
|
150 |
+
device (int): The device for inference. Default is 0.
|
151 |
+
"""
|
152 |
+
# ========================================================================
|
153 |
+
logger.info(f"Got text-to-video model root path: {pretrained_model_path}")
|
154 |
+
|
155 |
+
# ==================== Initialize Distributed Environment ================
|
156 |
+
if args.ulysses_degree > 1 or args.ring_degree > 1:
|
157 |
+
assert xfuser is not None, \
|
158 |
+
"Ulysses Attention and Ring Attention requires xfuser package."
|
159 |
+
|
160 |
+
assert args.use_cpu_offload is False, \
|
161 |
+
"Cannot enable use_cpu_offload in the distributed environment."
|
162 |
+
|
163 |
+
dist.init_process_group("nccl")
|
164 |
+
|
165 |
+
assert dist.get_world_size() == args.ring_degree * args.ulysses_degree, \
|
166 |
+
"number of GPUs should be equal to ring_degree * ulysses_degree."
|
167 |
+
|
168 |
+
init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size())
|
169 |
+
|
170 |
+
initialize_model_parallel(
|
171 |
+
sequence_parallel_degree=dist.get_world_size(),
|
172 |
+
ring_degree=args.ring_degree,
|
173 |
+
ulysses_degree=args.ulysses_degree,
|
174 |
+
)
|
175 |
+
device = torch.device(f"cuda:{os.environ['LOCAL_RANK']}")
|
176 |
+
else:
|
177 |
+
if device is None:
|
178 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
179 |
+
|
180 |
+
parallel_args = {"ulysses_degree": args.ulysses_degree, "ring_degree": args.ring_degree}
|
181 |
+
|
182 |
+
# ======================== Get the args path =============================
|
183 |
+
|
184 |
+
# Disable gradient
|
185 |
+
torch.set_grad_enabled(False)
|
186 |
+
|
187 |
+
# =========================== Build main model ===========================
|
188 |
+
logger.info("Building model...")
|
189 |
+
pinToMemory = kwargs.pop("pinToMemory")
|
190 |
+
partialPinning = kwargs.pop("partialPinning")
|
191 |
+
# factor_kwargs = {"device": device, "dtype": PRECISION_TO_TYPE[args.precision]}
|
192 |
+
factor_kwargs = kwargs | {"device": "meta", "dtype": PRECISION_TO_TYPE[args.precision]}
|
193 |
+
in_channels = args.latent_channels
|
194 |
+
out_channels = args.latent_channels
|
195 |
+
|
196 |
+
model = load_model(
|
197 |
+
args,
|
198 |
+
in_channels=in_channels,
|
199 |
+
out_channels=out_channels,
|
200 |
+
factor_kwargs=factor_kwargs,
|
201 |
+
)
|
202 |
+
# model = model.to(device)
|
203 |
+
# model = Inference.load_state_dict(args, model, pretrained_model_path)
|
204 |
+
|
205 |
+
logger.info(f"Loading torch model {pretrained_model_path}...")
|
206 |
+
|
207 |
+
from mmgp import offload
|
208 |
+
offload.load_model_data(model, pretrained_model_path, pinToMemory = pinToMemory, partialPinning = partialPinning)
|
209 |
+
|
210 |
+
model.eval()
|
211 |
+
|
212 |
+
# ============================= Build extra models ========================
|
213 |
+
# VAE
|
214 |
+
vae, _, s_ratio, t_ratio = load_vae(
|
215 |
+
args.vae,
|
216 |
+
args.vae_precision,
|
217 |
+
logger=logger,
|
218 |
+
device=device if not args.use_cpu_offload else "cpu",
|
219 |
+
)
|
220 |
+
vae_kwargs = {"s_ratio": s_ratio, "t_ratio": t_ratio}
|
221 |
+
|
222 |
+
# Text encoder
|
223 |
+
if args.prompt_template_video is not None:
|
224 |
+
crop_start = PROMPT_TEMPLATE[args.prompt_template_video].get(
|
225 |
+
"crop_start", 0
|
226 |
+
)
|
227 |
+
elif args.prompt_template is not None:
|
228 |
+
crop_start = PROMPT_TEMPLATE[args.prompt_template].get("crop_start", 0)
|
229 |
+
else:
|
230 |
+
crop_start = 0
|
231 |
+
max_length = args.text_len + crop_start
|
232 |
+
|
233 |
+
# prompt_template
|
234 |
+
prompt_template = (
|
235 |
+
PROMPT_TEMPLATE[args.prompt_template]
|
236 |
+
if args.prompt_template is not None
|
237 |
+
else None
|
238 |
+
)
|
239 |
+
|
240 |
+
# prompt_template_video
|
241 |
+
prompt_template_video = (
|
242 |
+
PROMPT_TEMPLATE[args.prompt_template_video]
|
243 |
+
if args.prompt_template_video is not None
|
244 |
+
else None
|
245 |
+
)
|
246 |
+
|
247 |
+
text_encoder = TextEncoder(
|
248 |
+
text_encoder_type=args.text_encoder,
|
249 |
+
max_length=max_length,
|
250 |
+
text_encoder_precision=args.text_encoder_precision,
|
251 |
+
tokenizer_type=args.tokenizer,
|
252 |
+
prompt_template=prompt_template,
|
253 |
+
prompt_template_video=prompt_template_video,
|
254 |
+
hidden_state_skip_layer=args.hidden_state_skip_layer,
|
255 |
+
apply_final_norm=args.apply_final_norm,
|
256 |
+
reproduce=args.reproduce,
|
257 |
+
logger=logger,
|
258 |
+
device=device if not args.use_cpu_offload else "cpu",
|
259 |
+
text_encoder_path = text_encoder_path
|
260 |
+
)
|
261 |
+
text_encoder_2 = None
|
262 |
+
if args.text_encoder_2 is not None:
|
263 |
+
text_encoder_2 = TextEncoder(
|
264 |
+
text_encoder_type=args.text_encoder_2,
|
265 |
+
max_length=args.text_len_2,
|
266 |
+
text_encoder_precision=args.text_encoder_precision_2,
|
267 |
+
tokenizer_type=args.tokenizer_2,
|
268 |
+
reproduce=args.reproduce,
|
269 |
+
logger=logger,
|
270 |
+
device=device if not args.use_cpu_offload else "cpu",
|
271 |
+
)
|
272 |
+
|
273 |
+
return cls(
|
274 |
+
args=args,
|
275 |
+
vae=vae,
|
276 |
+
vae_kwargs=vae_kwargs,
|
277 |
+
text_encoder=text_encoder,
|
278 |
+
text_encoder_2=text_encoder_2,
|
279 |
+
model=model,
|
280 |
+
use_cpu_offload=args.use_cpu_offload,
|
281 |
+
device=device,
|
282 |
+
logger=logger,
|
283 |
+
parallel_args=parallel_args
|
284 |
+
)
|
285 |
+
|
286 |
+
@staticmethod
|
287 |
+
def load_state_dict(args, model, pretrained_model_path):
|
288 |
+
load_key = args.load_key
|
289 |
+
dit_weight = Path(args.dit_weight)
|
290 |
+
|
291 |
+
if dit_weight is None:
|
292 |
+
model_dir = pretrained_model_path / f"t2v_{args.model_resolution}"
|
293 |
+
files = list(model_dir.glob("*.pt"))
|
294 |
+
if len(files) == 0:
|
295 |
+
raise ValueError(f"No model weights found in {model_dir}")
|
296 |
+
if str(files[0]).startswith("pytorch_model_"):
|
297 |
+
model_path = dit_weight / f"pytorch_model_{load_key}.pt"
|
298 |
+
bare_model = True
|
299 |
+
elif any(str(f).endswith("_model_states.pt") for f in files):
|
300 |
+
files = [f for f in files if str(f).endswith("_model_states.pt")]
|
301 |
+
model_path = files[0]
|
302 |
+
if len(files) > 1:
|
303 |
+
logger.warning(
|
304 |
+
f"Multiple model weights found in {dit_weight}, using {model_path}"
|
305 |
+
)
|
306 |
+
bare_model = False
|
307 |
+
else:
|
308 |
+
raise ValueError(
|
309 |
+
f"Invalid model path: {dit_weight} with unrecognized weight format: "
|
310 |
+
f"{list(map(str, files))}. When given a directory as --dit-weight, only "
|
311 |
+
f"`pytorch_model_*.pt`(provided by HunyuanDiT official) and "
|
312 |
+
f"`*_model_states.pt`(saved by deepspeed) can be parsed. If you want to load a "
|
313 |
+
f"specific weight file, please provide the full path to the file."
|
314 |
+
)
|
315 |
+
else:
|
316 |
+
if dit_weight.is_dir():
|
317 |
+
files = list(dit_weight.glob("*.pt"))
|
318 |
+
if len(files) == 0:
|
319 |
+
raise ValueError(f"No model weights found in {dit_weight}")
|
320 |
+
if str(files[0]).startswith("pytorch_model_"):
|
321 |
+
model_path = dit_weight / f"pytorch_model_{load_key}.pt"
|
322 |
+
bare_model = True
|
323 |
+
elif any(str(f).endswith("_model_states.pt") for f in files):
|
324 |
+
files = [f for f in files if str(f).endswith("_model_states.pt")]
|
325 |
+
model_path = files[0]
|
326 |
+
if len(files) > 1:
|
327 |
+
logger.warning(
|
328 |
+
f"Multiple model weights found in {dit_weight}, using {model_path}"
|
329 |
+
)
|
330 |
+
bare_model = False
|
331 |
+
else:
|
332 |
+
raise ValueError(
|
333 |
+
f"Invalid model path: {dit_weight} with unrecognized weight format: "
|
334 |
+
f"{list(map(str, files))}. When given a directory as --dit-weight, only "
|
335 |
+
f"`pytorch_model_*.pt`(provided by HunyuanDiT official) and "
|
336 |
+
f"`*_model_states.pt`(saved by deepspeed) can be parsed. If you want to load a "
|
337 |
+
f"specific weight file, please provide the full path to the file."
|
338 |
+
)
|
339 |
+
elif dit_weight.is_file():
|
340 |
+
model_path = dit_weight
|
341 |
+
bare_model = "unknown"
|
342 |
+
else:
|
343 |
+
raise ValueError(f"Invalid model path: {dit_weight}")
|
344 |
+
|
345 |
+
if not model_path.exists():
|
346 |
+
raise ValueError(f"model_path not exists: {model_path}")
|
347 |
+
logger.info(f"Loading torch model {model_path}...")
|
348 |
+
state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)
|
349 |
+
|
350 |
+
if bare_model == "unknown" and ("ema" in state_dict or "module" in state_dict):
|
351 |
+
bare_model = False
|
352 |
+
if bare_model is False:
|
353 |
+
if load_key in state_dict:
|
354 |
+
state_dict = state_dict[load_key]
|
355 |
+
else:
|
356 |
+
raise KeyError(
|
357 |
+
f"Missing key: `{load_key}` in the checkpoint: {model_path}. The keys in the checkpoint "
|
358 |
+
f"are: {list(state_dict.keys())}."
|
359 |
+
)
|
360 |
+
model.load_state_dict(state_dict, strict=True, assign = True )
|
361 |
+
return model
|
362 |
+
|
363 |
+
@staticmethod
|
364 |
+
def parse_size(size):
|
365 |
+
if isinstance(size, int):
|
366 |
+
size = [size]
|
367 |
+
if not isinstance(size, (list, tuple)):
|
368 |
+
raise ValueError(f"Size must be an integer or (height, width), got {size}.")
|
369 |
+
if len(size) == 1:
|
370 |
+
size = [size[0], size[0]]
|
371 |
+
if len(size) != 2:
|
372 |
+
raise ValueError(f"Size must be an integer or (height, width), got {size}.")
|
373 |
+
return size
|
374 |
+
|
375 |
+
|
376 |
+
class HunyuanVideoSampler(Inference):
|
377 |
+
def __init__(
|
378 |
+
self,
|
379 |
+
args,
|
380 |
+
vae,
|
381 |
+
vae_kwargs,
|
382 |
+
text_encoder,
|
383 |
+
model,
|
384 |
+
text_encoder_2=None,
|
385 |
+
pipeline=None,
|
386 |
+
use_cpu_offload=False,
|
387 |
+
device=0,
|
388 |
+
logger=None,
|
389 |
+
parallel_args=None
|
390 |
+
):
|
391 |
+
super().__init__(
|
392 |
+
args,
|
393 |
+
vae,
|
394 |
+
vae_kwargs,
|
395 |
+
text_encoder,
|
396 |
+
model,
|
397 |
+
text_encoder_2=text_encoder_2,
|
398 |
+
pipeline=pipeline,
|
399 |
+
use_cpu_offload=use_cpu_offload,
|
400 |
+
device=device,
|
401 |
+
logger=logger,
|
402 |
+
parallel_args=parallel_args
|
403 |
+
)
|
404 |
+
|
405 |
+
self.pipeline = self.load_diffusion_pipeline(
|
406 |
+
args=args,
|
407 |
+
vae=self.vae,
|
408 |
+
text_encoder=self.text_encoder,
|
409 |
+
text_encoder_2=self.text_encoder_2,
|
410 |
+
model=self.model,
|
411 |
+
device=self.device,
|
412 |
+
)
|
413 |
+
|
414 |
+
self.default_negative_prompt = NEGATIVE_PROMPT
|
415 |
+
|
416 |
+
def load_diffusion_pipeline(
|
417 |
+
self,
|
418 |
+
args,
|
419 |
+
vae,
|
420 |
+
text_encoder,
|
421 |
+
text_encoder_2,
|
422 |
+
model,
|
423 |
+
scheduler=None,
|
424 |
+
device=None,
|
425 |
+
progress_bar_config=None,
|
426 |
+
data_type="video",
|
427 |
+
):
|
428 |
+
"""Load the denoising scheduler for inference."""
|
429 |
+
if scheduler is None:
|
430 |
+
if args.denoise_type == "flow":
|
431 |
+
scheduler = FlowMatchDiscreteScheduler(
|
432 |
+
shift=args.flow_shift,
|
433 |
+
reverse=args.flow_reverse,
|
434 |
+
solver=args.flow_solver,
|
435 |
+
)
|
436 |
+
else:
|
437 |
+
raise ValueError(f"Invalid denoise type {args.denoise_type}")
|
438 |
+
|
439 |
+
pipeline = HunyuanVideoPipeline(
|
440 |
+
vae=vae,
|
441 |
+
text_encoder=text_encoder,
|
442 |
+
text_encoder_2=text_encoder_2,
|
443 |
+
transformer=model,
|
444 |
+
scheduler=scheduler,
|
445 |
+
progress_bar_config=progress_bar_config,
|
446 |
+
args=args,
|
447 |
+
)
|
448 |
+
# if self.use_cpu_offload:
|
449 |
+
# pipeline.enable_sequential_cpu_offload()
|
450 |
+
# else:
|
451 |
+
# pipeline = pipeline.to(device)
|
452 |
+
|
453 |
+
return pipeline
|
454 |
+
|
455 |
+
def get_rotary_pos_embed(self, video_length, height, width):
|
456 |
+
target_ndim = 3
|
457 |
+
ndim = 5 - 2
|
458 |
+
# 884
|
459 |
+
if "884" in self.args.vae:
|
460 |
+
latents_size = [(video_length - 1) // 4 + 1, height // 8, width // 8]
|
461 |
+
elif "888" in self.args.vae:
|
462 |
+
latents_size = [(video_length - 1) // 8 + 1, height // 8, width // 8]
|
463 |
+
else:
|
464 |
+
latents_size = [video_length, height // 8, width // 8]
|
465 |
+
|
466 |
+
if isinstance(self.model.patch_size, int):
|
467 |
+
assert all(s % self.model.patch_size == 0 for s in latents_size), (
|
468 |
+
f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), "
|
469 |
+
f"but got {latents_size}."
|
470 |
+
)
|
471 |
+
rope_sizes = [s // self.model.patch_size for s in latents_size]
|
472 |
+
elif isinstance(self.model.patch_size, list):
|
473 |
+
assert all(
|
474 |
+
s % self.model.patch_size[idx] == 0
|
475 |
+
for idx, s in enumerate(latents_size)
|
476 |
+
), (
|
477 |
+
f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), "
|
478 |
+
f"but got {latents_size}."
|
479 |
+
)
|
480 |
+
rope_sizes = [
|
481 |
+
s // self.model.patch_size[idx] for idx, s in enumerate(latents_size)
|
482 |
+
]
|
483 |
+
|
484 |
+
if len(rope_sizes) != target_ndim:
|
485 |
+
rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis
|
486 |
+
head_dim = self.model.hidden_size // self.model.heads_num
|
487 |
+
rope_dim_list = self.model.rope_dim_list
|
488 |
+
if rope_dim_list is None:
|
489 |
+
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
|
490 |
+
assert (
|
491 |
+
sum(rope_dim_list) == head_dim
|
492 |
+
), "sum(rope_dim_list) should equal to head_dim of attention layer"
|
493 |
+
freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
|
494 |
+
rope_dim_list,
|
495 |
+
rope_sizes,
|
496 |
+
theta=self.args.rope_theta,
|
497 |
+
use_real=True,
|
498 |
+
theta_rescale_factor=1,
|
499 |
+
)
|
500 |
+
return freqs_cos, freqs_sin
|
501 |
+
|
502 |
+
@torch.no_grad()
|
503 |
+
def predict(
|
504 |
+
self,
|
505 |
+
prompt,
|
506 |
+
height=192,
|
507 |
+
width=336,
|
508 |
+
video_length=129,
|
509 |
+
seed=None,
|
510 |
+
negative_prompt=None,
|
511 |
+
infer_steps=50,
|
512 |
+
guidance_scale=6,
|
513 |
+
flow_shift=5.0,
|
514 |
+
embedded_guidance_scale=None,
|
515 |
+
batch_size=1,
|
516 |
+
num_videos_per_prompt=1,
|
517 |
+
**kwargs,
|
518 |
+
):
|
519 |
+
"""
|
520 |
+
Predict the image/video from the given text.
|
521 |
+
|
522 |
+
Args:
|
523 |
+
prompt (str or List[str]): The input text.
|
524 |
+
kwargs:
|
525 |
+
height (int): The height of the output video. Default is 192.
|
526 |
+
width (int): The width of the output video. Default is 336.
|
527 |
+
video_length (int): The frame number of the output video. Default is 129.
|
528 |
+
seed (int or List[str]): The random seed for the generation. Default is a random integer.
|
529 |
+
negative_prompt (str or List[str]): The negative text prompt. Default is an empty string.
|
530 |
+
guidance_scale (float): The guidance scale for the generation. Default is 6.0.
|
531 |
+
num_images_per_prompt (int): The number of images per prompt. Default is 1.
|
532 |
+
infer_steps (int): The number of inference steps. Default is 100.
|
533 |
+
"""
|
534 |
+
if self.parallel_args['ulysses_degree'] > 1 or self.parallel_args['ring_degree'] > 1:
|
535 |
+
assert seed is not None, \
|
536 |
+
"You have to set a seed in the distributed environment, please rerun with --seed <your-seed>."
|
537 |
+
|
538 |
+
parallelize_transformer(self.pipeline)
|
539 |
+
|
540 |
+
out_dict = dict()
|
541 |
+
|
542 |
+
# ========================================================================
|
543 |
+
# Arguments: seed
|
544 |
+
# ========================================================================
|
545 |
+
if isinstance(seed, torch.Tensor):
|
546 |
+
seed = seed.tolist()
|
547 |
+
if seed is None:
|
548 |
+
seeds = [
|
549 |
+
random.randint(0, 1_000_000)
|
550 |
+
for _ in range(batch_size * num_videos_per_prompt)
|
551 |
+
]
|
552 |
+
elif isinstance(seed, int):
|
553 |
+
seeds = [
|
554 |
+
seed + i
|
555 |
+
for _ in range(batch_size)
|
556 |
+
for i in range(num_videos_per_prompt)
|
557 |
+
]
|
558 |
+
elif isinstance(seed, (list, tuple)):
|
559 |
+
if len(seed) == batch_size:
|
560 |
+
seeds = [
|
561 |
+
int(seed[i]) + j
|
562 |
+
for i in range(batch_size)
|
563 |
+
for j in range(num_videos_per_prompt)
|
564 |
+
]
|
565 |
+
elif len(seed) == batch_size * num_videos_per_prompt:
|
566 |
+
seeds = [int(s) for s in seed]
|
567 |
+
else:
|
568 |
+
raise ValueError(
|
569 |
+
f"Length of seed must be equal to number of prompt(batch_size) or "
|
570 |
+
f"batch_size * num_videos_per_prompt ({batch_size} * {num_videos_per_prompt}), got {seed}."
|
571 |
+
)
|
572 |
+
else:
|
573 |
+
raise ValueError(
|
574 |
+
f"Seed must be an integer, a list of integers, or None, got {seed}."
|
575 |
+
)
|
576 |
+
generator = [torch.Generator(self.device).manual_seed(seed) for seed in seeds]
|
577 |
+
out_dict["seeds"] = seeds
|
578 |
+
|
579 |
+
# ========================================================================
|
580 |
+
# Arguments: target_width, target_height, target_video_length
|
581 |
+
# ========================================================================
|
582 |
+
if width <= 0 or height <= 0 or video_length <= 0:
|
583 |
+
raise ValueError(
|
584 |
+
f"`height` and `width` and `video_length` must be positive integers, got height={height}, width={width}, video_length={video_length}"
|
585 |
+
)
|
586 |
+
if (video_length - 1) % 4 != 0:
|
587 |
+
raise ValueError(
|
588 |
+
f"`video_length-1` must be a multiple of 4, got {video_length}"
|
589 |
+
)
|
590 |
+
|
591 |
+
logger.info(
|
592 |
+
f"Input (height, width, video_length) = ({height}, {width}, {video_length})"
|
593 |
+
)
|
594 |
+
|
595 |
+
target_height = align_to(height, 16)
|
596 |
+
target_width = align_to(width, 16)
|
597 |
+
target_video_length = video_length
|
598 |
+
|
599 |
+
out_dict["size"] = (target_height, target_width, target_video_length)
|
600 |
+
|
601 |
+
# ========================================================================
|
602 |
+
# Arguments: prompt, new_prompt, negative_prompt
|
603 |
+
# ========================================================================
|
604 |
+
if not isinstance(prompt, str):
|
605 |
+
raise TypeError(f"`prompt` must be a string, but got {type(prompt)}")
|
606 |
+
prompt = [prompt.strip()]
|
607 |
+
|
608 |
+
# negative prompt
|
609 |
+
if negative_prompt is None or negative_prompt == "":
|
610 |
+
negative_prompt = self.default_negative_prompt
|
611 |
+
if not isinstance(negative_prompt, str):
|
612 |
+
raise TypeError(
|
613 |
+
f"`negative_prompt` must be a string, but got {type(negative_prompt)}"
|
614 |
+
)
|
615 |
+
negative_prompt = [negative_prompt.strip()]
|
616 |
+
|
617 |
+
# ========================================================================
|
618 |
+
# Scheduler
|
619 |
+
# ========================================================================
|
620 |
+
scheduler = FlowMatchDiscreteScheduler(
|
621 |
+
shift=flow_shift,
|
622 |
+
reverse=self.args.flow_reverse,
|
623 |
+
solver=self.args.flow_solver
|
624 |
+
)
|
625 |
+
self.pipeline.scheduler = scheduler
|
626 |
+
|
627 |
+
# ========================================================================
|
628 |
+
# Build Rope freqs
|
629 |
+
# ========================================================================
|
630 |
+
freqs_cos, freqs_sin = self.get_rotary_pos_embed(
|
631 |
+
target_video_length, target_height, target_width
|
632 |
+
)
|
633 |
+
n_tokens = freqs_cos.shape[0]
|
634 |
+
|
635 |
+
# ========================================================================
|
636 |
+
# Print infer args
|
637 |
+
# ========================================================================
|
638 |
+
debug_str = f"""
|
639 |
+
height: {target_height}
|
640 |
+
width: {target_width}
|
641 |
+
video_length: {target_video_length}
|
642 |
+
prompt: {prompt}
|
643 |
+
neg_prompt: {negative_prompt}
|
644 |
+
seed: {seed}
|
645 |
+
infer_steps: {infer_steps}
|
646 |
+
num_videos_per_prompt: {num_videos_per_prompt}
|
647 |
+
guidance_scale: {guidance_scale}
|
648 |
+
n_tokens: {n_tokens}
|
649 |
+
flow_shift: {flow_shift}
|
650 |
+
embedded_guidance_scale: {embedded_guidance_scale}"""
|
651 |
+
logger.debug(debug_str)
|
652 |
+
|
653 |
+
# ========================================================================
|
654 |
+
# Pipeline inference
|
655 |
+
# ========================================================================
|
656 |
+
start_time = time.time()
|
657 |
+
samples = self.pipeline(
|
658 |
+
prompt=prompt,
|
659 |
+
height=target_height,
|
660 |
+
width=target_width,
|
661 |
+
video_length=target_video_length,
|
662 |
+
num_inference_steps=infer_steps,
|
663 |
+
guidance_scale=guidance_scale,
|
664 |
+
negative_prompt=negative_prompt,
|
665 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
666 |
+
generator=generator,
|
667 |
+
output_type="pil",
|
668 |
+
freqs_cis=(freqs_cos, freqs_sin),
|
669 |
+
n_tokens=n_tokens,
|
670 |
+
embedded_guidance_scale=embedded_guidance_scale,
|
671 |
+
data_type="video" if target_video_length > 1 else "image",
|
672 |
+
is_progress_bar=True,
|
673 |
+
vae_ver=self.args.vae,
|
674 |
+
enable_tiling=self.args.vae_tiling,
|
675 |
+
)[0]
|
676 |
+
out_dict["samples"] = samples
|
677 |
+
out_dict["prompts"] = prompt
|
678 |
+
|
679 |
+
gen_time = time.time() - start_time
|
680 |
+
logger.info(f"Success, time: {gen_time}")
|
681 |
+
|
682 |
+
return out_dict
|
hyvideo/modules/__init__.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .models import HYVideoDiffusionTransformer, HUNYUAN_VIDEO_CONFIG
|
2 |
+
|
3 |
+
|
4 |
+
def load_model(args, in_channels, out_channels, factor_kwargs):
|
5 |
+
"""load hunyuan video model
|
6 |
+
|
7 |
+
Args:
|
8 |
+
args (dict): model args
|
9 |
+
in_channels (int): input channels number
|
10 |
+
out_channels (int): output channels number
|
11 |
+
factor_kwargs (dict): factor kwargs
|
12 |
+
|
13 |
+
Returns:
|
14 |
+
model (nn.Module): The hunyuan video model
|
15 |
+
"""
|
16 |
+
if args.model in HUNYUAN_VIDEO_CONFIG.keys():
|
17 |
+
model = HYVideoDiffusionTransformer(
|
18 |
+
args,
|
19 |
+
in_channels=in_channels,
|
20 |
+
out_channels=out_channels,
|
21 |
+
**HUNYUAN_VIDEO_CONFIG[args.model],
|
22 |
+
**factor_kwargs,
|
23 |
+
)
|
24 |
+
return model
|
25 |
+
else:
|
26 |
+
raise NotImplementedError()
|
hyvideo/modules/activation_layers.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
|
4 |
+
def get_activation_layer(act_type):
|
5 |
+
"""get activation layer
|
6 |
+
|
7 |
+
Args:
|
8 |
+
act_type (str): the activation type
|
9 |
+
|
10 |
+
Returns:
|
11 |
+
torch.nn.functional: the activation layer
|
12 |
+
"""
|
13 |
+
if act_type == "gelu":
|
14 |
+
return lambda: nn.GELU()
|
15 |
+
elif act_type == "gelu_tanh":
|
16 |
+
# Approximate `tanh` requires torch >= 1.13
|
17 |
+
return lambda: nn.GELU(approximate="tanh")
|
18 |
+
elif act_type == "relu":
|
19 |
+
return nn.ReLU
|
20 |
+
elif act_type == "silu":
|
21 |
+
return nn.SiLU
|
22 |
+
else:
|
23 |
+
raise ValueError(f"Unknown activation type: {act_type}")
|
hyvideo/modules/attenion.py
ADDED
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib.metadata
|
2 |
+
import math
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
try:
|
9 |
+
import flash_attn
|
10 |
+
from flash_attn.flash_attn_interface import _flash_attn_forward
|
11 |
+
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
12 |
+
except ImportError:
|
13 |
+
flash_attn = None
|
14 |
+
flash_attn_varlen_func = None
|
15 |
+
_flash_attn_forward = None
|
16 |
+
|
17 |
+
try:
|
18 |
+
from sageattention import sageattn_varlen
|
19 |
+
def sageattn_varlen_wrapper(
|
20 |
+
q,
|
21 |
+
k,
|
22 |
+
v,
|
23 |
+
cu_seqlens_q,
|
24 |
+
cu_seqlens_kv,
|
25 |
+
max_seqlen_q,
|
26 |
+
max_seqlen_kv,
|
27 |
+
):
|
28 |
+
return sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
|
29 |
+
except ImportError:
|
30 |
+
sageattn_varlen_wrapper = None
|
31 |
+
|
32 |
+
MEMORY_LAYOUT = {
|
33 |
+
"sdpa": (
|
34 |
+
lambda x: x.transpose(1, 2),
|
35 |
+
lambda x: x.transpose(1, 2),
|
36 |
+
),
|
37 |
+
"sage": (
|
38 |
+
lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
|
39 |
+
lambda x: x,
|
40 |
+
),
|
41 |
+
"flash": (
|
42 |
+
lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
|
43 |
+
lambda x: x,
|
44 |
+
),
|
45 |
+
"torch": (
|
46 |
+
lambda x: x.transpose(1, 2),
|
47 |
+
lambda x: x.transpose(1, 2),
|
48 |
+
),
|
49 |
+
"vanilla": (
|
50 |
+
lambda x: x.transpose(1, 2),
|
51 |
+
lambda x: x.transpose(1, 2),
|
52 |
+
),
|
53 |
+
}
|
54 |
+
|
55 |
+
|
56 |
+
def get_cu_seqlens(text_mask, img_len):
|
57 |
+
"""Calculate cu_seqlens_q, cu_seqlens_kv using text_mask and img_len
|
58 |
+
|
59 |
+
Args:
|
60 |
+
text_mask (torch.Tensor): the mask of text
|
61 |
+
img_len (int): the length of image
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
torch.Tensor: the calculated cu_seqlens for flash attention
|
65 |
+
"""
|
66 |
+
batch_size = text_mask.shape[0]
|
67 |
+
text_len = text_mask.sum(dim=1)
|
68 |
+
max_len = text_mask.shape[1] + img_len
|
69 |
+
|
70 |
+
cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda")
|
71 |
+
|
72 |
+
for i in range(batch_size):
|
73 |
+
s = text_len[i] + img_len
|
74 |
+
s1 = i * max_len + s
|
75 |
+
s2 = (i + 1) * max_len
|
76 |
+
cu_seqlens[2 * i + 1] = s1
|
77 |
+
cu_seqlens[2 * i + 2] = s2
|
78 |
+
|
79 |
+
return cu_seqlens
|
80 |
+
|
81 |
+
|
82 |
+
def attention(
|
83 |
+
q,
|
84 |
+
k,
|
85 |
+
v,
|
86 |
+
mode="flash",
|
87 |
+
drop_rate=0,
|
88 |
+
attn_mask=None,
|
89 |
+
causal=False,
|
90 |
+
cu_seqlens_q=None,
|
91 |
+
cu_seqlens_kv=None,
|
92 |
+
max_seqlen_q=None,
|
93 |
+
max_seqlen_kv=None,
|
94 |
+
batch_size=1,
|
95 |
+
):
|
96 |
+
"""
|
97 |
+
Perform QKV self attention.
|
98 |
+
|
99 |
+
Args:
|
100 |
+
q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.
|
101 |
+
k (torch.Tensor): Key tensor with shape [b, s1, a, d]
|
102 |
+
v (torch.Tensor): Value tensor with shape [b, s1, a, d]
|
103 |
+
mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
|
104 |
+
drop_rate (float): Dropout rate in attention map. (default: 0)
|
105 |
+
attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).
|
106 |
+
(default: None)
|
107 |
+
causal (bool): Whether to use causal attention. (default: False)
|
108 |
+
cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
|
109 |
+
used to index into q.
|
110 |
+
cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
|
111 |
+
used to index into kv.
|
112 |
+
max_seqlen_q (int): The maximum sequence length in the batch of q.
|
113 |
+
max_seqlen_kv (int): The maximum sequence length in the batch of k and v.
|
114 |
+
|
115 |
+
Returns:
|
116 |
+
torch.Tensor: Output tensor after self attention with shape [b, s, ad]
|
117 |
+
"""
|
118 |
+
pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
|
119 |
+
q = pre_attn_layout(q)
|
120 |
+
k = pre_attn_layout(k)
|
121 |
+
v = pre_attn_layout(v)
|
122 |
+
|
123 |
+
if mode == "torch":
|
124 |
+
if attn_mask is not None and attn_mask.dtype != torch.bool:
|
125 |
+
attn_mask = attn_mask.to(q.dtype)
|
126 |
+
x = F.scaled_dot_product_attention(
|
127 |
+
q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal
|
128 |
+
)
|
129 |
+
|
130 |
+
elif mode == "sdpa":
|
131 |
+
if attn_mask is not None and attn_mask.dtype != torch.bool:
|
132 |
+
attn_mask = attn_mask.to(q.dtype)
|
133 |
+
x = F.scaled_dot_product_attention(
|
134 |
+
q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal
|
135 |
+
)
|
136 |
+
|
137 |
+
elif mode == "sage":
|
138 |
+
x = sageattn_varlen_wrapper(
|
139 |
+
q,
|
140 |
+
k,
|
141 |
+
v,
|
142 |
+
cu_seqlens_q,
|
143 |
+
cu_seqlens_kv,
|
144 |
+
max_seqlen_q,
|
145 |
+
max_seqlen_kv,
|
146 |
+
)
|
147 |
+
# x with shape [(bxs), a, d]
|
148 |
+
x = x.view(
|
149 |
+
batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]
|
150 |
+
) # reshape x to [b, s, a, d]
|
151 |
+
|
152 |
+
elif mode == "flash":
|
153 |
+
x = flash_attn_varlen_func(
|
154 |
+
q,
|
155 |
+
k,
|
156 |
+
v,
|
157 |
+
cu_seqlens_q,
|
158 |
+
cu_seqlens_kv,
|
159 |
+
max_seqlen_q,
|
160 |
+
max_seqlen_kv,
|
161 |
+
)
|
162 |
+
# x with shape [(bxs), a, d]
|
163 |
+
x = x.view(
|
164 |
+
batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]
|
165 |
+
) # reshape x to [b, s, a, d]
|
166 |
+
elif mode == "vanilla":
|
167 |
+
scale_factor = 1 / math.sqrt(q.size(-1))
|
168 |
+
|
169 |
+
b, a, s, _ = q.shape
|
170 |
+
s1 = k.size(2)
|
171 |
+
attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
|
172 |
+
if causal:
|
173 |
+
# Only applied to self attention
|
174 |
+
assert (
|
175 |
+
attn_mask is None
|
176 |
+
), "Causal mask and attn_mask cannot be used together"
|
177 |
+
temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(
|
178 |
+
diagonal=0
|
179 |
+
)
|
180 |
+
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
181 |
+
attn_bias.to(q.dtype)
|
182 |
+
|
183 |
+
if attn_mask is not None:
|
184 |
+
if attn_mask.dtype == torch.bool:
|
185 |
+
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
|
186 |
+
else:
|
187 |
+
attn_bias += attn_mask
|
188 |
+
|
189 |
+
# TODO: Maybe force q and k to be float32 to avoid numerical overflow
|
190 |
+
attn = (q @ k.transpose(-2, -1)) * scale_factor
|
191 |
+
attn += attn_bias
|
192 |
+
attn = attn.softmax(dim=-1)
|
193 |
+
attn = torch.dropout(attn, p=drop_rate, train=True)
|
194 |
+
x = attn @ v
|
195 |
+
else:
|
196 |
+
raise NotImplementedError(f"Unsupported attention mode: {mode}")
|
197 |
+
|
198 |
+
x = post_attn_layout(x)
|
199 |
+
b, s, a, d = x.shape
|
200 |
+
out = x.reshape(b, s, -1)
|
201 |
+
return out
|
202 |
+
|
203 |
+
|
204 |
+
def parallel_attention(
|
205 |
+
hybrid_seq_parallel_attn,
|
206 |
+
q,
|
207 |
+
k,
|
208 |
+
v,
|
209 |
+
img_q_len,
|
210 |
+
img_kv_len,
|
211 |
+
cu_seqlens_q,
|
212 |
+
cu_seqlens_kv
|
213 |
+
):
|
214 |
+
attn1 = hybrid_seq_parallel_attn(
|
215 |
+
None,
|
216 |
+
q[:, :img_q_len, :, :],
|
217 |
+
k[:, :img_kv_len, :, :],
|
218 |
+
v[:, :img_kv_len, :, :],
|
219 |
+
dropout_p=0.0,
|
220 |
+
causal=False,
|
221 |
+
joint_tensor_query=q[:,img_q_len:cu_seqlens_q[1]],
|
222 |
+
joint_tensor_key=k[:,img_kv_len:cu_seqlens_kv[1]],
|
223 |
+
joint_tensor_value=v[:,img_kv_len:cu_seqlens_kv[1]],
|
224 |
+
joint_strategy="rear",
|
225 |
+
)
|
226 |
+
if flash_attn.__version__ >= '2.7.0':
|
227 |
+
attn2, *_ = _flash_attn_forward(
|
228 |
+
q[:,cu_seqlens_q[1]:],
|
229 |
+
k[:,cu_seqlens_kv[1]:],
|
230 |
+
v[:,cu_seqlens_kv[1]:],
|
231 |
+
dropout_p=0.0,
|
232 |
+
softmax_scale=q.shape[-1] ** (-0.5),
|
233 |
+
causal=False,
|
234 |
+
window_size_left=-1,
|
235 |
+
window_size_right=-1,
|
236 |
+
softcap=0.0,
|
237 |
+
alibi_slopes=None,
|
238 |
+
return_softmax=False,
|
239 |
+
)
|
240 |
+
else:
|
241 |
+
attn2, *_ = _flash_attn_forward(
|
242 |
+
q[:,cu_seqlens_q[1]:],
|
243 |
+
k[:,cu_seqlens_kv[1]:],
|
244 |
+
v[:,cu_seqlens_kv[1]:],
|
245 |
+
dropout_p=0.0,
|
246 |
+
softmax_scale=q.shape[-1] ** (-0.5),
|
247 |
+
causal=False,
|
248 |
+
window_size=(-1, -1),
|
249 |
+
softcap=0.0,
|
250 |
+
alibi_slopes=None,
|
251 |
+
return_softmax=False,
|
252 |
+
)
|
253 |
+
attn = torch.cat([attn1, attn2], dim=1)
|
254 |
+
b, s, a, d = attn.shape
|
255 |
+
attn = attn.reshape(b, s, -1)
|
256 |
+
|
257 |
+
return attn
|
hyvideo/modules/embed_layers.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from einops import rearrange, repeat
|
5 |
+
|
6 |
+
from ..utils.helpers import to_2tuple
|
7 |
+
|
8 |
+
|
9 |
+
class PatchEmbed(nn.Module):
|
10 |
+
"""2D Image to Patch Embedding
|
11 |
+
|
12 |
+
Image to Patch Embedding using Conv2d
|
13 |
+
|
14 |
+
A convolution based approach to patchifying a 2D image w/ embedding projection.
|
15 |
+
|
16 |
+
Based on the impl in https://github.com/google-research/vision_transformer
|
17 |
+
|
18 |
+
Hacked together by / Copyright 2020 Ross Wightman
|
19 |
+
|
20 |
+
Remove the _assert function in forward function to be compatible with multi-resolution images.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
patch_size=16,
|
26 |
+
in_chans=3,
|
27 |
+
embed_dim=768,
|
28 |
+
norm_layer=None,
|
29 |
+
flatten=True,
|
30 |
+
bias=True,
|
31 |
+
dtype=None,
|
32 |
+
device=None,
|
33 |
+
):
|
34 |
+
factory_kwargs = {"dtype": dtype, "device": device}
|
35 |
+
super().__init__()
|
36 |
+
patch_size = to_2tuple(patch_size)
|
37 |
+
self.patch_size = patch_size
|
38 |
+
self.flatten = flatten
|
39 |
+
|
40 |
+
self.proj = nn.Conv3d(
|
41 |
+
in_chans,
|
42 |
+
embed_dim,
|
43 |
+
kernel_size=patch_size,
|
44 |
+
stride=patch_size,
|
45 |
+
bias=bias,
|
46 |
+
**factory_kwargs
|
47 |
+
)
|
48 |
+
nn.init.xavier_uniform_(self.proj.weight.view(self.proj.weight.size(0), -1))
|
49 |
+
if bias:
|
50 |
+
nn.init.zeros_(self.proj.bias)
|
51 |
+
|
52 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
53 |
+
|
54 |
+
def forward(self, x):
|
55 |
+
x = self.proj(x)
|
56 |
+
if self.flatten:
|
57 |
+
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
58 |
+
x = self.norm(x)
|
59 |
+
return x
|
60 |
+
|
61 |
+
|
62 |
+
class TextProjection(nn.Module):
|
63 |
+
"""
|
64 |
+
Projects text embeddings. Also handles dropout for classifier-free guidance.
|
65 |
+
|
66 |
+
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
|
67 |
+
"""
|
68 |
+
|
69 |
+
def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None):
|
70 |
+
factory_kwargs = {"dtype": dtype, "device": device}
|
71 |
+
super().__init__()
|
72 |
+
self.linear_1 = nn.Linear(
|
73 |
+
in_features=in_channels,
|
74 |
+
out_features=hidden_size,
|
75 |
+
bias=True,
|
76 |
+
**factory_kwargs
|
77 |
+
)
|
78 |
+
self.act_1 = act_layer()
|
79 |
+
self.linear_2 = nn.Linear(
|
80 |
+
in_features=hidden_size,
|
81 |
+
out_features=hidden_size,
|
82 |
+
bias=True,
|
83 |
+
**factory_kwargs
|
84 |
+
)
|
85 |
+
|
86 |
+
def forward(self, caption):
|
87 |
+
hidden_states = self.linear_1(caption)
|
88 |
+
hidden_states = self.act_1(hidden_states)
|
89 |
+
hidden_states = self.linear_2(hidden_states)
|
90 |
+
return hidden_states
|
91 |
+
|
92 |
+
|
93 |
+
def timestep_embedding(t, dim, max_period=10000):
|
94 |
+
"""
|
95 |
+
Create sinusoidal timestep embeddings.
|
96 |
+
|
97 |
+
Args:
|
98 |
+
t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional.
|
99 |
+
dim (int): the dimension of the output.
|
100 |
+
max_period (int): controls the minimum frequency of the embeddings.
|
101 |
+
|
102 |
+
Returns:
|
103 |
+
embedding (torch.Tensor): An (N, D) Tensor of positional embeddings.
|
104 |
+
|
105 |
+
.. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
106 |
+
"""
|
107 |
+
half = dim // 2
|
108 |
+
freqs = torch.exp(
|
109 |
+
-math.log(max_period)
|
110 |
+
* torch.arange(start=0, end=half, dtype=torch.float32)
|
111 |
+
/ half
|
112 |
+
).to(device=t.device)
|
113 |
+
args = t[:, None].float() * freqs[None]
|
114 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
115 |
+
if dim % 2:
|
116 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
117 |
+
return embedding
|
118 |
+
|
119 |
+
|
120 |
+
class TimestepEmbedder(nn.Module):
|
121 |
+
"""
|
122 |
+
Embeds scalar timesteps into vector representations.
|
123 |
+
"""
|
124 |
+
|
125 |
+
def __init__(
|
126 |
+
self,
|
127 |
+
hidden_size,
|
128 |
+
act_layer,
|
129 |
+
frequency_embedding_size=256,
|
130 |
+
max_period=10000,
|
131 |
+
out_size=None,
|
132 |
+
dtype=None,
|
133 |
+
device=None,
|
134 |
+
):
|
135 |
+
factory_kwargs = {"dtype": dtype, "device": device}
|
136 |
+
super().__init__()
|
137 |
+
self.frequency_embedding_size = frequency_embedding_size
|
138 |
+
self.max_period = max_period
|
139 |
+
if out_size is None:
|
140 |
+
out_size = hidden_size
|
141 |
+
|
142 |
+
self.mlp = nn.Sequential(
|
143 |
+
nn.Linear(
|
144 |
+
frequency_embedding_size, hidden_size, bias=True, **factory_kwargs
|
145 |
+
),
|
146 |
+
act_layer(),
|
147 |
+
nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs),
|
148 |
+
)
|
149 |
+
nn.init.normal_(self.mlp[0].weight, std=0.02)
|
150 |
+
nn.init.normal_(self.mlp[2].weight, std=0.02)
|
151 |
+
|
152 |
+
def forward(self, t):
|
153 |
+
t_freq = timestep_embedding(
|
154 |
+
t, self.frequency_embedding_size, self.max_period
|
155 |
+
).type(self.mlp[0].weight.dtype)
|
156 |
+
t_emb = self.mlp(t_freq)
|
157 |
+
return t_emb
|
hyvideo/modules/mlp_layers.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from timm library:
|
2 |
+
# https://github.com/huggingface/pytorch-image-models/blob/648aaa41233ba83eb38faf5ba9d415d574823241/timm/layers/mlp.py#L13
|
3 |
+
|
4 |
+
from functools import partial
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
from .modulate_layers import modulate
|
10 |
+
from ..utils.helpers import to_2tuple
|
11 |
+
|
12 |
+
|
13 |
+
class MLP(nn.Module):
|
14 |
+
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
15 |
+
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
in_channels,
|
19 |
+
hidden_channels=None,
|
20 |
+
out_features=None,
|
21 |
+
act_layer=nn.GELU,
|
22 |
+
norm_layer=None,
|
23 |
+
bias=True,
|
24 |
+
drop=0.0,
|
25 |
+
use_conv=False,
|
26 |
+
device=None,
|
27 |
+
dtype=None,
|
28 |
+
):
|
29 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
30 |
+
super().__init__()
|
31 |
+
out_features = out_features or in_channels
|
32 |
+
hidden_channels = hidden_channels or in_channels
|
33 |
+
bias = to_2tuple(bias)
|
34 |
+
drop_probs = to_2tuple(drop)
|
35 |
+
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
|
36 |
+
|
37 |
+
self.fc1 = linear_layer(
|
38 |
+
in_channels, hidden_channels, bias=bias[0], **factory_kwargs
|
39 |
+
)
|
40 |
+
self.act = act_layer()
|
41 |
+
self.drop1 = nn.Dropout(drop_probs[0])
|
42 |
+
self.norm = (
|
43 |
+
norm_layer(hidden_channels, **factory_kwargs)
|
44 |
+
if norm_layer is not None
|
45 |
+
else nn.Identity()
|
46 |
+
)
|
47 |
+
self.fc2 = linear_layer(
|
48 |
+
hidden_channels, out_features, bias=bias[1], **factory_kwargs
|
49 |
+
)
|
50 |
+
self.drop2 = nn.Dropout(drop_probs[1])
|
51 |
+
|
52 |
+
def forward(self, x):
|
53 |
+
x = self.fc1(x)
|
54 |
+
x = self.act(x)
|
55 |
+
x = self.drop1(x)
|
56 |
+
x = self.norm(x)
|
57 |
+
x = self.fc2(x)
|
58 |
+
x = self.drop2(x)
|
59 |
+
return x
|
60 |
+
|
61 |
+
|
62 |
+
#
|
63 |
+
class MLPEmbedder(nn.Module):
|
64 |
+
"""copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py"""
|
65 |
+
def __init__(self, in_dim: int, hidden_dim: int, device=None, dtype=None):
|
66 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
67 |
+
super().__init__()
|
68 |
+
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True, **factory_kwargs)
|
69 |
+
self.silu = nn.SiLU()
|
70 |
+
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True, **factory_kwargs)
|
71 |
+
|
72 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
73 |
+
return self.out_layer(self.silu(self.in_layer(x)))
|
74 |
+
|
75 |
+
|
76 |
+
class FinalLayer(nn.Module):
|
77 |
+
"""The final layer of DiT."""
|
78 |
+
|
79 |
+
def __init__(
|
80 |
+
self, hidden_size, patch_size, out_channels, act_layer, device=None, dtype=None
|
81 |
+
):
|
82 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
83 |
+
super().__init__()
|
84 |
+
|
85 |
+
# Just use LayerNorm for the final layer
|
86 |
+
self.norm_final = nn.LayerNorm(
|
87 |
+
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
|
88 |
+
)
|
89 |
+
if isinstance(patch_size, int):
|
90 |
+
self.linear = nn.Linear(
|
91 |
+
hidden_size,
|
92 |
+
patch_size * patch_size * out_channels,
|
93 |
+
bias=True,
|
94 |
+
**factory_kwargs
|
95 |
+
)
|
96 |
+
else:
|
97 |
+
self.linear = nn.Linear(
|
98 |
+
hidden_size,
|
99 |
+
patch_size[0] * patch_size[1] * patch_size[2] * out_channels,
|
100 |
+
bias=True,
|
101 |
+
)
|
102 |
+
nn.init.zeros_(self.linear.weight)
|
103 |
+
nn.init.zeros_(self.linear.bias)
|
104 |
+
|
105 |
+
# Here we don't distinguish between the modulate types. Just use the simple one.
|
106 |
+
self.adaLN_modulation = nn.Sequential(
|
107 |
+
act_layer(),
|
108 |
+
nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
|
109 |
+
)
|
110 |
+
# Zero-initialize the modulation
|
111 |
+
nn.init.zeros_(self.adaLN_modulation[1].weight)
|
112 |
+
nn.init.zeros_(self.adaLN_modulation[1].bias)
|
113 |
+
|
114 |
+
def forward(self, x, c):
|
115 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
116 |
+
x = modulate(self.norm_final(x), shift=shift, scale=scale)
|
117 |
+
x = self.linear(x)
|
118 |
+
return x
|
hyvideo/modules/models.py
ADDED
@@ -0,0 +1,870 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, List, Tuple, Optional, Union, Dict
|
2 |
+
from einops import rearrange
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from diffusers.models import ModelMixin
|
9 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
10 |
+
|
11 |
+
from .activation_layers import get_activation_layer
|
12 |
+
from .norm_layers import get_norm_layer
|
13 |
+
from .embed_layers import TimestepEmbedder, PatchEmbed, TextProjection
|
14 |
+
from .attenion import attention, parallel_attention, get_cu_seqlens
|
15 |
+
from .posemb_layers import apply_rotary_emb
|
16 |
+
from .mlp_layers import MLP, MLPEmbedder, FinalLayer
|
17 |
+
from .modulate_layers import ModulateDiT, modulate, apply_gate
|
18 |
+
from .token_refiner import SingleTokenRefiner
|
19 |
+
import numpy as np
|
20 |
+
|
21 |
+
class MMDoubleStreamBlock(nn.Module):
|
22 |
+
"""
|
23 |
+
A multimodal dit block with seperate modulation for
|
24 |
+
text and image/video, see more details (SD3): https://arxiv.org/abs/2403.03206
|
25 |
+
(Flux.1): https://github.com/black-forest-labs/flux
|
26 |
+
"""
|
27 |
+
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
hidden_size: int,
|
31 |
+
heads_num: int,
|
32 |
+
mlp_width_ratio: float,
|
33 |
+
mlp_act_type: str = "gelu_tanh",
|
34 |
+
qk_norm: bool = True,
|
35 |
+
qk_norm_type: str = "rms",
|
36 |
+
qkv_bias: bool = False,
|
37 |
+
dtype: Optional[torch.dtype] = None,
|
38 |
+
device: Optional[torch.device] = None,
|
39 |
+
attention_mode: str = "sdpa",
|
40 |
+
):
|
41 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
42 |
+
super().__init__()
|
43 |
+
|
44 |
+
self.attention_mode = attention_mode
|
45 |
+
self.deterministic = False
|
46 |
+
self.heads_num = heads_num
|
47 |
+
head_dim = hidden_size // heads_num
|
48 |
+
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
|
49 |
+
|
50 |
+
self.img_mod = ModulateDiT(
|
51 |
+
hidden_size,
|
52 |
+
factor=6,
|
53 |
+
act_layer=get_activation_layer("silu"),
|
54 |
+
**factory_kwargs,
|
55 |
+
)
|
56 |
+
self.img_norm1 = nn.LayerNorm(
|
57 |
+
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
|
58 |
+
)
|
59 |
+
|
60 |
+
self.img_attn_qkv = nn.Linear(
|
61 |
+
hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs
|
62 |
+
)
|
63 |
+
qk_norm_layer = get_norm_layer(qk_norm_type)
|
64 |
+
self.img_attn_q_norm = (
|
65 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
66 |
+
if qk_norm
|
67 |
+
else nn.Identity()
|
68 |
+
)
|
69 |
+
self.img_attn_k_norm = (
|
70 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
71 |
+
if qk_norm
|
72 |
+
else nn.Identity()
|
73 |
+
)
|
74 |
+
self.img_attn_proj = nn.Linear(
|
75 |
+
hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
|
76 |
+
)
|
77 |
+
|
78 |
+
self.img_norm2 = nn.LayerNorm(
|
79 |
+
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
|
80 |
+
)
|
81 |
+
self.img_mlp = MLP(
|
82 |
+
hidden_size,
|
83 |
+
mlp_hidden_dim,
|
84 |
+
act_layer=get_activation_layer(mlp_act_type),
|
85 |
+
bias=True,
|
86 |
+
**factory_kwargs,
|
87 |
+
)
|
88 |
+
|
89 |
+
self.txt_mod = ModulateDiT(
|
90 |
+
hidden_size,
|
91 |
+
factor=6,
|
92 |
+
act_layer=get_activation_layer("silu"),
|
93 |
+
**factory_kwargs,
|
94 |
+
)
|
95 |
+
self.txt_norm1 = nn.LayerNorm(
|
96 |
+
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
|
97 |
+
)
|
98 |
+
|
99 |
+
self.txt_attn_qkv = nn.Linear(
|
100 |
+
hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs
|
101 |
+
)
|
102 |
+
self.txt_attn_q_norm = (
|
103 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
104 |
+
if qk_norm
|
105 |
+
else nn.Identity()
|
106 |
+
)
|
107 |
+
self.txt_attn_k_norm = (
|
108 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
109 |
+
if qk_norm
|
110 |
+
else nn.Identity()
|
111 |
+
)
|
112 |
+
self.txt_attn_proj = nn.Linear(
|
113 |
+
hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
|
114 |
+
)
|
115 |
+
|
116 |
+
self.txt_norm2 = nn.LayerNorm(
|
117 |
+
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
|
118 |
+
)
|
119 |
+
self.txt_mlp = MLP(
|
120 |
+
hidden_size,
|
121 |
+
mlp_hidden_dim,
|
122 |
+
act_layer=get_activation_layer(mlp_act_type),
|
123 |
+
bias=True,
|
124 |
+
**factory_kwargs,
|
125 |
+
)
|
126 |
+
self.hybrid_seq_parallel_attn = None
|
127 |
+
|
128 |
+
def enable_deterministic(self):
|
129 |
+
self.deterministic = True
|
130 |
+
|
131 |
+
def disable_deterministic(self):
|
132 |
+
self.deterministic = False
|
133 |
+
|
134 |
+
def forward(
|
135 |
+
self,
|
136 |
+
img: torch.Tensor,
|
137 |
+
txt: torch.Tensor,
|
138 |
+
vec: torch.Tensor,
|
139 |
+
attn_mask = None,
|
140 |
+
cu_seqlens_q: Optional[torch.Tensor] = None,
|
141 |
+
cu_seqlens_kv: Optional[torch.Tensor] = None,
|
142 |
+
max_seqlen_q: Optional[int] = None,
|
143 |
+
max_seqlen_kv: Optional[int] = None,
|
144 |
+
freqs_cis: tuple = None,
|
145 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
146 |
+
(
|
147 |
+
img_mod1_shift,
|
148 |
+
img_mod1_scale,
|
149 |
+
img_mod1_gate,
|
150 |
+
img_mod2_shift,
|
151 |
+
img_mod2_scale,
|
152 |
+
img_mod2_gate,
|
153 |
+
) = self.img_mod(vec).chunk(6, dim=-1)
|
154 |
+
(
|
155 |
+
txt_mod1_shift,
|
156 |
+
txt_mod1_scale,
|
157 |
+
txt_mod1_gate,
|
158 |
+
txt_mod2_shift,
|
159 |
+
txt_mod2_scale,
|
160 |
+
txt_mod2_gate,
|
161 |
+
) = self.txt_mod(vec).chunk(6, dim=-1)
|
162 |
+
|
163 |
+
# Prepare image for attention.
|
164 |
+
img_modulated = self.img_norm1(img)
|
165 |
+
img_modulated = modulate(
|
166 |
+
img_modulated, shift=img_mod1_shift, scale=img_mod1_scale
|
167 |
+
)
|
168 |
+
img_qkv = self.img_attn_qkv(img_modulated)
|
169 |
+
img_q, img_k, img_v = rearrange(
|
170 |
+
img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num
|
171 |
+
)
|
172 |
+
# Apply QK-Norm if needed
|
173 |
+
img_q = self.img_attn_q_norm(img_q).to(img_v)
|
174 |
+
img_k = self.img_attn_k_norm(img_k).to(img_v)
|
175 |
+
|
176 |
+
# Apply RoPE if needed.
|
177 |
+
if freqs_cis is not None:
|
178 |
+
img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
|
179 |
+
assert (
|
180 |
+
img_qq.shape == img_q.shape and img_kk.shape == img_k.shape
|
181 |
+
), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}"
|
182 |
+
img_q, img_k = img_qq, img_kk
|
183 |
+
|
184 |
+
# Prepare txt for attention.
|
185 |
+
txt_modulated = self.txt_norm1(txt)
|
186 |
+
txt_modulated = modulate(
|
187 |
+
txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale
|
188 |
+
)
|
189 |
+
txt_qkv = self.txt_attn_qkv(txt_modulated)
|
190 |
+
txt_q, txt_k, txt_v = rearrange(
|
191 |
+
txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num
|
192 |
+
)
|
193 |
+
# Apply QK-Norm if needed.
|
194 |
+
txt_q = self.txt_attn_q_norm(txt_q).to(txt_v)
|
195 |
+
txt_k = self.txt_attn_k_norm(txt_k).to(txt_v)
|
196 |
+
|
197 |
+
# Run actual attention.
|
198 |
+
q = torch.cat((img_q, txt_q), dim=1)
|
199 |
+
k = torch.cat((img_k, txt_k), dim=1)
|
200 |
+
v = torch.cat((img_v, txt_v), dim=1)
|
201 |
+
# assert (
|
202 |
+
# cu_seqlens_q.shape[0] == 2 * img.shape[0] + 1
|
203 |
+
# ), f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, img.shape[0]:{img.shape[0]}"
|
204 |
+
|
205 |
+
# attention computation start
|
206 |
+
if not self.hybrid_seq_parallel_attn:
|
207 |
+
attn = attention(
|
208 |
+
q,
|
209 |
+
k,
|
210 |
+
v,
|
211 |
+
mode=self.attention_mode,
|
212 |
+
attn_mask=attn_mask,
|
213 |
+
cu_seqlens_q=cu_seqlens_q,
|
214 |
+
cu_seqlens_kv=cu_seqlens_kv,
|
215 |
+
max_seqlen_q=max_seqlen_q,
|
216 |
+
max_seqlen_kv=max_seqlen_kv,
|
217 |
+
batch_size=img_k.shape[0],
|
218 |
+
)
|
219 |
+
else:
|
220 |
+
attn = parallel_attention(
|
221 |
+
self.hybrid_seq_parallel_attn,
|
222 |
+
q,
|
223 |
+
k,
|
224 |
+
v,
|
225 |
+
img_q_len=img_q.shape[1],
|
226 |
+
img_kv_len=img_k.shape[1],
|
227 |
+
cu_seqlens_q=cu_seqlens_q,
|
228 |
+
cu_seqlens_kv=cu_seqlens_kv
|
229 |
+
)
|
230 |
+
|
231 |
+
# attention computation end
|
232 |
+
|
233 |
+
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :]
|
234 |
+
|
235 |
+
# Calculate the img bloks.
|
236 |
+
img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate)
|
237 |
+
img = img + apply_gate(
|
238 |
+
self.img_mlp(
|
239 |
+
modulate(
|
240 |
+
self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale
|
241 |
+
)
|
242 |
+
),
|
243 |
+
gate=img_mod2_gate,
|
244 |
+
)
|
245 |
+
|
246 |
+
# Calculate the txt bloks.
|
247 |
+
txt = txt + apply_gate(self.txt_attn_proj(txt_attn), gate=txt_mod1_gate)
|
248 |
+
txt = txt + apply_gate(
|
249 |
+
self.txt_mlp(
|
250 |
+
modulate(
|
251 |
+
self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale
|
252 |
+
)
|
253 |
+
),
|
254 |
+
gate=txt_mod2_gate,
|
255 |
+
)
|
256 |
+
|
257 |
+
return img, txt
|
258 |
+
|
259 |
+
|
260 |
+
class MMSingleStreamBlock(nn.Module):
|
261 |
+
"""
|
262 |
+
A DiT block with parallel linear layers as described in
|
263 |
+
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
264 |
+
Also refer to (SD3): https://arxiv.org/abs/2403.03206
|
265 |
+
(Flux.1): https://github.com/black-forest-labs/flux
|
266 |
+
"""
|
267 |
+
|
268 |
+
def __init__(
|
269 |
+
self,
|
270 |
+
hidden_size: int,
|
271 |
+
heads_num: int,
|
272 |
+
mlp_width_ratio: float = 4.0,
|
273 |
+
mlp_act_type: str = "gelu_tanh",
|
274 |
+
qk_norm: bool = True,
|
275 |
+
qk_norm_type: str = "rms",
|
276 |
+
qk_scale: float = None,
|
277 |
+
dtype: Optional[torch.dtype] = None,
|
278 |
+
device: Optional[torch.device] = None,
|
279 |
+
attention_mode: str = "sdpa",
|
280 |
+
):
|
281 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
282 |
+
super().__init__()
|
283 |
+
self.attention_mode = attention_mode
|
284 |
+
self.deterministic = False
|
285 |
+
self.hidden_size = hidden_size
|
286 |
+
self.heads_num = heads_num
|
287 |
+
head_dim = hidden_size // heads_num
|
288 |
+
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
|
289 |
+
self.mlp_hidden_dim = mlp_hidden_dim
|
290 |
+
self.scale = qk_scale or head_dim ** -0.5
|
291 |
+
|
292 |
+
# qkv and mlp_in
|
293 |
+
self.linear1 = nn.Linear(
|
294 |
+
hidden_size, hidden_size * 3 + mlp_hidden_dim, **factory_kwargs
|
295 |
+
)
|
296 |
+
# proj and mlp_out
|
297 |
+
self.linear2 = nn.Linear(
|
298 |
+
hidden_size + mlp_hidden_dim, hidden_size, **factory_kwargs
|
299 |
+
)
|
300 |
+
|
301 |
+
qk_norm_layer = get_norm_layer(qk_norm_type)
|
302 |
+
self.q_norm = (
|
303 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
304 |
+
if qk_norm
|
305 |
+
else nn.Identity()
|
306 |
+
)
|
307 |
+
self.k_norm = (
|
308 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
309 |
+
if qk_norm
|
310 |
+
else nn.Identity()
|
311 |
+
)
|
312 |
+
|
313 |
+
self.pre_norm = nn.LayerNorm(
|
314 |
+
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
|
315 |
+
)
|
316 |
+
|
317 |
+
self.mlp_act = get_activation_layer(mlp_act_type)()
|
318 |
+
self.modulation = ModulateDiT(
|
319 |
+
hidden_size,
|
320 |
+
factor=3,
|
321 |
+
act_layer=get_activation_layer("silu"),
|
322 |
+
**factory_kwargs,
|
323 |
+
)
|
324 |
+
self.hybrid_seq_parallel_attn = None
|
325 |
+
|
326 |
+
def enable_deterministic(self):
|
327 |
+
self.deterministic = True
|
328 |
+
|
329 |
+
def disable_deterministic(self):
|
330 |
+
self.deterministic = False
|
331 |
+
|
332 |
+
def forward(
|
333 |
+
self,
|
334 |
+
x: torch.Tensor,
|
335 |
+
vec: torch.Tensor,
|
336 |
+
txt_len: int,
|
337 |
+
attn_mask= None,
|
338 |
+
cu_seqlens_q: Optional[torch.Tensor] = None,
|
339 |
+
cu_seqlens_kv: Optional[torch.Tensor] = None,
|
340 |
+
max_seqlen_q: Optional[int] = None,
|
341 |
+
max_seqlen_kv: Optional[int] = None,
|
342 |
+
freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
|
343 |
+
|
344 |
+
) -> torch.Tensor:
|
345 |
+
mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1)
|
346 |
+
x_mod = modulate(self.pre_norm(x), shift=mod_shift, scale=mod_scale)
|
347 |
+
qkv, mlp = torch.split(
|
348 |
+
self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1
|
349 |
+
)
|
350 |
+
|
351 |
+
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
352 |
+
|
353 |
+
# Apply QK-Norm if needed.
|
354 |
+
q = self.q_norm(q).to(v)
|
355 |
+
k = self.k_norm(k).to(v)
|
356 |
+
|
357 |
+
# Apply RoPE if needed.
|
358 |
+
if freqs_cis is not None:
|
359 |
+
img_q, txt_q = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :]
|
360 |
+
img_k, txt_k = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
|
361 |
+
img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
|
362 |
+
assert (
|
363 |
+
img_qq.shape == img_q.shape and img_kk.shape == img_k.shape
|
364 |
+
), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}"
|
365 |
+
img_q, img_k = img_qq, img_kk
|
366 |
+
q = torch.cat((img_q, txt_q), dim=1)
|
367 |
+
k = torch.cat((img_k, txt_k), dim=1)
|
368 |
+
|
369 |
+
# Compute attention.
|
370 |
+
# assert (
|
371 |
+
# cu_seqlens_q.shape[0] == 2 * x.shape[0] + 1
|
372 |
+
# ), f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, x.shape[0]:{x.shape[0]}"
|
373 |
+
|
374 |
+
# attention computation start
|
375 |
+
if not self.hybrid_seq_parallel_attn:
|
376 |
+
attn = attention(
|
377 |
+
q,
|
378 |
+
k,
|
379 |
+
v,
|
380 |
+
mode=self.attention_mode,
|
381 |
+
attn_mask=attn_mask,
|
382 |
+
cu_seqlens_q=cu_seqlens_q,
|
383 |
+
cu_seqlens_kv=cu_seqlens_kv,
|
384 |
+
max_seqlen_q=max_seqlen_q,
|
385 |
+
max_seqlen_kv=max_seqlen_kv,
|
386 |
+
batch_size=x.shape[0],
|
387 |
+
)
|
388 |
+
else:
|
389 |
+
attn = parallel_attention(
|
390 |
+
self.hybrid_seq_parallel_attn,
|
391 |
+
q,
|
392 |
+
k,
|
393 |
+
v,
|
394 |
+
img_q_len=img_q.shape[1],
|
395 |
+
img_kv_len=img_k.shape[1],
|
396 |
+
cu_seqlens_q=cu_seqlens_q,
|
397 |
+
cu_seqlens_kv=cu_seqlens_kv
|
398 |
+
)
|
399 |
+
# attention computation end
|
400 |
+
|
401 |
+
# Compute activation in mlp stream, cat again and run second linear layer.
|
402 |
+
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
403 |
+
return x + apply_gate(output, gate=mod_gate)
|
404 |
+
|
405 |
+
|
406 |
+
class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
|
407 |
+
"""
|
408 |
+
HunyuanVideo Transformer backbone
|
409 |
+
|
410 |
+
Inherited from ModelMixin and ConfigMixin for compatibility with diffusers' sampler StableDiffusionPipeline.
|
411 |
+
|
412 |
+
Reference:
|
413 |
+
[1] Flux.1: https://github.com/black-forest-labs/flux
|
414 |
+
[2] MMDiT: http://arxiv.org/abs/2403.03206
|
415 |
+
|
416 |
+
Parameters
|
417 |
+
----------
|
418 |
+
args: argparse.Namespace
|
419 |
+
The arguments parsed by argparse.
|
420 |
+
patch_size: list
|
421 |
+
The size of the patch.
|
422 |
+
in_channels: int
|
423 |
+
The number of input channels.
|
424 |
+
out_channels: int
|
425 |
+
The number of output channels.
|
426 |
+
hidden_size: int
|
427 |
+
The hidden size of the transformer backbone.
|
428 |
+
heads_num: int
|
429 |
+
The number of attention heads.
|
430 |
+
mlp_width_ratio: float
|
431 |
+
The ratio of the hidden size of the MLP in the transformer block.
|
432 |
+
mlp_act_type: str
|
433 |
+
The activation function of the MLP in the transformer block.
|
434 |
+
depth_double_blocks: int
|
435 |
+
The number of transformer blocks in the double blocks.
|
436 |
+
depth_single_blocks: int
|
437 |
+
The number of transformer blocks in the single blocks.
|
438 |
+
rope_dim_list: list
|
439 |
+
The dimension of the rotary embedding for t, h, w.
|
440 |
+
qkv_bias: bool
|
441 |
+
Whether to use bias in the qkv linear layer.
|
442 |
+
qk_norm: bool
|
443 |
+
Whether to use qk norm.
|
444 |
+
qk_norm_type: str
|
445 |
+
The type of qk norm.
|
446 |
+
guidance_embed: bool
|
447 |
+
Whether to use guidance embedding for distillation.
|
448 |
+
text_projection: str
|
449 |
+
The type of the text projection, default is single_refiner.
|
450 |
+
use_attention_mask: bool
|
451 |
+
Whether to use attention mask for text encoder.
|
452 |
+
dtype: torch.dtype
|
453 |
+
The dtype of the model.
|
454 |
+
device: torch.device
|
455 |
+
The device of the model.
|
456 |
+
"""
|
457 |
+
|
458 |
+
@register_to_config
|
459 |
+
def __init__(
|
460 |
+
self,
|
461 |
+
args: Any,
|
462 |
+
patch_size: list = [1, 2, 2],
|
463 |
+
in_channels: int = 4, # Should be VAE.config.latent_channels.
|
464 |
+
out_channels: int = None,
|
465 |
+
hidden_size: int = 3072,
|
466 |
+
heads_num: int = 24,
|
467 |
+
mlp_width_ratio: float = 4.0,
|
468 |
+
mlp_act_type: str = "gelu_tanh",
|
469 |
+
mm_double_blocks_depth: int = 20,
|
470 |
+
mm_single_blocks_depth: int = 40,
|
471 |
+
rope_dim_list: List[int] = [16, 56, 56],
|
472 |
+
qkv_bias: bool = True,
|
473 |
+
qk_norm: bool = True,
|
474 |
+
qk_norm_type: str = "rms",
|
475 |
+
guidance_embed: bool = False, # For modulation.
|
476 |
+
text_projection: str = "single_refiner",
|
477 |
+
use_attention_mask: bool = True,
|
478 |
+
dtype: Optional[torch.dtype] = None,
|
479 |
+
device: Optional[torch.device] = None,
|
480 |
+
attention_mode: Optional[str] = "sdpa"
|
481 |
+
):
|
482 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
483 |
+
super().__init__()
|
484 |
+
|
485 |
+
self.patch_size = patch_size
|
486 |
+
self.in_channels = in_channels
|
487 |
+
self.out_channels = in_channels if out_channels is None else out_channels
|
488 |
+
self.unpatchify_channels = self.out_channels
|
489 |
+
self.guidance_embed = guidance_embed
|
490 |
+
self.rope_dim_list = rope_dim_list
|
491 |
+
self.attention_mode = attention_mode
|
492 |
+
|
493 |
+
# Text projection. Default to linear projection.
|
494 |
+
# Alternative: TokenRefiner. See more details (LI-DiT): http://arxiv.org/abs/2406.11831
|
495 |
+
self.use_attention_mask = use_attention_mask
|
496 |
+
self.text_projection = text_projection
|
497 |
+
|
498 |
+
self.text_states_dim = args.text_states_dim
|
499 |
+
self.text_states_dim_2 = args.text_states_dim_2
|
500 |
+
|
501 |
+
if hidden_size % heads_num != 0:
|
502 |
+
raise ValueError(
|
503 |
+
f"Hidden size {hidden_size} must be divisible by heads_num {heads_num}"
|
504 |
+
)
|
505 |
+
pe_dim = hidden_size // heads_num
|
506 |
+
if sum(rope_dim_list) != pe_dim:
|
507 |
+
raise ValueError(
|
508 |
+
f"Got {rope_dim_list} but expected positional dim {pe_dim}"
|
509 |
+
)
|
510 |
+
self.hidden_size = hidden_size
|
511 |
+
self.heads_num = heads_num
|
512 |
+
|
513 |
+
# image projection
|
514 |
+
self.img_in = PatchEmbed(
|
515 |
+
self.patch_size, self.in_channels, self.hidden_size, **factory_kwargs
|
516 |
+
)
|
517 |
+
|
518 |
+
# text projection
|
519 |
+
if self.text_projection == "linear":
|
520 |
+
self.txt_in = TextProjection(
|
521 |
+
self.text_states_dim,
|
522 |
+
self.hidden_size,
|
523 |
+
get_activation_layer("silu"),
|
524 |
+
**factory_kwargs,
|
525 |
+
)
|
526 |
+
elif self.text_projection == "single_refiner":
|
527 |
+
self.txt_in = SingleTokenRefiner(
|
528 |
+
self.text_states_dim, hidden_size, heads_num, depth=2, **factory_kwargs
|
529 |
+
)
|
530 |
+
else:
|
531 |
+
raise NotImplementedError(
|
532 |
+
f"Unsupported text_projection: {self.text_projection}"
|
533 |
+
)
|
534 |
+
|
535 |
+
# time modulation
|
536 |
+
self.time_in = TimestepEmbedder(
|
537 |
+
self.hidden_size, get_activation_layer("silu"), **factory_kwargs
|
538 |
+
)
|
539 |
+
|
540 |
+
# text modulation
|
541 |
+
self.vector_in = MLPEmbedder(
|
542 |
+
self.text_states_dim_2, self.hidden_size, **factory_kwargs
|
543 |
+
)
|
544 |
+
|
545 |
+
# guidance modulation
|
546 |
+
self.guidance_in = (
|
547 |
+
TimestepEmbedder(
|
548 |
+
self.hidden_size, get_activation_layer("silu"), **factory_kwargs
|
549 |
+
)
|
550 |
+
if guidance_embed
|
551 |
+
else None
|
552 |
+
)
|
553 |
+
|
554 |
+
# double blocks
|
555 |
+
self.double_blocks = nn.ModuleList(
|
556 |
+
[
|
557 |
+
MMDoubleStreamBlock(
|
558 |
+
self.hidden_size,
|
559 |
+
self.heads_num,
|
560 |
+
mlp_width_ratio=mlp_width_ratio,
|
561 |
+
mlp_act_type=mlp_act_type,
|
562 |
+
qk_norm=qk_norm,
|
563 |
+
qk_norm_type=qk_norm_type,
|
564 |
+
qkv_bias=qkv_bias,
|
565 |
+
attention_mode = attention_mode,
|
566 |
+
**factory_kwargs,
|
567 |
+
)
|
568 |
+
for _ in range(mm_double_blocks_depth)
|
569 |
+
]
|
570 |
+
)
|
571 |
+
|
572 |
+
# single blocks
|
573 |
+
self.single_blocks = nn.ModuleList(
|
574 |
+
[
|
575 |
+
MMSingleStreamBlock(
|
576 |
+
self.hidden_size,
|
577 |
+
self.heads_num,
|
578 |
+
mlp_width_ratio=mlp_width_ratio,
|
579 |
+
mlp_act_type=mlp_act_type,
|
580 |
+
qk_norm=qk_norm,
|
581 |
+
qk_norm_type=qk_norm_type,
|
582 |
+
attention_mode = attention_mode,
|
583 |
+
**factory_kwargs,
|
584 |
+
)
|
585 |
+
for _ in range(mm_single_blocks_depth)
|
586 |
+
]
|
587 |
+
)
|
588 |
+
|
589 |
+
self.final_layer = FinalLayer(
|
590 |
+
self.hidden_size,
|
591 |
+
self.patch_size,
|
592 |
+
self.out_channels,
|
593 |
+
get_activation_layer("silu"),
|
594 |
+
**factory_kwargs,
|
595 |
+
)
|
596 |
+
|
597 |
+
def enable_deterministic(self):
|
598 |
+
for block in self.double_blocks:
|
599 |
+
block.enable_deterministic()
|
600 |
+
for block in self.single_blocks:
|
601 |
+
block.enable_deterministic()
|
602 |
+
|
603 |
+
def disable_deterministic(self):
|
604 |
+
for block in self.double_blocks:
|
605 |
+
block.disable_deterministic()
|
606 |
+
for block in self.single_blocks:
|
607 |
+
block.disable_deterministic()
|
608 |
+
|
609 |
+
def forward(
|
610 |
+
self,
|
611 |
+
x: torch.Tensor,
|
612 |
+
t: torch.Tensor, # Should be in range(0, 1000).
|
613 |
+
text_states: torch.Tensor = None,
|
614 |
+
text_mask: torch.Tensor = None, # Now we don't use it.
|
615 |
+
text_states_2: Optional[torch.Tensor] = None, # Text embedding for modulation.
|
616 |
+
freqs_cos: Optional[torch.Tensor] = None,
|
617 |
+
freqs_sin: Optional[torch.Tensor] = None,
|
618 |
+
guidance: torch.Tensor = None, # Guidance for modulation, should be cfg_scale x 1000.
|
619 |
+
return_dict: bool = True,
|
620 |
+
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
621 |
+
out = {}
|
622 |
+
img = x
|
623 |
+
txt = text_states
|
624 |
+
_, _, ot, oh, ow = x.shape
|
625 |
+
tt, th, tw = (
|
626 |
+
ot // self.patch_size[0],
|
627 |
+
oh // self.patch_size[1],
|
628 |
+
ow // self.patch_size[2],
|
629 |
+
)
|
630 |
+
|
631 |
+
# Prepare modulation vectors.
|
632 |
+
vec = self.time_in(t)
|
633 |
+
|
634 |
+
# text modulation
|
635 |
+
vec = vec + self.vector_in(text_states_2)
|
636 |
+
|
637 |
+
# guidance modulation
|
638 |
+
if self.guidance_embed:
|
639 |
+
if guidance is None:
|
640 |
+
raise ValueError(
|
641 |
+
"Didn't get guidance strength for guidance distilled model."
|
642 |
+
)
|
643 |
+
|
644 |
+
# our timestep_embedding is merged into guidance_in(TimestepEmbedder)
|
645 |
+
vec = vec + self.guidance_in(guidance)
|
646 |
+
|
647 |
+
# Embed image and text.
|
648 |
+
img = self.img_in(img)
|
649 |
+
if self.text_projection == "linear":
|
650 |
+
txt = self.txt_in(txt)
|
651 |
+
elif self.text_projection == "single_refiner":
|
652 |
+
txt = self.txt_in(txt, t, text_mask if self.use_attention_mask else None)
|
653 |
+
else:
|
654 |
+
raise NotImplementedError(
|
655 |
+
f"Unsupported text_projection: {self.text_projection}"
|
656 |
+
)
|
657 |
+
|
658 |
+
txt_seq_len = txt.shape[1]
|
659 |
+
img_seq_len = img.shape[1]
|
660 |
+
|
661 |
+
# Compute cu_squlens and max_seqlen for flash attention
|
662 |
+
# cu_seqlens_q = get_cu_seqlens(text_mask, img_seq_len)
|
663 |
+
# cu_seqlens_kv = cu_seqlens_q
|
664 |
+
max_seqlen_q = img_seq_len + txt_seq_len
|
665 |
+
max_seqlen_kv = max_seqlen_q
|
666 |
+
|
667 |
+
# thanks to kijai (https://github.com/kijai/ComfyUI-HunyuanVideoWrapper/), for the code to support sdpa
|
668 |
+
if self.attention_mode == "sdpa":
|
669 |
+
cu_seqlens_q, cu_seqlens_kv = None, None
|
670 |
+
# Create a square boolean mask filled with False
|
671 |
+
attn_mask = torch.zeros((1, max_seqlen_q, max_seqlen_q), dtype=torch.bool, device=text_mask.device)
|
672 |
+
|
673 |
+
# Calculate the valid attention regions
|
674 |
+
text_len = text_mask[0].sum().item()
|
675 |
+
total_len = text_len + img_seq_len
|
676 |
+
|
677 |
+
# Allow attention to all tokens up to total_len
|
678 |
+
attn_mask[0, :total_len, :total_len] = True
|
679 |
+
else:
|
680 |
+
attn_mask = None
|
681 |
+
# Compute cu_squlens for flash and sage attention
|
682 |
+
cu_seqlens_q = get_cu_seqlens(text_mask, img_seq_len)
|
683 |
+
cu_seqlens_kv = cu_seqlens_q
|
684 |
+
|
685 |
+
freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
|
686 |
+
|
687 |
+
if self.enable_teacache:
|
688 |
+
inp = img.clone()
|
689 |
+
vec_ = vec.clone()
|
690 |
+
txt_ = txt.clone()
|
691 |
+
(
|
692 |
+
img_mod1_shift,
|
693 |
+
img_mod1_scale,
|
694 |
+
img_mod1_gate,
|
695 |
+
img_mod2_shift,
|
696 |
+
img_mod2_scale,
|
697 |
+
img_mod2_gate,
|
698 |
+
) = self.double_blocks[0].img_mod(vec_).chunk(6, dim=-1)
|
699 |
+
normed_inp = self.double_blocks[0].img_norm1(inp)
|
700 |
+
modulated_inp = modulate(
|
701 |
+
normed_inp, shift=img_mod1_shift, scale=img_mod1_scale
|
702 |
+
)
|
703 |
+
if self.cnt == 0 or self.cnt == self.num_steps-1:
|
704 |
+
should_calc = True
|
705 |
+
self.accumulated_rel_l1_distance = 0
|
706 |
+
else:
|
707 |
+
coefficients = [7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02]
|
708 |
+
rescale_func = np.poly1d(coefficients)
|
709 |
+
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
|
710 |
+
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
|
711 |
+
should_calc = False
|
712 |
+
else:
|
713 |
+
should_calc = True
|
714 |
+
self.accumulated_rel_l1_distance = 0
|
715 |
+
self.previous_modulated_input = modulated_inp
|
716 |
+
self.cnt += 1
|
717 |
+
if self.cnt == self.num_steps:
|
718 |
+
self.cnt = 0
|
719 |
+
|
720 |
+
if self.enable_teacache:
|
721 |
+
if not should_calc:
|
722 |
+
img += self.previous_residual
|
723 |
+
else:
|
724 |
+
ori_img = img.clone()
|
725 |
+
# --------------------- Pass through DiT blocks ------------------------
|
726 |
+
for _, block in enumerate(self.double_blocks):
|
727 |
+
double_block_args = [
|
728 |
+
img,
|
729 |
+
txt,
|
730 |
+
vec,
|
731 |
+
attn_mask,
|
732 |
+
cu_seqlens_q,
|
733 |
+
cu_seqlens_kv,
|
734 |
+
max_seqlen_q,
|
735 |
+
max_seqlen_kv,
|
736 |
+
freqs_cis,
|
737 |
+
]
|
738 |
+
|
739 |
+
img, txt = block(*double_block_args)
|
740 |
+
|
741 |
+
# Merge txt and img to pass through single stream blocks.
|
742 |
+
x = torch.cat((img, txt), 1)
|
743 |
+
if len(self.single_blocks) > 0:
|
744 |
+
for _, block in enumerate(self.single_blocks):
|
745 |
+
single_block_args = [
|
746 |
+
x,
|
747 |
+
vec,
|
748 |
+
txt_seq_len,
|
749 |
+
attn_mask,
|
750 |
+
cu_seqlens_q,
|
751 |
+
cu_seqlens_kv,
|
752 |
+
max_seqlen_q,
|
753 |
+
max_seqlen_kv,
|
754 |
+
(freqs_cos, freqs_sin),
|
755 |
+
]
|
756 |
+
|
757 |
+
x = block(*single_block_args)
|
758 |
+
|
759 |
+
img = x[:, :img_seq_len, ...]
|
760 |
+
self.previous_residual = img - ori_img
|
761 |
+
else:
|
762 |
+
# --------------------- Pass through DiT blocks ------------------------
|
763 |
+
for _, block in enumerate(self.double_blocks):
|
764 |
+
double_block_args = [
|
765 |
+
img,
|
766 |
+
txt,
|
767 |
+
vec,
|
768 |
+
attn_mask,
|
769 |
+
cu_seqlens_q,
|
770 |
+
cu_seqlens_kv,
|
771 |
+
max_seqlen_q,
|
772 |
+
max_seqlen_kv,
|
773 |
+
freqs_cis,
|
774 |
+
]
|
775 |
+
|
776 |
+
img, txt = block(*double_block_args)
|
777 |
+
|
778 |
+
# Merge txt and img to pass through single stream blocks.
|
779 |
+
x = torch.cat((img, txt), 1)
|
780 |
+
if len(self.single_blocks) > 0:
|
781 |
+
for _, block in enumerate(self.single_blocks):
|
782 |
+
single_block_args = [
|
783 |
+
x,
|
784 |
+
vec,
|
785 |
+
txt_seq_len,
|
786 |
+
attn_mask,
|
787 |
+
cu_seqlens_q,
|
788 |
+
cu_seqlens_kv,
|
789 |
+
max_seqlen_q,
|
790 |
+
max_seqlen_kv,
|
791 |
+
(freqs_cos, freqs_sin),
|
792 |
+
]
|
793 |
+
|
794 |
+
x = block(*single_block_args)
|
795 |
+
|
796 |
+
img = x[:, :img_seq_len, ...]
|
797 |
+
|
798 |
+
# ---------------------------- Final layer ------------------------------
|
799 |
+
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
800 |
+
|
801 |
+
img = self.unpatchify(img, tt, th, tw)
|
802 |
+
if return_dict:
|
803 |
+
out["x"] = img
|
804 |
+
return out
|
805 |
+
return img
|
806 |
+
|
807 |
+
def unpatchify(self, x, t, h, w):
|
808 |
+
"""
|
809 |
+
x: (N, T, patch_size**2 * C)
|
810 |
+
imgs: (N, H, W, C)
|
811 |
+
"""
|
812 |
+
c = self.unpatchify_channels
|
813 |
+
pt, ph, pw = self.patch_size
|
814 |
+
assert t * h * w == x.shape[1]
|
815 |
+
|
816 |
+
x = x.reshape(shape=(x.shape[0], t, h, w, c, pt, ph, pw))
|
817 |
+
x = torch.einsum("nthwcopq->nctohpwq", x)
|
818 |
+
imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
|
819 |
+
|
820 |
+
return imgs
|
821 |
+
|
822 |
+
def params_count(self):
|
823 |
+
counts = {
|
824 |
+
"double": sum(
|
825 |
+
[
|
826 |
+
sum(p.numel() for p in block.img_attn_qkv.parameters())
|
827 |
+
+ sum(p.numel() for p in block.img_attn_proj.parameters())
|
828 |
+
+ sum(p.numel() for p in block.img_mlp.parameters())
|
829 |
+
+ sum(p.numel() for p in block.txt_attn_qkv.parameters())
|
830 |
+
+ sum(p.numel() for p in block.txt_attn_proj.parameters())
|
831 |
+
+ sum(p.numel() for p in block.txt_mlp.parameters())
|
832 |
+
for block in self.double_blocks
|
833 |
+
]
|
834 |
+
),
|
835 |
+
"single": sum(
|
836 |
+
[
|
837 |
+
sum(p.numel() for p in block.linear1.parameters())
|
838 |
+
+ sum(p.numel() for p in block.linear2.parameters())
|
839 |
+
for block in self.single_blocks
|
840 |
+
]
|
841 |
+
),
|
842 |
+
"total": sum(p.numel() for p in self.parameters()),
|
843 |
+
}
|
844 |
+
counts["attn+mlp"] = counts["double"] + counts["single"]
|
845 |
+
return counts
|
846 |
+
|
847 |
+
|
848 |
+
#################################################################################
|
849 |
+
# HunyuanVideo Configs #
|
850 |
+
#################################################################################
|
851 |
+
|
852 |
+
HUNYUAN_VIDEO_CONFIG = {
|
853 |
+
"HYVideo-T/2": {
|
854 |
+
"mm_double_blocks_depth": 20,
|
855 |
+
"mm_single_blocks_depth": 40,
|
856 |
+
"rope_dim_list": [16, 56, 56],
|
857 |
+
"hidden_size": 3072,
|
858 |
+
"heads_num": 24,
|
859 |
+
"mlp_width_ratio": 4,
|
860 |
+
},
|
861 |
+
"HYVideo-T/2-cfgdistill": {
|
862 |
+
"mm_double_blocks_depth": 20,
|
863 |
+
"mm_single_blocks_depth": 40,
|
864 |
+
"rope_dim_list": [16, 56, 56],
|
865 |
+
"hidden_size": 3072,
|
866 |
+
"heads_num": 24,
|
867 |
+
"mlp_width_ratio": 4,
|
868 |
+
"guidance_embed": True,
|
869 |
+
},
|
870 |
+
}
|
hyvideo/modules/modulate_layers.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
|
7 |
+
class ModulateDiT(nn.Module):
|
8 |
+
"""Modulation layer for DiT."""
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
hidden_size: int,
|
12 |
+
factor: int,
|
13 |
+
act_layer: Callable,
|
14 |
+
dtype=None,
|
15 |
+
device=None,
|
16 |
+
):
|
17 |
+
factory_kwargs = {"dtype": dtype, "device": device}
|
18 |
+
super().__init__()
|
19 |
+
self.act = act_layer()
|
20 |
+
self.linear = nn.Linear(
|
21 |
+
hidden_size, factor * hidden_size, bias=True, **factory_kwargs
|
22 |
+
)
|
23 |
+
# Zero-initialize the modulation
|
24 |
+
nn.init.zeros_(self.linear.weight)
|
25 |
+
nn.init.zeros_(self.linear.bias)
|
26 |
+
|
27 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
28 |
+
return self.linear(self.act(x))
|
29 |
+
|
30 |
+
|
31 |
+
def modulate(x, shift=None, scale=None):
|
32 |
+
"""modulate by shift and scale
|
33 |
+
|
34 |
+
Args:
|
35 |
+
x (torch.Tensor): input tensor.
|
36 |
+
shift (torch.Tensor, optional): shift tensor. Defaults to None.
|
37 |
+
scale (torch.Tensor, optional): scale tensor. Defaults to None.
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
torch.Tensor: the output tensor after modulate.
|
41 |
+
"""
|
42 |
+
if scale is None and shift is None:
|
43 |
+
return x
|
44 |
+
elif shift is None:
|
45 |
+
return x * (1 + scale.unsqueeze(1))
|
46 |
+
elif scale is None:
|
47 |
+
return x + shift.unsqueeze(1)
|
48 |
+
else:
|
49 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
50 |
+
|
51 |
+
|
52 |
+
def apply_gate(x, gate=None, tanh=False):
|
53 |
+
"""AI is creating summary for apply_gate
|
54 |
+
|
55 |
+
Args:
|
56 |
+
x (torch.Tensor): input tensor.
|
57 |
+
gate (torch.Tensor, optional): gate tensor. Defaults to None.
|
58 |
+
tanh (bool, optional): whether to use tanh function. Defaults to False.
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
torch.Tensor: the output tensor after apply gate.
|
62 |
+
"""
|
63 |
+
if gate is None:
|
64 |
+
return x
|
65 |
+
if tanh:
|
66 |
+
return x * gate.unsqueeze(1).tanh()
|
67 |
+
else:
|
68 |
+
return x * gate.unsqueeze(1)
|
69 |
+
|
70 |
+
|
71 |
+
def ckpt_wrapper(module):
|
72 |
+
def ckpt_forward(*inputs):
|
73 |
+
outputs = module(*inputs)
|
74 |
+
return outputs
|
75 |
+
|
76 |
+
return ckpt_forward
|
hyvideo/modules/norm_layers.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
class RMSNorm(nn.Module):
|
6 |
+
def __init__(
|
7 |
+
self,
|
8 |
+
dim: int,
|
9 |
+
elementwise_affine=True,
|
10 |
+
eps: float = 1e-6,
|
11 |
+
device=None,
|
12 |
+
dtype=None,
|
13 |
+
):
|
14 |
+
"""
|
15 |
+
Initialize the RMSNorm normalization layer.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
dim (int): The dimension of the input tensor.
|
19 |
+
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
20 |
+
|
21 |
+
Attributes:
|
22 |
+
eps (float): A small value added to the denominator for numerical stability.
|
23 |
+
weight (nn.Parameter): Learnable scaling parameter.
|
24 |
+
|
25 |
+
"""
|
26 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
27 |
+
super().__init__()
|
28 |
+
self.eps = eps
|
29 |
+
if elementwise_affine:
|
30 |
+
self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
|
31 |
+
|
32 |
+
def _norm(self, x):
|
33 |
+
"""
|
34 |
+
Apply the RMSNorm normalization to the input tensor.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
x (torch.Tensor): The input tensor.
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
torch.Tensor: The normalized tensor.
|
41 |
+
|
42 |
+
"""
|
43 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
44 |
+
|
45 |
+
def forward(self, x):
|
46 |
+
"""
|
47 |
+
Forward pass through the RMSNorm layer.
|
48 |
+
|
49 |
+
Args:
|
50 |
+
x (torch.Tensor): The input tensor.
|
51 |
+
|
52 |
+
Returns:
|
53 |
+
torch.Tensor: The output tensor after applying RMSNorm.
|
54 |
+
|
55 |
+
"""
|
56 |
+
output = self._norm(x.float()).type_as(x)
|
57 |
+
if hasattr(self, "weight"):
|
58 |
+
output = output * self.weight
|
59 |
+
return output
|
60 |
+
|
61 |
+
|
62 |
+
def get_norm_layer(norm_layer):
|
63 |
+
"""
|
64 |
+
Get the normalization layer.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
norm_layer (str): The type of normalization layer.
|
68 |
+
|
69 |
+
Returns:
|
70 |
+
norm_layer (nn.Module): The normalization layer.
|
71 |
+
"""
|
72 |
+
if norm_layer == "layer":
|
73 |
+
return nn.LayerNorm
|
74 |
+
elif norm_layer == "rms":
|
75 |
+
return RMSNorm
|
76 |
+
else:
|
77 |
+
raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
|
hyvideo/modules/posemb_layers.py
ADDED
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from typing import Union, Tuple, List
|
3 |
+
|
4 |
+
|
5 |
+
def _to_tuple(x, dim=2):
|
6 |
+
if isinstance(x, int):
|
7 |
+
return (x,) * dim
|
8 |
+
elif len(x) == dim:
|
9 |
+
return x
|
10 |
+
else:
|
11 |
+
raise ValueError(f"Expected length {dim} or int, but got {x}")
|
12 |
+
|
13 |
+
|
14 |
+
def get_meshgrid_nd(start, *args, dim=2):
|
15 |
+
"""
|
16 |
+
Get n-D meshgrid with start, stop and num.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
|
20 |
+
step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
|
21 |
+
should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
|
22 |
+
n-tuples.
|
23 |
+
*args: See above.
|
24 |
+
dim (int): Dimension of the meshgrid. Defaults to 2.
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
grid (np.ndarray): [dim, ...]
|
28 |
+
"""
|
29 |
+
if len(args) == 0:
|
30 |
+
# start is grid_size
|
31 |
+
num = _to_tuple(start, dim=dim)
|
32 |
+
start = (0,) * dim
|
33 |
+
stop = num
|
34 |
+
elif len(args) == 1:
|
35 |
+
# start is start, args[0] is stop, step is 1
|
36 |
+
start = _to_tuple(start, dim=dim)
|
37 |
+
stop = _to_tuple(args[0], dim=dim)
|
38 |
+
num = [stop[i] - start[i] for i in range(dim)]
|
39 |
+
elif len(args) == 2:
|
40 |
+
# start is start, args[0] is stop, args[1] is num
|
41 |
+
start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0
|
42 |
+
stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32
|
43 |
+
num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124
|
44 |
+
else:
|
45 |
+
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
|
46 |
+
|
47 |
+
# PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
|
48 |
+
axis_grid = []
|
49 |
+
for i in range(dim):
|
50 |
+
a, b, n = start[i], stop[i], num[i]
|
51 |
+
g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
|
52 |
+
axis_grid.append(g)
|
53 |
+
grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D]
|
54 |
+
grid = torch.stack(grid, dim=0) # [dim, W, H, D]
|
55 |
+
|
56 |
+
return grid
|
57 |
+
|
58 |
+
|
59 |
+
#################################################################################
|
60 |
+
# Rotary Positional Embedding Functions #
|
61 |
+
#################################################################################
|
62 |
+
# https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L80
|
63 |
+
|
64 |
+
|
65 |
+
def reshape_for_broadcast(
|
66 |
+
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
67 |
+
x: torch.Tensor,
|
68 |
+
head_first=False,
|
69 |
+
):
|
70 |
+
"""
|
71 |
+
Reshape frequency tensor for broadcasting it with another tensor.
|
72 |
+
|
73 |
+
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
|
74 |
+
for the purpose of broadcasting the frequency tensor during element-wise operations.
|
75 |
+
|
76 |
+
Notes:
|
77 |
+
When using FlashMHAModified, head_first should be False.
|
78 |
+
When using Attention, head_first should be True.
|
79 |
+
|
80 |
+
Args:
|
81 |
+
freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped.
|
82 |
+
x (torch.Tensor): Target tensor for broadcasting compatibility.
|
83 |
+
head_first (bool): head dimension first (except batch dim) or not.
|
84 |
+
|
85 |
+
Returns:
|
86 |
+
torch.Tensor: Reshaped frequency tensor.
|
87 |
+
|
88 |
+
Raises:
|
89 |
+
AssertionError: If the frequency tensor doesn't match the expected shape.
|
90 |
+
AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
|
91 |
+
"""
|
92 |
+
ndim = x.ndim
|
93 |
+
assert 0 <= 1 < ndim
|
94 |
+
|
95 |
+
if isinstance(freqs_cis, tuple):
|
96 |
+
# freqs_cis: (cos, sin) in real space
|
97 |
+
if head_first:
|
98 |
+
assert freqs_cis[0].shape == (
|
99 |
+
x.shape[-2],
|
100 |
+
x.shape[-1],
|
101 |
+
), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
|
102 |
+
shape = [
|
103 |
+
d if i == ndim - 2 or i == ndim - 1 else 1
|
104 |
+
for i, d in enumerate(x.shape)
|
105 |
+
]
|
106 |
+
else:
|
107 |
+
assert freqs_cis[0].shape == (
|
108 |
+
x.shape[1],
|
109 |
+
x.shape[-1],
|
110 |
+
), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
|
111 |
+
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
112 |
+
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
|
113 |
+
else:
|
114 |
+
# freqs_cis: values in complex space
|
115 |
+
if head_first:
|
116 |
+
assert freqs_cis.shape == (
|
117 |
+
x.shape[-2],
|
118 |
+
x.shape[-1],
|
119 |
+
), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
|
120 |
+
shape = [
|
121 |
+
d if i == ndim - 2 or i == ndim - 1 else 1
|
122 |
+
for i, d in enumerate(x.shape)
|
123 |
+
]
|
124 |
+
else:
|
125 |
+
assert freqs_cis.shape == (
|
126 |
+
x.shape[1],
|
127 |
+
x.shape[-1],
|
128 |
+
), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
|
129 |
+
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
130 |
+
return freqs_cis.view(*shape)
|
131 |
+
|
132 |
+
|
133 |
+
def rotate_half(x):
|
134 |
+
x_real, x_imag = (
|
135 |
+
x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
|
136 |
+
) # [B, S, H, D//2]
|
137 |
+
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
138 |
+
|
139 |
+
|
140 |
+
def apply_rotary_emb(
|
141 |
+
xq: torch.Tensor,
|
142 |
+
xk: torch.Tensor,
|
143 |
+
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
|
144 |
+
head_first: bool = False,
|
145 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
146 |
+
"""
|
147 |
+
Apply rotary embeddings to input tensors using the given frequency tensor.
|
148 |
+
|
149 |
+
This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
|
150 |
+
frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
|
151 |
+
is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
|
152 |
+
returned as real tensors.
|
153 |
+
|
154 |
+
Args:
|
155 |
+
xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
|
156 |
+
xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D]
|
157 |
+
freqs_cis (torch.Tensor or tuple): Precomputed frequency tensor for complex exponential.
|
158 |
+
head_first (bool): head dimension first (except batch dim) or not.
|
159 |
+
|
160 |
+
Returns:
|
161 |
+
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
162 |
+
|
163 |
+
"""
|
164 |
+
xk_out = None
|
165 |
+
if isinstance(freqs_cis, tuple):
|
166 |
+
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
|
167 |
+
cos, sin = cos.to(xq.device), sin.to(xq.device)
|
168 |
+
# real * cos - imag * sin
|
169 |
+
# imag * cos + real * sin
|
170 |
+
xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
|
171 |
+
xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
|
172 |
+
else:
|
173 |
+
# view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex)
|
174 |
+
xq_ = torch.view_as_complex(
|
175 |
+
xq.float().reshape(*xq.shape[:-1], -1, 2)
|
176 |
+
) # [B, S, H, D//2]
|
177 |
+
freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(
|
178 |
+
xq.device
|
179 |
+
) # [S, D//2] --> [1, S, 1, D//2]
|
180 |
+
# (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin)
|
181 |
+
# view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real)
|
182 |
+
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
|
183 |
+
xk_ = torch.view_as_complex(
|
184 |
+
xk.float().reshape(*xk.shape[:-1], -1, 2)
|
185 |
+
) # [B, S, H, D//2]
|
186 |
+
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
|
187 |
+
|
188 |
+
return xq_out, xk_out
|
189 |
+
|
190 |
+
|
191 |
+
def get_nd_rotary_pos_embed(
|
192 |
+
rope_dim_list,
|
193 |
+
start,
|
194 |
+
*args,
|
195 |
+
theta=10000.0,
|
196 |
+
use_real=False,
|
197 |
+
theta_rescale_factor: Union[float, List[float]] = 1.0,
|
198 |
+
interpolation_factor: Union[float, List[float]] = 1.0,
|
199 |
+
):
|
200 |
+
"""
|
201 |
+
This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
|
202 |
+
|
203 |
+
Args:
|
204 |
+
rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
|
205 |
+
sum(rope_dim_list) should equal to head_dim of attention layer.
|
206 |
+
start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start,
|
207 |
+
args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.
|
208 |
+
*args: See above.
|
209 |
+
theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
|
210 |
+
use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
211 |
+
Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real
|
212 |
+
part and an imaginary part separately.
|
213 |
+
theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.
|
214 |
+
|
215 |
+
Returns:
|
216 |
+
pos_embed (torch.Tensor): [HW, D/2]
|
217 |
+
"""
|
218 |
+
|
219 |
+
grid = get_meshgrid_nd(
|
220 |
+
start, *args, dim=len(rope_dim_list)
|
221 |
+
) # [3, W, H, D] / [2, W, H]
|
222 |
+
|
223 |
+
if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float):
|
224 |
+
theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
|
225 |
+
elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
|
226 |
+
theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
|
227 |
+
assert len(theta_rescale_factor) == len(
|
228 |
+
rope_dim_list
|
229 |
+
), "len(theta_rescale_factor) should equal to len(rope_dim_list)"
|
230 |
+
|
231 |
+
if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float):
|
232 |
+
interpolation_factor = [interpolation_factor] * len(rope_dim_list)
|
233 |
+
elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
|
234 |
+
interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
|
235 |
+
assert len(interpolation_factor) == len(
|
236 |
+
rope_dim_list
|
237 |
+
), "len(interpolation_factor) should equal to len(rope_dim_list)"
|
238 |
+
|
239 |
+
# use 1/ndim of dimensions to encode grid_axis
|
240 |
+
embs = []
|
241 |
+
for i in range(len(rope_dim_list)):
|
242 |
+
emb = get_1d_rotary_pos_embed(
|
243 |
+
rope_dim_list[i],
|
244 |
+
grid[i].reshape(-1),
|
245 |
+
theta,
|
246 |
+
use_real=use_real,
|
247 |
+
theta_rescale_factor=theta_rescale_factor[i],
|
248 |
+
interpolation_factor=interpolation_factor[i],
|
249 |
+
) # 2 x [WHD, rope_dim_list[i]]
|
250 |
+
embs.append(emb)
|
251 |
+
|
252 |
+
if use_real:
|
253 |
+
cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2)
|
254 |
+
sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2)
|
255 |
+
return cos, sin
|
256 |
+
else:
|
257 |
+
emb = torch.cat(embs, dim=1) # (WHD, D/2)
|
258 |
+
return emb
|
259 |
+
|
260 |
+
|
261 |
+
def get_1d_rotary_pos_embed(
|
262 |
+
dim: int,
|
263 |
+
pos: Union[torch.FloatTensor, int],
|
264 |
+
theta: float = 10000.0,
|
265 |
+
use_real: bool = False,
|
266 |
+
theta_rescale_factor: float = 1.0,
|
267 |
+
interpolation_factor: float = 1.0,
|
268 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
269 |
+
"""
|
270 |
+
Precompute the frequency tensor for complex exponential (cis) with given dimensions.
|
271 |
+
(Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)
|
272 |
+
|
273 |
+
This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
|
274 |
+
and the end index 'end'. The 'theta' parameter scales the frequencies.
|
275 |
+
The returned tensor contains complex values in complex64 data type.
|
276 |
+
|
277 |
+
Args:
|
278 |
+
dim (int): Dimension of the frequency tensor.
|
279 |
+
pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
|
280 |
+
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
|
281 |
+
use_real (bool, optional): If True, return real part and imaginary part separately.
|
282 |
+
Otherwise, return complex numbers.
|
283 |
+
theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.
|
284 |
+
|
285 |
+
Returns:
|
286 |
+
freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2]
|
287 |
+
freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
|
288 |
+
"""
|
289 |
+
if isinstance(pos, int):
|
290 |
+
pos = torch.arange(pos).float()
|
291 |
+
|
292 |
+
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
293 |
+
# has some connection to NTK literature
|
294 |
+
if theta_rescale_factor != 1.0:
|
295 |
+
theta *= theta_rescale_factor ** (dim / (dim - 2))
|
296 |
+
|
297 |
+
freqs = 1.0 / (
|
298 |
+
theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
|
299 |
+
) # [D/2]
|
300 |
+
# assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}"
|
301 |
+
freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2]
|
302 |
+
if use_real:
|
303 |
+
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
|
304 |
+
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
|
305 |
+
return freqs_cos, freqs_sin
|
306 |
+
else:
|
307 |
+
freqs_cis = torch.polar(
|
308 |
+
torch.ones_like(freqs), freqs
|
309 |
+
) # complex64 # [S, D/2]
|
310 |
+
return freqs_cis
|
hyvideo/modules/token_refiner.py
ADDED
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
from einops import rearrange
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
from .activation_layers import get_activation_layer
|
8 |
+
from .attenion import attention
|
9 |
+
from .norm_layers import get_norm_layer
|
10 |
+
from .embed_layers import TimestepEmbedder, TextProjection
|
11 |
+
from .attenion import attention
|
12 |
+
from .mlp_layers import MLP
|
13 |
+
from .modulate_layers import modulate, apply_gate
|
14 |
+
|
15 |
+
|
16 |
+
class IndividualTokenRefinerBlock(nn.Module):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
hidden_size,
|
20 |
+
heads_num,
|
21 |
+
mlp_width_ratio: str = 4.0,
|
22 |
+
mlp_drop_rate: float = 0.0,
|
23 |
+
act_type: str = "silu",
|
24 |
+
qk_norm: bool = False,
|
25 |
+
qk_norm_type: str = "layer",
|
26 |
+
qkv_bias: bool = True,
|
27 |
+
dtype: Optional[torch.dtype] = None,
|
28 |
+
device: Optional[torch.device] = None,
|
29 |
+
):
|
30 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
31 |
+
super().__init__()
|
32 |
+
self.heads_num = heads_num
|
33 |
+
head_dim = hidden_size // heads_num
|
34 |
+
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
|
35 |
+
|
36 |
+
self.norm1 = nn.LayerNorm(
|
37 |
+
hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
|
38 |
+
)
|
39 |
+
self.self_attn_qkv = nn.Linear(
|
40 |
+
hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs
|
41 |
+
)
|
42 |
+
qk_norm_layer = get_norm_layer(qk_norm_type)
|
43 |
+
self.self_attn_q_norm = (
|
44 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
45 |
+
if qk_norm
|
46 |
+
else nn.Identity()
|
47 |
+
)
|
48 |
+
self.self_attn_k_norm = (
|
49 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
50 |
+
if qk_norm
|
51 |
+
else nn.Identity()
|
52 |
+
)
|
53 |
+
self.self_attn_proj = nn.Linear(
|
54 |
+
hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
|
55 |
+
)
|
56 |
+
|
57 |
+
self.norm2 = nn.LayerNorm(
|
58 |
+
hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
|
59 |
+
)
|
60 |
+
act_layer = get_activation_layer(act_type)
|
61 |
+
self.mlp = MLP(
|
62 |
+
in_channels=hidden_size,
|
63 |
+
hidden_channels=mlp_hidden_dim,
|
64 |
+
act_layer=act_layer,
|
65 |
+
drop=mlp_drop_rate,
|
66 |
+
**factory_kwargs,
|
67 |
+
)
|
68 |
+
|
69 |
+
self.adaLN_modulation = nn.Sequential(
|
70 |
+
act_layer(),
|
71 |
+
nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
|
72 |
+
)
|
73 |
+
# Zero-initialize the modulation
|
74 |
+
nn.init.zeros_(self.adaLN_modulation[1].weight)
|
75 |
+
nn.init.zeros_(self.adaLN_modulation[1].bias)
|
76 |
+
|
77 |
+
def forward(
|
78 |
+
self,
|
79 |
+
x: torch.Tensor,
|
80 |
+
c: torch.Tensor, # timestep_aware_representations + context_aware_representations
|
81 |
+
attn_mask: torch.Tensor = None,
|
82 |
+
):
|
83 |
+
gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
|
84 |
+
|
85 |
+
norm_x = self.norm1(x)
|
86 |
+
qkv = self.self_attn_qkv(norm_x)
|
87 |
+
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
88 |
+
# Apply QK-Norm if needed
|
89 |
+
q = self.self_attn_q_norm(q).to(v)
|
90 |
+
k = self.self_attn_k_norm(k).to(v)
|
91 |
+
|
92 |
+
# Self-Attention
|
93 |
+
attn = attention(q, k, v, mode="torch", attn_mask=attn_mask)
|
94 |
+
|
95 |
+
x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
|
96 |
+
|
97 |
+
# FFN Layer
|
98 |
+
x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp)
|
99 |
+
|
100 |
+
return x
|
101 |
+
|
102 |
+
|
103 |
+
class IndividualTokenRefiner(nn.Module):
|
104 |
+
def __init__(
|
105 |
+
self,
|
106 |
+
hidden_size,
|
107 |
+
heads_num,
|
108 |
+
depth,
|
109 |
+
mlp_width_ratio: float = 4.0,
|
110 |
+
mlp_drop_rate: float = 0.0,
|
111 |
+
act_type: str = "silu",
|
112 |
+
qk_norm: bool = False,
|
113 |
+
qk_norm_type: str = "layer",
|
114 |
+
qkv_bias: bool = True,
|
115 |
+
dtype: Optional[torch.dtype] = None,
|
116 |
+
device: Optional[torch.device] = None,
|
117 |
+
):
|
118 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
119 |
+
super().__init__()
|
120 |
+
self.blocks = nn.ModuleList(
|
121 |
+
[
|
122 |
+
IndividualTokenRefinerBlock(
|
123 |
+
hidden_size=hidden_size,
|
124 |
+
heads_num=heads_num,
|
125 |
+
mlp_width_ratio=mlp_width_ratio,
|
126 |
+
mlp_drop_rate=mlp_drop_rate,
|
127 |
+
act_type=act_type,
|
128 |
+
qk_norm=qk_norm,
|
129 |
+
qk_norm_type=qk_norm_type,
|
130 |
+
qkv_bias=qkv_bias,
|
131 |
+
**factory_kwargs,
|
132 |
+
)
|
133 |
+
for _ in range(depth)
|
134 |
+
]
|
135 |
+
)
|
136 |
+
|
137 |
+
def forward(
|
138 |
+
self,
|
139 |
+
x: torch.Tensor,
|
140 |
+
c: torch.LongTensor,
|
141 |
+
mask: Optional[torch.Tensor] = None,
|
142 |
+
):
|
143 |
+
self_attn_mask = None
|
144 |
+
if mask is not None:
|
145 |
+
batch_size = mask.shape[0]
|
146 |
+
seq_len = mask.shape[1]
|
147 |
+
mask = mask.to(x.device)
|
148 |
+
# batch_size x 1 x seq_len x seq_len
|
149 |
+
self_attn_mask_1 = mask.view(batch_size, 1, 1, seq_len).repeat(
|
150 |
+
1, 1, seq_len, 1
|
151 |
+
)
|
152 |
+
# batch_size x 1 x seq_len x seq_len
|
153 |
+
self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
|
154 |
+
# batch_size x 1 x seq_len x seq_len, 1 for broadcasting of heads_num
|
155 |
+
self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
|
156 |
+
# avoids self-attention weight being NaN for padding tokens
|
157 |
+
self_attn_mask[:, :, :, 0] = True
|
158 |
+
|
159 |
+
for block in self.blocks:
|
160 |
+
x = block(x, c, self_attn_mask)
|
161 |
+
return x
|
162 |
+
|
163 |
+
|
164 |
+
class SingleTokenRefiner(nn.Module):
|
165 |
+
"""
|
166 |
+
A single token refiner block for llm text embedding refine.
|
167 |
+
"""
|
168 |
+
def __init__(
|
169 |
+
self,
|
170 |
+
in_channels,
|
171 |
+
hidden_size,
|
172 |
+
heads_num,
|
173 |
+
depth,
|
174 |
+
mlp_width_ratio: float = 4.0,
|
175 |
+
mlp_drop_rate: float = 0.0,
|
176 |
+
act_type: str = "silu",
|
177 |
+
qk_norm: bool = False,
|
178 |
+
qk_norm_type: str = "layer",
|
179 |
+
qkv_bias: bool = True,
|
180 |
+
attn_mode: str = "torch",
|
181 |
+
dtype: Optional[torch.dtype] = None,
|
182 |
+
device: Optional[torch.device] = None,
|
183 |
+
):
|
184 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
185 |
+
super().__init__()
|
186 |
+
self.attn_mode = attn_mode
|
187 |
+
assert self.attn_mode == "torch", "Only support 'torch' mode for token refiner."
|
188 |
+
|
189 |
+
self.input_embedder = nn.Linear(
|
190 |
+
in_channels, hidden_size, bias=True, **factory_kwargs
|
191 |
+
)
|
192 |
+
|
193 |
+
act_layer = get_activation_layer(act_type)
|
194 |
+
# Build timestep embedding layer
|
195 |
+
self.t_embedder = TimestepEmbedder(hidden_size, act_layer, **factory_kwargs)
|
196 |
+
# Build context embedding layer
|
197 |
+
self.c_embedder = TextProjection(
|
198 |
+
in_channels, hidden_size, act_layer, **factory_kwargs
|
199 |
+
)
|
200 |
+
|
201 |
+
self.individual_token_refiner = IndividualTokenRefiner(
|
202 |
+
hidden_size=hidden_size,
|
203 |
+
heads_num=heads_num,
|
204 |
+
depth=depth,
|
205 |
+
mlp_width_ratio=mlp_width_ratio,
|
206 |
+
mlp_drop_rate=mlp_drop_rate,
|
207 |
+
act_type=act_type,
|
208 |
+
qk_norm=qk_norm,
|
209 |
+
qk_norm_type=qk_norm_type,
|
210 |
+
qkv_bias=qkv_bias,
|
211 |
+
**factory_kwargs,
|
212 |
+
)
|
213 |
+
|
214 |
+
def forward(
|
215 |
+
self,
|
216 |
+
x: torch.Tensor,
|
217 |
+
t: torch.LongTensor,
|
218 |
+
mask: Optional[torch.LongTensor] = None,
|
219 |
+
):
|
220 |
+
timestep_aware_representations = self.t_embedder(t)
|
221 |
+
|
222 |
+
if mask is None:
|
223 |
+
context_aware_representations = x.mean(dim=1)
|
224 |
+
else:
|
225 |
+
mask_float = mask.float().unsqueeze(-1) # [b, s1, 1]
|
226 |
+
context_aware_representations = (x * mask_float).sum(
|
227 |
+
dim=1
|
228 |
+
) / mask_float.sum(dim=1)
|
229 |
+
context_aware_representations = self.c_embedder(context_aware_representations)
|
230 |
+
c = timestep_aware_representations + context_aware_representations
|
231 |
+
|
232 |
+
x = self.input_embedder(x)
|
233 |
+
|
234 |
+
x = self.individual_token_refiner(x, c, mask)
|
235 |
+
|
236 |
+
return x
|
hyvideo/prompt_rewrite.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
normal_mode_prompt = """Normal mode - Video Recaption Task:
|
2 |
+
|
3 |
+
You are a large language model specialized in rewriting video descriptions. Your task is to modify the input description.
|
4 |
+
|
5 |
+
0. Preserve ALL information, including style words and technical terms.
|
6 |
+
|
7 |
+
1. If the input is in Chinese, translate the entire description to English.
|
8 |
+
|
9 |
+
2. If the input is just one or two words describing an object or person, provide a brief, simple description focusing on basic visual characteristics. Limit the description to 1-2 short sentences.
|
10 |
+
|
11 |
+
3. If the input does not include style, lighting, atmosphere, you can make reasonable associations.
|
12 |
+
|
13 |
+
4. Output ALL must be in English.
|
14 |
+
|
15 |
+
Given Input:
|
16 |
+
input: "{input}"
|
17 |
+
"""
|
18 |
+
|
19 |
+
|
20 |
+
master_mode_prompt = """Master mode - Video Recaption Task:
|
21 |
+
|
22 |
+
You are a large language model specialized in rewriting video descriptions. Your task is to modify the input description.
|
23 |
+
|
24 |
+
0. Preserve ALL information, including style words and technical terms.
|
25 |
+
|
26 |
+
1. If the input is in Chinese, translate the entire description to English.
|
27 |
+
|
28 |
+
2. If the input is just one or two words describing an object or person, provide a brief, simple description focusing on basic visual characteristics. Limit the description to 1-2 short sentences.
|
29 |
+
|
30 |
+
3. If the input does not include style, lighting, atmosphere, you can make reasonable associations.
|
31 |
+
|
32 |
+
4. Output ALL must be in English.
|
33 |
+
|
34 |
+
Given Input:
|
35 |
+
input: "{input}"
|
36 |
+
"""
|
37 |
+
|
38 |
+
def get_rewrite_prompt(ori_prompt, mode="Normal"):
|
39 |
+
if mode == "Normal":
|
40 |
+
prompt = normal_mode_prompt.format(input=ori_prompt)
|
41 |
+
elif mode == "Master":
|
42 |
+
prompt = master_mode_prompt.format(input=ori_prompt)
|
43 |
+
else:
|
44 |
+
raise Exception("Only supports Normal and Normal", mode)
|
45 |
+
return prompt
|
46 |
+
|
47 |
+
ori_prompt = "一只小狗在草地上奔跑。"
|
48 |
+
normal_prompt = get_rewrite_prompt(ori_prompt, mode="Normal")
|
49 |
+
master_prompt = get_rewrite_prompt(ori_prompt, mode="Master")
|
50 |
+
|
51 |
+
# Then you can use the normal_prompt or master_prompt to access the hunyuan-large rewrite model to get the final prompt.
|
hyvideo/text_encoder/__init__.py
ADDED
@@ -0,0 +1,366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Optional, Tuple
|
3 |
+
from copy import deepcopy
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from transformers import CLIPTextModel, CLIPTokenizer, AutoTokenizer, AutoModel
|
8 |
+
from transformers.utils import ModelOutput
|
9 |
+
|
10 |
+
from ..constants import TEXT_ENCODER_PATH, TOKENIZER_PATH
|
11 |
+
from ..constants import PRECISION_TO_TYPE
|
12 |
+
|
13 |
+
|
14 |
+
def use_default(value, default):
|
15 |
+
return value if value is not None else default
|
16 |
+
|
17 |
+
|
18 |
+
def load_text_encoder(
|
19 |
+
text_encoder_type,
|
20 |
+
text_encoder_precision=None,
|
21 |
+
text_encoder_path=None,
|
22 |
+
logger=None,
|
23 |
+
device=None,
|
24 |
+
):
|
25 |
+
if text_encoder_path is None:
|
26 |
+
text_encoder_path = TEXT_ENCODER_PATH[text_encoder_type]
|
27 |
+
if logger is not None:
|
28 |
+
logger.info(
|
29 |
+
f"Loading text encoder model ({text_encoder_type}) from: {text_encoder_path}"
|
30 |
+
)
|
31 |
+
|
32 |
+
if text_encoder_type == "clipL":
|
33 |
+
text_encoder = CLIPTextModel.from_pretrained(text_encoder_path)
|
34 |
+
text_encoder.final_layer_norm = text_encoder.text_model.final_layer_norm
|
35 |
+
elif text_encoder_type == "llm":
|
36 |
+
text_encoder = AutoModel.from_pretrained(
|
37 |
+
text_encoder_path, low_cpu_mem_usage=True
|
38 |
+
)
|
39 |
+
text_encoder.final_layer_norm = text_encoder.norm
|
40 |
+
else:
|
41 |
+
raise ValueError(f"Unsupported text encoder type: {text_encoder_type}")
|
42 |
+
# from_pretrained will ensure that the model is in eval mode.
|
43 |
+
|
44 |
+
if text_encoder_precision is not None:
|
45 |
+
text_encoder = text_encoder.to(dtype=PRECISION_TO_TYPE[text_encoder_precision])
|
46 |
+
|
47 |
+
text_encoder.requires_grad_(False)
|
48 |
+
|
49 |
+
if logger is not None:
|
50 |
+
logger.info(f"Text encoder to dtype: {text_encoder.dtype}")
|
51 |
+
|
52 |
+
if device is not None:
|
53 |
+
text_encoder = text_encoder.to(device)
|
54 |
+
|
55 |
+
return text_encoder, text_encoder_path
|
56 |
+
|
57 |
+
|
58 |
+
def load_tokenizer(
|
59 |
+
tokenizer_type, tokenizer_path=None, padding_side="right", logger=None
|
60 |
+
):
|
61 |
+
if tokenizer_path is None:
|
62 |
+
tokenizer_path = TOKENIZER_PATH[tokenizer_type]
|
63 |
+
if logger is not None:
|
64 |
+
logger.info(f"Loading tokenizer ({tokenizer_type}) from: {tokenizer_path}")
|
65 |
+
|
66 |
+
if tokenizer_type == "clipL":
|
67 |
+
tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path, max_length=77)
|
68 |
+
elif tokenizer_type == "llm":
|
69 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
70 |
+
tokenizer_path, padding_side=padding_side
|
71 |
+
)
|
72 |
+
else:
|
73 |
+
raise ValueError(f"Unsupported tokenizer type: {tokenizer_type}")
|
74 |
+
|
75 |
+
return tokenizer, tokenizer_path
|
76 |
+
|
77 |
+
|
78 |
+
@dataclass
|
79 |
+
class TextEncoderModelOutput(ModelOutput):
|
80 |
+
"""
|
81 |
+
Base class for model's outputs that also contains a pooling of the last hidden states.
|
82 |
+
|
83 |
+
Args:
|
84 |
+
hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
85 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
86 |
+
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
87 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
|
88 |
+
hidden_states_list (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed):
|
89 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
90 |
+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
91 |
+
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
92 |
+
text_outputs (`list`, *optional*, returned when `return_texts=True` is passed):
|
93 |
+
List of decoded texts.
|
94 |
+
"""
|
95 |
+
|
96 |
+
hidden_state: torch.FloatTensor = None
|
97 |
+
attention_mask: Optional[torch.LongTensor] = None
|
98 |
+
hidden_states_list: Optional[Tuple[torch.FloatTensor, ...]] = None
|
99 |
+
text_outputs: Optional[list] = None
|
100 |
+
|
101 |
+
|
102 |
+
class TextEncoder(nn.Module):
|
103 |
+
def __init__(
|
104 |
+
self,
|
105 |
+
text_encoder_type: str,
|
106 |
+
max_length: int,
|
107 |
+
text_encoder_precision: Optional[str] = None,
|
108 |
+
text_encoder_path: Optional[str] = None,
|
109 |
+
tokenizer_type: Optional[str] = None,
|
110 |
+
tokenizer_path: Optional[str] = None,
|
111 |
+
output_key: Optional[str] = None,
|
112 |
+
use_attention_mask: bool = True,
|
113 |
+
input_max_length: Optional[int] = None,
|
114 |
+
prompt_template: Optional[dict] = None,
|
115 |
+
prompt_template_video: Optional[dict] = None,
|
116 |
+
hidden_state_skip_layer: Optional[int] = None,
|
117 |
+
apply_final_norm: bool = False,
|
118 |
+
reproduce: bool = False,
|
119 |
+
logger=None,
|
120 |
+
device=None,
|
121 |
+
):
|
122 |
+
super().__init__()
|
123 |
+
self.text_encoder_type = text_encoder_type
|
124 |
+
self.max_length = max_length
|
125 |
+
self.precision = text_encoder_precision
|
126 |
+
self.model_path = text_encoder_path
|
127 |
+
self.tokenizer_type = (
|
128 |
+
tokenizer_type if tokenizer_type is not None else text_encoder_type
|
129 |
+
)
|
130 |
+
self.tokenizer_path = (
|
131 |
+
tokenizer_path if tokenizer_path is not None else None # text_encoder_path
|
132 |
+
)
|
133 |
+
self.use_attention_mask = use_attention_mask
|
134 |
+
if prompt_template_video is not None:
|
135 |
+
assert (
|
136 |
+
use_attention_mask is True
|
137 |
+
), "Attention mask is True required when training videos."
|
138 |
+
self.input_max_length = (
|
139 |
+
input_max_length if input_max_length is not None else max_length
|
140 |
+
)
|
141 |
+
self.prompt_template = prompt_template
|
142 |
+
self.prompt_template_video = prompt_template_video
|
143 |
+
self.hidden_state_skip_layer = hidden_state_skip_layer
|
144 |
+
self.apply_final_norm = apply_final_norm
|
145 |
+
self.reproduce = reproduce
|
146 |
+
self.logger = logger
|
147 |
+
|
148 |
+
self.use_template = self.prompt_template is not None
|
149 |
+
if self.use_template:
|
150 |
+
assert (
|
151 |
+
isinstance(self.prompt_template, dict)
|
152 |
+
and "template" in self.prompt_template
|
153 |
+
), f"`prompt_template` must be a dictionary with a key 'template', got {self.prompt_template}"
|
154 |
+
assert "{}" in str(self.prompt_template["template"]), (
|
155 |
+
"`prompt_template['template']` must contain a placeholder `{}` for the input text, "
|
156 |
+
f"got {self.prompt_template['template']}"
|
157 |
+
)
|
158 |
+
|
159 |
+
self.use_video_template = self.prompt_template_video is not None
|
160 |
+
if self.use_video_template:
|
161 |
+
if self.prompt_template_video is not None:
|
162 |
+
assert (
|
163 |
+
isinstance(self.prompt_template_video, dict)
|
164 |
+
and "template" in self.prompt_template_video
|
165 |
+
), f"`prompt_template_video` must be a dictionary with a key 'template', got {self.prompt_template_video}"
|
166 |
+
assert "{}" in str(self.prompt_template_video["template"]), (
|
167 |
+
"`prompt_template_video['template']` must contain a placeholder `{}` for the input text, "
|
168 |
+
f"got {self.prompt_template_video['template']}"
|
169 |
+
)
|
170 |
+
|
171 |
+
if "t5" in text_encoder_type:
|
172 |
+
self.output_key = output_key or "last_hidden_state"
|
173 |
+
elif "clip" in text_encoder_type:
|
174 |
+
self.output_key = output_key or "pooler_output"
|
175 |
+
elif "llm" in text_encoder_type or "glm" in text_encoder_type:
|
176 |
+
self.output_key = output_key or "last_hidden_state"
|
177 |
+
else:
|
178 |
+
raise ValueError(f"Unsupported text encoder type: {text_encoder_type}")
|
179 |
+
|
180 |
+
if "llm" in text_encoder_type:
|
181 |
+
|
182 |
+
from mmgp import offload
|
183 |
+
|
184 |
+
self.model= offload.fast_load_transformers_model(self.model_path) #, pinInMemory = True, partialPinning = True
|
185 |
+
self.model.final_layer_norm = self.model.norm
|
186 |
+
|
187 |
+
else:
|
188 |
+
self.model, self.model_path = load_text_encoder(
|
189 |
+
text_encoder_type=self.text_encoder_type,
|
190 |
+
text_encoder_precision=self.precision,
|
191 |
+
text_encoder_path=self.model_path,
|
192 |
+
logger=self.logger,
|
193 |
+
device=device,
|
194 |
+
)
|
195 |
+
|
196 |
+
self.dtype = self.model.dtype
|
197 |
+
self.device = self.model.device
|
198 |
+
|
199 |
+
self.tokenizer, self.tokenizer_path = load_tokenizer(
|
200 |
+
tokenizer_type=self.tokenizer_type,
|
201 |
+
tokenizer_path=self.tokenizer_path,
|
202 |
+
padding_side="right",
|
203 |
+
logger=self.logger,
|
204 |
+
)
|
205 |
+
|
206 |
+
def __repr__(self):
|
207 |
+
return f"{self.text_encoder_type} ({self.precision} - {self.model_path})"
|
208 |
+
|
209 |
+
@staticmethod
|
210 |
+
def apply_text_to_template(text, template, prevent_empty_text=True):
|
211 |
+
"""
|
212 |
+
Apply text to template.
|
213 |
+
|
214 |
+
Args:
|
215 |
+
text (str): Input text.
|
216 |
+
template (str or list): Template string or list of chat conversation.
|
217 |
+
prevent_empty_text (bool): If Ture, we will prevent the user text from being empty
|
218 |
+
by adding a space. Defaults to True.
|
219 |
+
"""
|
220 |
+
if isinstance(template, str):
|
221 |
+
# Will send string to tokenizer. Used for llm
|
222 |
+
return template.format(text)
|
223 |
+
else:
|
224 |
+
raise TypeError(f"Unsupported template type: {type(template)}")
|
225 |
+
|
226 |
+
def text2tokens(self, text, data_type="image"):
|
227 |
+
"""
|
228 |
+
Tokenize the input text.
|
229 |
+
|
230 |
+
Args:
|
231 |
+
text (str or list): Input text.
|
232 |
+
"""
|
233 |
+
tokenize_input_type = "str"
|
234 |
+
if self.use_template:
|
235 |
+
if data_type == "image":
|
236 |
+
prompt_template = self.prompt_template["template"]
|
237 |
+
elif data_type == "video":
|
238 |
+
prompt_template = self.prompt_template_video["template"]
|
239 |
+
else:
|
240 |
+
raise ValueError(f"Unsupported data type: {data_type}")
|
241 |
+
if isinstance(text, (list, tuple)):
|
242 |
+
text = [
|
243 |
+
self.apply_text_to_template(one_text, prompt_template)
|
244 |
+
for one_text in text
|
245 |
+
]
|
246 |
+
if isinstance(text[0], list):
|
247 |
+
tokenize_input_type = "list"
|
248 |
+
elif isinstance(text, str):
|
249 |
+
text = self.apply_text_to_template(text, prompt_template)
|
250 |
+
if isinstance(text, list):
|
251 |
+
tokenize_input_type = "list"
|
252 |
+
else:
|
253 |
+
raise TypeError(f"Unsupported text type: {type(text)}")
|
254 |
+
|
255 |
+
kwargs = dict(
|
256 |
+
truncation=True,
|
257 |
+
max_length=self.max_length,
|
258 |
+
padding="max_length",
|
259 |
+
return_tensors="pt",
|
260 |
+
)
|
261 |
+
if tokenize_input_type == "str":
|
262 |
+
return self.tokenizer(
|
263 |
+
text,
|
264 |
+
return_length=False,
|
265 |
+
return_overflowing_tokens=False,
|
266 |
+
return_attention_mask=True,
|
267 |
+
**kwargs,
|
268 |
+
)
|
269 |
+
elif tokenize_input_type == "list":
|
270 |
+
return self.tokenizer.apply_chat_template(
|
271 |
+
text,
|
272 |
+
add_generation_prompt=True,
|
273 |
+
tokenize=True,
|
274 |
+
return_dict=True,
|
275 |
+
**kwargs,
|
276 |
+
)
|
277 |
+
else:
|
278 |
+
raise ValueError(f"Unsupported tokenize_input_type: {tokenize_input_type}")
|
279 |
+
|
280 |
+
def encode(
|
281 |
+
self,
|
282 |
+
batch_encoding,
|
283 |
+
use_attention_mask=None,
|
284 |
+
output_hidden_states=False,
|
285 |
+
do_sample=None,
|
286 |
+
hidden_state_skip_layer=None,
|
287 |
+
return_texts=False,
|
288 |
+
data_type="image",
|
289 |
+
device=None,
|
290 |
+
):
|
291 |
+
"""
|
292 |
+
Args:
|
293 |
+
batch_encoding (dict): Batch encoding from tokenizer.
|
294 |
+
use_attention_mask (bool): Whether to use attention mask. If None, use self.use_attention_mask.
|
295 |
+
Defaults to None.
|
296 |
+
output_hidden_states (bool): Whether to output hidden states. If False, return the value of
|
297 |
+
self.output_key. If True, return the entire output. If set self.hidden_state_skip_layer,
|
298 |
+
output_hidden_states will be set True. Defaults to False.
|
299 |
+
do_sample (bool): Whether to sample from the model. Used for Decoder-Only LLMs. Defaults to None.
|
300 |
+
When self.produce is False, do_sample is set to True by default.
|
301 |
+
hidden_state_skip_layer (int): Number of hidden states to hidden_state_skip_layer. 0 means the last layer.
|
302 |
+
If None, self.output_key will be used. Defaults to None.
|
303 |
+
return_texts (bool): Whether to return the decoded texts. Defaults to False.
|
304 |
+
"""
|
305 |
+
device = self.model.device if device is None else device
|
306 |
+
use_attention_mask = use_default(use_attention_mask, self.use_attention_mask)
|
307 |
+
hidden_state_skip_layer = use_default(
|
308 |
+
hidden_state_skip_layer, self.hidden_state_skip_layer
|
309 |
+
)
|
310 |
+
do_sample = use_default(do_sample, not self.reproduce)
|
311 |
+
attention_mask = (
|
312 |
+
batch_encoding["attention_mask"].to(device) if use_attention_mask else None
|
313 |
+
)
|
314 |
+
outputs = self.model(
|
315 |
+
input_ids=batch_encoding["input_ids"].to(device),
|
316 |
+
attention_mask=attention_mask,
|
317 |
+
output_hidden_states=output_hidden_states
|
318 |
+
or hidden_state_skip_layer is not None,
|
319 |
+
)
|
320 |
+
if hidden_state_skip_layer is not None:
|
321 |
+
last_hidden_state = outputs.hidden_states[-(hidden_state_skip_layer + 1)]
|
322 |
+
# Real last hidden state already has layer norm applied. So here we only apply it
|
323 |
+
# for intermediate layers.
|
324 |
+
if hidden_state_skip_layer > 0 and self.apply_final_norm:
|
325 |
+
last_hidden_state = self.model.final_layer_norm(last_hidden_state)
|
326 |
+
else:
|
327 |
+
last_hidden_state = outputs[self.output_key]
|
328 |
+
|
329 |
+
# Remove hidden states of instruction tokens, only keep prompt tokens.
|
330 |
+
if self.use_template:
|
331 |
+
if data_type == "image":
|
332 |
+
crop_start = self.prompt_template.get("crop_start", -1)
|
333 |
+
elif data_type == "video":
|
334 |
+
crop_start = self.prompt_template_video.get("crop_start", -1)
|
335 |
+
else:
|
336 |
+
raise ValueError(f"Unsupported data type: {data_type}")
|
337 |
+
if crop_start > 0:
|
338 |
+
last_hidden_state = last_hidden_state[:, crop_start:]
|
339 |
+
attention_mask = (
|
340 |
+
attention_mask[:, crop_start:] if use_attention_mask else None
|
341 |
+
)
|
342 |
+
|
343 |
+
if output_hidden_states:
|
344 |
+
return TextEncoderModelOutput(
|
345 |
+
last_hidden_state, attention_mask, outputs.hidden_states
|
346 |
+
)
|
347 |
+
return TextEncoderModelOutput(last_hidden_state, attention_mask)
|
348 |
+
|
349 |
+
def forward(
|
350 |
+
self,
|
351 |
+
text,
|
352 |
+
use_attention_mask=None,
|
353 |
+
output_hidden_states=False,
|
354 |
+
do_sample=False,
|
355 |
+
hidden_state_skip_layer=None,
|
356 |
+
return_texts=False,
|
357 |
+
):
|
358 |
+
batch_encoding = self.text2tokens(text)
|
359 |
+
return self.encode(
|
360 |
+
batch_encoding,
|
361 |
+
use_attention_mask=use_attention_mask,
|
362 |
+
output_hidden_states=output_hidden_states,
|
363 |
+
do_sample=do_sample,
|
364 |
+
hidden_state_skip_layer=hidden_state_skip_layer,
|
365 |
+
return_texts=return_texts,
|
366 |
+
)
|
hyvideo/utils/__init__.py
ADDED
File without changes
|
hyvideo/utils/data_utils.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import math
|
3 |
+
|
4 |
+
|
5 |
+
def align_to(value, alignment):
|
6 |
+
"""align hight, width according to alignment
|
7 |
+
|
8 |
+
Args:
|
9 |
+
value (int): height or width
|
10 |
+
alignment (int): target alignment factor
|
11 |
+
|
12 |
+
Returns:
|
13 |
+
int: the aligned value
|
14 |
+
"""
|
15 |
+
return int(math.ceil(value / alignment) * alignment)
|
hyvideo/utils/file_utils.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pathlib import Path
|
3 |
+
from einops import rearrange
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torchvision
|
7 |
+
import numpy as np
|
8 |
+
import imageio
|
9 |
+
|
10 |
+
CODE_SUFFIXES = {
|
11 |
+
".py", # Python codes
|
12 |
+
".sh", # Shell scripts
|
13 |
+
".yaml",
|
14 |
+
".yml", # Configuration files
|
15 |
+
}
|
16 |
+
|
17 |
+
|
18 |
+
def safe_dir(path):
|
19 |
+
"""
|
20 |
+
Create a directory (or the parent directory of a file) if it does not exist.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
path (str or Path): Path to the directory.
|
24 |
+
|
25 |
+
Returns:
|
26 |
+
path (Path): Path object of the directory.
|
27 |
+
"""
|
28 |
+
path = Path(path)
|
29 |
+
path.mkdir(exist_ok=True, parents=True)
|
30 |
+
return path
|
31 |
+
|
32 |
+
|
33 |
+
def safe_file(path):
|
34 |
+
"""
|
35 |
+
Create the parent directory of a file if it does not exist.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
path (str or Path): Path to the file.
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
path (Path): Path object of the file.
|
42 |
+
"""
|
43 |
+
path = Path(path)
|
44 |
+
path.parent.mkdir(exist_ok=True, parents=True)
|
45 |
+
return path
|
46 |
+
|
47 |
+
def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=1, fps=24):
|
48 |
+
"""save videos by video tensor
|
49 |
+
copy from https://github.com/guoyww/AnimateDiff/blob/e92bd5671ba62c0d774a32951453e328018b7c5b/animatediff/utils/util.py#L61
|
50 |
+
|
51 |
+
Args:
|
52 |
+
videos (torch.Tensor): video tensor predicted by the model
|
53 |
+
path (str): path to save video
|
54 |
+
rescale (bool, optional): rescale the video tensor from [-1, 1] to . Defaults to False.
|
55 |
+
n_rows (int, optional): Defaults to 1.
|
56 |
+
fps (int, optional): video save fps. Defaults to 8.
|
57 |
+
"""
|
58 |
+
videos = rearrange(videos, "b c t h w -> t b c h w")
|
59 |
+
outputs = []
|
60 |
+
for x in videos:
|
61 |
+
x = torchvision.utils.make_grid(x, nrow=n_rows)
|
62 |
+
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
|
63 |
+
if rescale:
|
64 |
+
x = (x + 1.0) / 2.0 # -1,1 -> 0,1
|
65 |
+
x = torch.clamp(x, 0, 1)
|
66 |
+
x = (x * 255).numpy().astype(np.uint8)
|
67 |
+
outputs.append(x)
|
68 |
+
|
69 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
70 |
+
imageio.mimsave(path, outputs, fps=fps)
|
hyvideo/utils/helpers.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import collections.abc
|
2 |
+
|
3 |
+
from itertools import repeat
|
4 |
+
|
5 |
+
|
6 |
+
def _ntuple(n):
|
7 |
+
def parse(x):
|
8 |
+
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
9 |
+
x = tuple(x)
|
10 |
+
if len(x) == 1:
|
11 |
+
x = tuple(repeat(x[0], n))
|
12 |
+
return x
|
13 |
+
return tuple(repeat(x, n))
|
14 |
+
return parse
|
15 |
+
|
16 |
+
|
17 |
+
to_1tuple = _ntuple(1)
|
18 |
+
to_2tuple = _ntuple(2)
|
19 |
+
to_3tuple = _ntuple(3)
|
20 |
+
to_4tuple = _ntuple(4)
|
21 |
+
|
22 |
+
|
23 |
+
def as_tuple(x):
|
24 |
+
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
25 |
+
return tuple(x)
|
26 |
+
if x is None or isinstance(x, (int, float, str)):
|
27 |
+
return (x,)
|
28 |
+
else:
|
29 |
+
raise ValueError(f"Unknown type {type(x)}")
|
30 |
+
|
31 |
+
|
32 |
+
def as_list_of_2tuple(x):
|
33 |
+
x = as_tuple(x)
|
34 |
+
if len(x) == 1:
|
35 |
+
x = (x[0], x[0])
|
36 |
+
assert len(x) % 2 == 0, f"Expect even length, got {len(x)}."
|
37 |
+
lst = []
|
38 |
+
for i in range(0, len(x), 2):
|
39 |
+
lst.append((x[i], x[i + 1]))
|
40 |
+
return lst
|
hyvideo/utils/preprocess_text_encoder_tokenizer_utils.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
from transformers import (
|
4 |
+
AutoProcessor,
|
5 |
+
LlavaForConditionalGeneration,
|
6 |
+
)
|
7 |
+
|
8 |
+
|
9 |
+
def preprocess_text_encoder_tokenizer(args):
|
10 |
+
|
11 |
+
processor = AutoProcessor.from_pretrained(args.input_dir)
|
12 |
+
model = LlavaForConditionalGeneration.from_pretrained(
|
13 |
+
args.input_dir,
|
14 |
+
torch_dtype=torch.float16,
|
15 |
+
low_cpu_mem_usage=True,
|
16 |
+
).to(0)
|
17 |
+
|
18 |
+
model.language_model.save_pretrained(
|
19 |
+
f"{args.output_dir}"
|
20 |
+
)
|
21 |
+
processor.tokenizer.save_pretrained(
|
22 |
+
f"{args.output_dir}"
|
23 |
+
)
|
24 |
+
|
25 |
+
if __name__ == "__main__":
|
26 |
+
|
27 |
+
parser = argparse.ArgumentParser()
|
28 |
+
parser.add_argument(
|
29 |
+
"--input_dir",
|
30 |
+
type=str,
|
31 |
+
required=True,
|
32 |
+
help="The path to the llava-llama-3-8b-v1_1-transformers.",
|
33 |
+
)
|
34 |
+
parser.add_argument(
|
35 |
+
"--output_dir",
|
36 |
+
type=str,
|
37 |
+
default="",
|
38 |
+
help="The output path of the llava-llama-3-8b-text-encoder-tokenizer."
|
39 |
+
"if '', the parent dir of output will be the same as input dir.",
|
40 |
+
)
|
41 |
+
args = parser.parse_args()
|
42 |
+
|
43 |
+
if len(args.output_dir) == 0:
|
44 |
+
args.output_dir = "/".join(args.input_dir.split("/")[:-1])
|
45 |
+
|
46 |
+
preprocess_text_encoder_tokenizer(args)
|
hyvideo/vae/__init__.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from .autoencoder_kl_causal_3d import AutoencoderKLCausal3D
|
6 |
+
from ..constants import VAE_PATH, PRECISION_TO_TYPE
|
7 |
+
|
8 |
+
def load_vae(vae_type: str="884-16c-hy",
|
9 |
+
vae_precision: str=None,
|
10 |
+
sample_size: tuple=None,
|
11 |
+
vae_path: str=None,
|
12 |
+
logger=None,
|
13 |
+
device=None
|
14 |
+
):
|
15 |
+
"""the fucntion to load the 3D VAE model
|
16 |
+
|
17 |
+
Args:
|
18 |
+
vae_type (str): the type of the 3D VAE model. Defaults to "884-16c-hy".
|
19 |
+
vae_precision (str, optional): the precision to load vae. Defaults to None.
|
20 |
+
sample_size (tuple, optional): the tiling size. Defaults to None.
|
21 |
+
vae_path (str, optional): the path to vae. Defaults to None.
|
22 |
+
logger (_type_, optional): logger. Defaults to None.
|
23 |
+
device (_type_, optional): device to load vae. Defaults to None.
|
24 |
+
"""
|
25 |
+
if vae_path is None:
|
26 |
+
vae_path = VAE_PATH[vae_type]
|
27 |
+
|
28 |
+
if logger is not None:
|
29 |
+
logger.info(f"Loading 3D VAE model ({vae_type}) from: {vae_path}")
|
30 |
+
config = AutoencoderKLCausal3D.load_config(vae_path)
|
31 |
+
if sample_size:
|
32 |
+
vae = AutoencoderKLCausal3D.from_config(config, sample_size=sample_size)
|
33 |
+
else:
|
34 |
+
vae = AutoencoderKLCausal3D.from_config(config)
|
35 |
+
|
36 |
+
vae_ckpt = Path(vae_path) / "pytorch_model.pt"
|
37 |
+
assert vae_ckpt.exists(), f"VAE checkpoint not found: {vae_ckpt}"
|
38 |
+
|
39 |
+
ckpt = torch.load(vae_ckpt, weights_only=True, map_location=vae.device)
|
40 |
+
if "state_dict" in ckpt:
|
41 |
+
ckpt = ckpt["state_dict"]
|
42 |
+
if any(k.startswith("vae.") for k in ckpt.keys()):
|
43 |
+
ckpt = {k.replace("vae.", ""): v for k, v in ckpt.items() if k.startswith("vae.")}
|
44 |
+
vae.load_state_dict(ckpt)
|
45 |
+
|
46 |
+
spatial_compression_ratio = vae.config.spatial_compression_ratio
|
47 |
+
time_compression_ratio = vae.config.time_compression_ratio
|
48 |
+
|
49 |
+
if vae_precision is not None:
|
50 |
+
vae = vae.to(dtype=PRECISION_TO_TYPE[vae_precision])
|
51 |
+
|
52 |
+
vae.requires_grad_(False)
|
53 |
+
|
54 |
+
if logger is not None:
|
55 |
+
logger.info(f"VAE to dtype: {vae.dtype}")
|
56 |
+
|
57 |
+
if device is not None:
|
58 |
+
vae = vae.to(device)
|
59 |
+
|
60 |
+
vae.eval()
|
61 |
+
|
62 |
+
return vae, vae_path, spatial_compression_ratio, time_compression_ratio
|
hyvideo/vae/autoencoder_kl_causal_3d.py
ADDED
@@ -0,0 +1,603 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
#
|
16 |
+
# Modified from diffusers==0.29.2
|
17 |
+
#
|
18 |
+
# ==============================================================================
|
19 |
+
from typing import Dict, Optional, Tuple, Union
|
20 |
+
from dataclasses import dataclass
|
21 |
+
|
22 |
+
import torch
|
23 |
+
import torch.nn as nn
|
24 |
+
|
25 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
26 |
+
|
27 |
+
try:
|
28 |
+
# This diffusers is modified and packed in the mirror.
|
29 |
+
from diffusers.loaders import FromOriginalVAEMixin
|
30 |
+
except ImportError:
|
31 |
+
# Use this to be compatible with the original diffusers.
|
32 |
+
from diffusers.loaders.single_file_model import FromOriginalModelMixin as FromOriginalVAEMixin
|
33 |
+
from diffusers.utils.accelerate_utils import apply_forward_hook
|
34 |
+
from diffusers.models.attention_processor import (
|
35 |
+
ADDED_KV_ATTENTION_PROCESSORS,
|
36 |
+
CROSS_ATTENTION_PROCESSORS,
|
37 |
+
Attention,
|
38 |
+
AttentionProcessor,
|
39 |
+
AttnAddedKVProcessor,
|
40 |
+
AttnProcessor,
|
41 |
+
)
|
42 |
+
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
43 |
+
from diffusers.models.modeling_utils import ModelMixin
|
44 |
+
from .vae import DecoderCausal3D, BaseOutput, DecoderOutput, DiagonalGaussianDistribution, EncoderCausal3D
|
45 |
+
|
46 |
+
|
47 |
+
@dataclass
|
48 |
+
class DecoderOutput2(BaseOutput):
|
49 |
+
sample: torch.FloatTensor
|
50 |
+
posterior: Optional[DiagonalGaussianDistribution] = None
|
51 |
+
|
52 |
+
|
53 |
+
class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
54 |
+
r"""
|
55 |
+
A VAE model with KL loss for encoding images/videos into latents and decoding latent representations into images/videos.
|
56 |
+
|
57 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
58 |
+
for all models (such as downloading or saving).
|
59 |
+
"""
|
60 |
+
|
61 |
+
_supports_gradient_checkpointing = True
|
62 |
+
|
63 |
+
@register_to_config
|
64 |
+
def __init__(
|
65 |
+
self,
|
66 |
+
in_channels: int = 3,
|
67 |
+
out_channels: int = 3,
|
68 |
+
down_block_types: Tuple[str] = ("DownEncoderBlockCausal3D",),
|
69 |
+
up_block_types: Tuple[str] = ("UpDecoderBlockCausal3D",),
|
70 |
+
block_out_channels: Tuple[int] = (64,),
|
71 |
+
layers_per_block: int = 1,
|
72 |
+
act_fn: str = "silu",
|
73 |
+
latent_channels: int = 4,
|
74 |
+
norm_num_groups: int = 32,
|
75 |
+
sample_size: int = 32,
|
76 |
+
sample_tsize: int = 64,
|
77 |
+
scaling_factor: float = 0.18215,
|
78 |
+
force_upcast: float = True,
|
79 |
+
spatial_compression_ratio: int = 8,
|
80 |
+
time_compression_ratio: int = 4,
|
81 |
+
mid_block_add_attention: bool = True,
|
82 |
+
):
|
83 |
+
super().__init__()
|
84 |
+
|
85 |
+
self.time_compression_ratio = time_compression_ratio
|
86 |
+
|
87 |
+
self.encoder = EncoderCausal3D(
|
88 |
+
in_channels=in_channels,
|
89 |
+
out_channels=latent_channels,
|
90 |
+
down_block_types=down_block_types,
|
91 |
+
block_out_channels=block_out_channels,
|
92 |
+
layers_per_block=layers_per_block,
|
93 |
+
act_fn=act_fn,
|
94 |
+
norm_num_groups=norm_num_groups,
|
95 |
+
double_z=True,
|
96 |
+
time_compression_ratio=time_compression_ratio,
|
97 |
+
spatial_compression_ratio=spatial_compression_ratio,
|
98 |
+
mid_block_add_attention=mid_block_add_attention,
|
99 |
+
)
|
100 |
+
|
101 |
+
self.decoder = DecoderCausal3D(
|
102 |
+
in_channels=latent_channels,
|
103 |
+
out_channels=out_channels,
|
104 |
+
up_block_types=up_block_types,
|
105 |
+
block_out_channels=block_out_channels,
|
106 |
+
layers_per_block=layers_per_block,
|
107 |
+
norm_num_groups=norm_num_groups,
|
108 |
+
act_fn=act_fn,
|
109 |
+
time_compression_ratio=time_compression_ratio,
|
110 |
+
spatial_compression_ratio=spatial_compression_ratio,
|
111 |
+
mid_block_add_attention=mid_block_add_attention,
|
112 |
+
)
|
113 |
+
|
114 |
+
self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, kernel_size=1)
|
115 |
+
self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, kernel_size=1)
|
116 |
+
|
117 |
+
self.use_slicing = False
|
118 |
+
self.use_spatial_tiling = False
|
119 |
+
self.use_temporal_tiling = False
|
120 |
+
|
121 |
+
# only relevant if vae tiling is enabled
|
122 |
+
self.tile_sample_min_tsize = sample_tsize
|
123 |
+
self.tile_latent_min_tsize = sample_tsize // time_compression_ratio
|
124 |
+
|
125 |
+
self.tile_sample_min_size = self.config.sample_size
|
126 |
+
sample_size = (
|
127 |
+
self.config.sample_size[0]
|
128 |
+
if isinstance(self.config.sample_size, (list, tuple))
|
129 |
+
else self.config.sample_size
|
130 |
+
)
|
131 |
+
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
|
132 |
+
self.tile_overlap_factor = 0.25
|
133 |
+
|
134 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
135 |
+
if isinstance(module, (EncoderCausal3D, DecoderCausal3D)):
|
136 |
+
module.gradient_checkpointing = value
|
137 |
+
|
138 |
+
def enable_temporal_tiling(self, use_tiling: bool = True):
|
139 |
+
self.use_temporal_tiling = use_tiling
|
140 |
+
|
141 |
+
def disable_temporal_tiling(self):
|
142 |
+
self.enable_temporal_tiling(False)
|
143 |
+
|
144 |
+
def enable_spatial_tiling(self, use_tiling: bool = True):
|
145 |
+
self.use_spatial_tiling = use_tiling
|
146 |
+
|
147 |
+
def disable_spatial_tiling(self):
|
148 |
+
self.enable_spatial_tiling(False)
|
149 |
+
|
150 |
+
def enable_tiling(self, use_tiling: bool = True):
|
151 |
+
r"""
|
152 |
+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
153 |
+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
154 |
+
processing larger videos.
|
155 |
+
"""
|
156 |
+
self.enable_spatial_tiling(use_tiling)
|
157 |
+
self.enable_temporal_tiling(use_tiling)
|
158 |
+
|
159 |
+
def disable_tiling(self):
|
160 |
+
r"""
|
161 |
+
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
162 |
+
decoding in one step.
|
163 |
+
"""
|
164 |
+
self.disable_spatial_tiling()
|
165 |
+
self.disable_temporal_tiling()
|
166 |
+
|
167 |
+
def enable_slicing(self):
|
168 |
+
r"""
|
169 |
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
170 |
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
171 |
+
"""
|
172 |
+
self.use_slicing = True
|
173 |
+
|
174 |
+
def disable_slicing(self):
|
175 |
+
r"""
|
176 |
+
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
177 |
+
decoding in one step.
|
178 |
+
"""
|
179 |
+
self.use_slicing = False
|
180 |
+
|
181 |
+
@property
|
182 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
183 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
184 |
+
r"""
|
185 |
+
Returns:
|
186 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
187 |
+
indexed by its weight name.
|
188 |
+
"""
|
189 |
+
# set recursively
|
190 |
+
processors = {}
|
191 |
+
|
192 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
193 |
+
if hasattr(module, "get_processor"):
|
194 |
+
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
195 |
+
|
196 |
+
for sub_name, child in module.named_children():
|
197 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
198 |
+
|
199 |
+
return processors
|
200 |
+
|
201 |
+
for name, module in self.named_children():
|
202 |
+
fn_recursive_add_processors(name, module, processors)
|
203 |
+
|
204 |
+
return processors
|
205 |
+
|
206 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
207 |
+
def set_attn_processor(
|
208 |
+
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
|
209 |
+
):
|
210 |
+
r"""
|
211 |
+
Sets the attention processor to use to compute attention.
|
212 |
+
|
213 |
+
Parameters:
|
214 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
215 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
216 |
+
for **all** `Attention` layers.
|
217 |
+
|
218 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
219 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
220 |
+
|
221 |
+
"""
|
222 |
+
count = len(self.attn_processors.keys())
|
223 |
+
|
224 |
+
if isinstance(processor, dict) and len(processor) != count:
|
225 |
+
raise ValueError(
|
226 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
227 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
228 |
+
)
|
229 |
+
|
230 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
231 |
+
if hasattr(module, "set_processor"):
|
232 |
+
if not isinstance(processor, dict):
|
233 |
+
module.set_processor(processor, _remove_lora=_remove_lora)
|
234 |
+
else:
|
235 |
+
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
|
236 |
+
|
237 |
+
for sub_name, child in module.named_children():
|
238 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
239 |
+
|
240 |
+
for name, module in self.named_children():
|
241 |
+
fn_recursive_attn_processor(name, module, processor)
|
242 |
+
|
243 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
244 |
+
def set_default_attn_processor(self):
|
245 |
+
"""
|
246 |
+
Disables custom attention processors and sets the default attention implementation.
|
247 |
+
"""
|
248 |
+
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
249 |
+
processor = AttnAddedKVProcessor()
|
250 |
+
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
251 |
+
processor = AttnProcessor()
|
252 |
+
else:
|
253 |
+
raise ValueError(
|
254 |
+
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
255 |
+
)
|
256 |
+
|
257 |
+
self.set_attn_processor(processor, _remove_lora=True)
|
258 |
+
|
259 |
+
@apply_forward_hook
|
260 |
+
def encode(
|
261 |
+
self, x: torch.FloatTensor, return_dict: bool = True
|
262 |
+
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
263 |
+
"""
|
264 |
+
Encode a batch of images/videos into latents.
|
265 |
+
|
266 |
+
Args:
|
267 |
+
x (`torch.FloatTensor`): Input batch of images/videos.
|
268 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
269 |
+
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
270 |
+
|
271 |
+
Returns:
|
272 |
+
The latent representations of the encoded images/videos. If `return_dict` is True, a
|
273 |
+
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
|
274 |
+
"""
|
275 |
+
assert len(x.shape) == 5, "The input tensor should have 5 dimensions."
|
276 |
+
|
277 |
+
if self.use_temporal_tiling and x.shape[2] > self.tile_sample_min_tsize:
|
278 |
+
return self.temporal_tiled_encode(x, return_dict=return_dict)
|
279 |
+
|
280 |
+
if self.use_spatial_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
|
281 |
+
return self.spatial_tiled_encode(x, return_dict=return_dict)
|
282 |
+
|
283 |
+
if self.use_slicing and x.shape[0] > 1:
|
284 |
+
encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
|
285 |
+
h = torch.cat(encoded_slices)
|
286 |
+
else:
|
287 |
+
h = self.encoder(x)
|
288 |
+
|
289 |
+
moments = self.quant_conv(h)
|
290 |
+
posterior = DiagonalGaussianDistribution(moments)
|
291 |
+
|
292 |
+
if not return_dict:
|
293 |
+
return (posterior,)
|
294 |
+
|
295 |
+
return AutoencoderKLOutput(latent_dist=posterior)
|
296 |
+
|
297 |
+
def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
298 |
+
assert len(z.shape) == 5, "The input tensor should have 5 dimensions."
|
299 |
+
|
300 |
+
if self.use_temporal_tiling and z.shape[2] > self.tile_latent_min_tsize:
|
301 |
+
return self.temporal_tiled_decode(z, return_dict=return_dict)
|
302 |
+
|
303 |
+
if self.use_spatial_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
|
304 |
+
return self.spatial_tiled_decode(z, return_dict=return_dict)
|
305 |
+
|
306 |
+
z = self.post_quant_conv(z)
|
307 |
+
dec = self.decoder(z)
|
308 |
+
|
309 |
+
if not return_dict:
|
310 |
+
return (dec,)
|
311 |
+
|
312 |
+
return DecoderOutput(sample=dec)
|
313 |
+
|
314 |
+
@apply_forward_hook
|
315 |
+
def decode(
|
316 |
+
self, z: torch.FloatTensor, return_dict: bool = True, generator=None
|
317 |
+
) -> Union[DecoderOutput, torch.FloatTensor]:
|
318 |
+
"""
|
319 |
+
Decode a batch of images/videos.
|
320 |
+
|
321 |
+
Args:
|
322 |
+
z (`torch.FloatTensor`): Input batch of latent vectors.
|
323 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
324 |
+
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
325 |
+
|
326 |
+
Returns:
|
327 |
+
[`~models.vae.DecoderOutput`] or `tuple`:
|
328 |
+
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
329 |
+
returned.
|
330 |
+
|
331 |
+
"""
|
332 |
+
if self.use_slicing and z.shape[0] > 1:
|
333 |
+
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
334 |
+
decoded = torch.cat(decoded_slices)
|
335 |
+
else:
|
336 |
+
decoded = self._decode(z).sample
|
337 |
+
|
338 |
+
if not return_dict:
|
339 |
+
return (decoded,)
|
340 |
+
|
341 |
+
return DecoderOutput(sample=decoded)
|
342 |
+
|
343 |
+
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
344 |
+
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
|
345 |
+
for y in range(blend_extent):
|
346 |
+
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent)
|
347 |
+
return b
|
348 |
+
|
349 |
+
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
350 |
+
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
|
351 |
+
for x in range(blend_extent):
|
352 |
+
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent)
|
353 |
+
return b
|
354 |
+
|
355 |
+
def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
356 |
+
blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
|
357 |
+
for x in range(blend_extent):
|
358 |
+
b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (x / blend_extent)
|
359 |
+
return b
|
360 |
+
|
361 |
+
def spatial_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True, return_moments: bool = False) -> AutoencoderKLOutput:
|
362 |
+
r"""Encode a batch of images/videos using a tiled encoder.
|
363 |
+
|
364 |
+
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
365 |
+
steps. This is useful to keep memory use constant regardless of image/videos size. The end result of tiled encoding is
|
366 |
+
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
|
367 |
+
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
|
368 |
+
output, but they should be much less noticeable.
|
369 |
+
|
370 |
+
Args:
|
371 |
+
x (`torch.FloatTensor`): Input batch of images/videos.
|
372 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
373 |
+
Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
374 |
+
|
375 |
+
Returns:
|
376 |
+
[`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
|
377 |
+
If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
|
378 |
+
`tuple` is returned.
|
379 |
+
"""
|
380 |
+
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
|
381 |
+
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
|
382 |
+
row_limit = self.tile_latent_min_size - blend_extent
|
383 |
+
|
384 |
+
# Split video into tiles and encode them separately.
|
385 |
+
rows = []
|
386 |
+
for i in range(0, x.shape[-2], overlap_size):
|
387 |
+
row = []
|
388 |
+
for j in range(0, x.shape[-1], overlap_size):
|
389 |
+
tile = x[:, :, :, i: i + self.tile_sample_min_size, j: j + self.tile_sample_min_size]
|
390 |
+
tile = self.encoder(tile)
|
391 |
+
tile = self.quant_conv(tile)
|
392 |
+
row.append(tile)
|
393 |
+
rows.append(row)
|
394 |
+
result_rows = []
|
395 |
+
for i, row in enumerate(rows):
|
396 |
+
result_row = []
|
397 |
+
for j, tile in enumerate(row):
|
398 |
+
# blend the above tile and the left tile
|
399 |
+
# to the current tile and add the current tile to the result row
|
400 |
+
if i > 0:
|
401 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
402 |
+
if j > 0:
|
403 |
+
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
404 |
+
result_row.append(tile[:, :, :, :row_limit, :row_limit])
|
405 |
+
result_rows.append(torch.cat(result_row, dim=-1))
|
406 |
+
|
407 |
+
moments = torch.cat(result_rows, dim=-2)
|
408 |
+
if return_moments:
|
409 |
+
return moments
|
410 |
+
|
411 |
+
posterior = DiagonalGaussianDistribution(moments)
|
412 |
+
if not return_dict:
|
413 |
+
return (posterior,)
|
414 |
+
|
415 |
+
return AutoencoderKLOutput(latent_dist=posterior)
|
416 |
+
|
417 |
+
def spatial_tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
418 |
+
r"""
|
419 |
+
Decode a batch of images/videos using a tiled decoder.
|
420 |
+
|
421 |
+
Args:
|
422 |
+
z (`torch.FloatTensor`): Input batch of latent vectors.
|
423 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
424 |
+
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
425 |
+
|
426 |
+
Returns:
|
427 |
+
[`~models.vae.DecoderOutput`] or `tuple`:
|
428 |
+
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
429 |
+
returned.
|
430 |
+
"""
|
431 |
+
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
|
432 |
+
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
|
433 |
+
row_limit = self.tile_sample_min_size - blend_extent
|
434 |
+
|
435 |
+
# Split z into overlapping tiles and decode them separately.
|
436 |
+
# The tiles have an overlap to avoid seams between tiles.
|
437 |
+
rows = []
|
438 |
+
for i in range(0, z.shape[-2], overlap_size):
|
439 |
+
row = []
|
440 |
+
for j in range(0, z.shape[-1], overlap_size):
|
441 |
+
tile = z[:, :, :, i: i + self.tile_latent_min_size, j: j + self.tile_latent_min_size]
|
442 |
+
tile = self.post_quant_conv(tile)
|
443 |
+
decoded = self.decoder(tile)
|
444 |
+
row.append(decoded)
|
445 |
+
rows.append(row)
|
446 |
+
result_rows = []
|
447 |
+
for i, row in enumerate(rows):
|
448 |
+
result_row = []
|
449 |
+
for j, tile in enumerate(row):
|
450 |
+
# blend the above tile and the left tile
|
451 |
+
# to the current tile and add the current tile to the result row
|
452 |
+
if i > 0:
|
453 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
454 |
+
if j > 0:
|
455 |
+
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
456 |
+
result_row.append(tile[:, :, :, :row_limit, :row_limit])
|
457 |
+
result_rows.append(torch.cat(result_row, dim=-1))
|
458 |
+
|
459 |
+
dec = torch.cat(result_rows, dim=-2)
|
460 |
+
if not return_dict:
|
461 |
+
return (dec,)
|
462 |
+
|
463 |
+
return DecoderOutput(sample=dec)
|
464 |
+
|
465 |
+
def temporal_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
|
466 |
+
|
467 |
+
B, C, T, H, W = x.shape
|
468 |
+
overlap_size = int(self.tile_sample_min_tsize * (1 - self.tile_overlap_factor))
|
469 |
+
blend_extent = int(self.tile_latent_min_tsize * self.tile_overlap_factor)
|
470 |
+
t_limit = self.tile_latent_min_tsize - blend_extent
|
471 |
+
|
472 |
+
# Split the video into tiles and encode them separately.
|
473 |
+
row = []
|
474 |
+
for i in range(0, T, overlap_size):
|
475 |
+
tile = x[:, :, i: i + self.tile_sample_min_tsize + 1, :, :]
|
476 |
+
if self.use_spatial_tiling and (tile.shape[-1] > self.tile_sample_min_size or tile.shape[-2] > self.tile_sample_min_size):
|
477 |
+
tile = self.spatial_tiled_encode(tile, return_moments=True)
|
478 |
+
else:
|
479 |
+
tile = self.encoder(tile)
|
480 |
+
tile = self.quant_conv(tile)
|
481 |
+
if i > 0:
|
482 |
+
tile = tile[:, :, 1:, :, :]
|
483 |
+
row.append(tile)
|
484 |
+
result_row = []
|
485 |
+
for i, tile in enumerate(row):
|
486 |
+
if i > 0:
|
487 |
+
tile = self.blend_t(row[i - 1], tile, blend_extent)
|
488 |
+
result_row.append(tile[:, :, :t_limit, :, :])
|
489 |
+
else:
|
490 |
+
result_row.append(tile[:, :, :t_limit + 1, :, :])
|
491 |
+
|
492 |
+
moments = torch.cat(result_row, dim=2)
|
493 |
+
posterior = DiagonalGaussianDistribution(moments)
|
494 |
+
|
495 |
+
if not return_dict:
|
496 |
+
return (posterior,)
|
497 |
+
|
498 |
+
return AutoencoderKLOutput(latent_dist=posterior)
|
499 |
+
|
500 |
+
def temporal_tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
501 |
+
# Split z into overlapping tiles and decode them separately.
|
502 |
+
|
503 |
+
B, C, T, H, W = z.shape
|
504 |
+
overlap_size = int(self.tile_latent_min_tsize * (1 - self.tile_overlap_factor))
|
505 |
+
blend_extent = int(self.tile_sample_min_tsize * self.tile_overlap_factor)
|
506 |
+
t_limit = self.tile_sample_min_tsize - blend_extent
|
507 |
+
|
508 |
+
row = []
|
509 |
+
for i in range(0, T, overlap_size):
|
510 |
+
tile = z[:, :, i: i + self.tile_latent_min_tsize + 1, :, :]
|
511 |
+
if self.use_spatial_tiling and (tile.shape[-1] > self.tile_latent_min_size or tile.shape[-2] > self.tile_latent_min_size):
|
512 |
+
decoded = self.spatial_tiled_decode(tile, return_dict=True).sample
|
513 |
+
else:
|
514 |
+
tile = self.post_quant_conv(tile)
|
515 |
+
decoded = self.decoder(tile)
|
516 |
+
if i > 0:
|
517 |
+
decoded = decoded[:, :, 1:, :, :]
|
518 |
+
row.append(decoded)
|
519 |
+
result_row = []
|
520 |
+
for i, tile in enumerate(row):
|
521 |
+
if i > 0:
|
522 |
+
tile = self.blend_t(row[i - 1], tile, blend_extent)
|
523 |
+
result_row.append(tile[:, :, :t_limit, :, :])
|
524 |
+
else:
|
525 |
+
result_row.append(tile[:, :, :t_limit + 1, :, :])
|
526 |
+
|
527 |
+
dec = torch.cat(result_row, dim=2)
|
528 |
+
if not return_dict:
|
529 |
+
return (dec,)
|
530 |
+
|
531 |
+
return DecoderOutput(sample=dec)
|
532 |
+
|
533 |
+
def forward(
|
534 |
+
self,
|
535 |
+
sample: torch.FloatTensor,
|
536 |
+
sample_posterior: bool = False,
|
537 |
+
return_dict: bool = True,
|
538 |
+
return_posterior: bool = False,
|
539 |
+
generator: Optional[torch.Generator] = None,
|
540 |
+
) -> Union[DecoderOutput2, torch.FloatTensor]:
|
541 |
+
r"""
|
542 |
+
Args:
|
543 |
+
sample (`torch.FloatTensor`): Input sample.
|
544 |
+
sample_posterior (`bool`, *optional*, defaults to `False`):
|
545 |
+
Whether to sample from the posterior.
|
546 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
547 |
+
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
548 |
+
"""
|
549 |
+
x = sample
|
550 |
+
posterior = self.encode(x).latent_dist
|
551 |
+
if sample_posterior:
|
552 |
+
z = posterior.sample(generator=generator)
|
553 |
+
else:
|
554 |
+
z = posterior.mode()
|
555 |
+
dec = self.decode(z).sample
|
556 |
+
|
557 |
+
if not return_dict:
|
558 |
+
if return_posterior:
|
559 |
+
return (dec, posterior)
|
560 |
+
else:
|
561 |
+
return (dec,)
|
562 |
+
if return_posterior:
|
563 |
+
return DecoderOutput2(sample=dec, posterior=posterior)
|
564 |
+
else:
|
565 |
+
return DecoderOutput2(sample=dec)
|
566 |
+
|
567 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
|
568 |
+
def fuse_qkv_projections(self):
|
569 |
+
"""
|
570 |
+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
|
571 |
+
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
|
572 |
+
|
573 |
+
<Tip warning={true}>
|
574 |
+
|
575 |
+
This API is 🧪 experimental.
|
576 |
+
|
577 |
+
</Tip>
|
578 |
+
"""
|
579 |
+
self.original_attn_processors = None
|
580 |
+
|
581 |
+
for _, attn_processor in self.attn_processors.items():
|
582 |
+
if "Added" in str(attn_processor.__class__.__name__):
|
583 |
+
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
584 |
+
|
585 |
+
self.original_attn_processors = self.attn_processors
|
586 |
+
|
587 |
+
for module in self.modules():
|
588 |
+
if isinstance(module, Attention):
|
589 |
+
module.fuse_projections(fuse=True)
|
590 |
+
|
591 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
592 |
+
def unfuse_qkv_projections(self):
|
593 |
+
"""Disables the fused QKV projection if enabled.
|
594 |
+
|
595 |
+
<Tip warning={true}>
|
596 |
+
|
597 |
+
This API is 🧪 experimental.
|
598 |
+
|
599 |
+
</Tip>
|
600 |
+
|
601 |
+
"""
|
602 |
+
if self.original_attn_processors is not None:
|
603 |
+
self.set_attn_processor(self.original_attn_processors)
|
hyvideo/vae/unet_causal_3d_blocks.py
ADDED
@@ -0,0 +1,764 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
#
|
16 |
+
# Modified from diffusers==0.29.2
|
17 |
+
#
|
18 |
+
# ==============================================================================
|
19 |
+
|
20 |
+
from typing import Optional, Tuple, Union
|
21 |
+
|
22 |
+
import torch
|
23 |
+
import torch.nn.functional as F
|
24 |
+
from torch import nn
|
25 |
+
from einops import rearrange
|
26 |
+
|
27 |
+
from diffusers.utils import logging
|
28 |
+
from diffusers.models.activations import get_activation
|
29 |
+
from diffusers.models.attention_processor import SpatialNorm
|
30 |
+
from diffusers.models.attention_processor import Attention
|
31 |
+
from diffusers.models.normalization import AdaGroupNorm
|
32 |
+
from diffusers.models.normalization import RMSNorm
|
33 |
+
|
34 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
35 |
+
|
36 |
+
|
37 |
+
def prepare_causal_attention_mask(n_frame: int, n_hw: int, dtype, device, batch_size: int = None):
|
38 |
+
seq_len = n_frame * n_hw
|
39 |
+
mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device)
|
40 |
+
for i in range(seq_len):
|
41 |
+
i_frame = i // n_hw
|
42 |
+
mask[i, : (i_frame + 1) * n_hw] = 0
|
43 |
+
if batch_size is not None:
|
44 |
+
mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
|
45 |
+
return mask
|
46 |
+
|
47 |
+
|
48 |
+
class CausalConv3d(nn.Module):
|
49 |
+
"""
|
50 |
+
Implements a causal 3D convolution layer where each position only depends on previous timesteps and current spatial locations.
|
51 |
+
This maintains temporal causality in video generation tasks.
|
52 |
+
"""
|
53 |
+
|
54 |
+
def __init__(
|
55 |
+
self,
|
56 |
+
chan_in,
|
57 |
+
chan_out,
|
58 |
+
kernel_size: Union[int, Tuple[int, int, int]],
|
59 |
+
stride: Union[int, Tuple[int, int, int]] = 1,
|
60 |
+
dilation: Union[int, Tuple[int, int, int]] = 1,
|
61 |
+
pad_mode='replicate',
|
62 |
+
**kwargs
|
63 |
+
):
|
64 |
+
super().__init__()
|
65 |
+
|
66 |
+
self.pad_mode = pad_mode
|
67 |
+
padding = (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size - 1, 0) # W, H, T
|
68 |
+
self.time_causal_padding = padding
|
69 |
+
|
70 |
+
self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
|
71 |
+
|
72 |
+
def forward(self, x):
|
73 |
+
x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
|
74 |
+
return self.conv(x)
|
75 |
+
|
76 |
+
|
77 |
+
class UpsampleCausal3D(nn.Module):
|
78 |
+
"""
|
79 |
+
A 3D upsampling layer with an optional convolution.
|
80 |
+
"""
|
81 |
+
|
82 |
+
def __init__(
|
83 |
+
self,
|
84 |
+
channels: int,
|
85 |
+
use_conv: bool = False,
|
86 |
+
use_conv_transpose: bool = False,
|
87 |
+
out_channels: Optional[int] = None,
|
88 |
+
name: str = "conv",
|
89 |
+
kernel_size: Optional[int] = None,
|
90 |
+
padding=1,
|
91 |
+
norm_type=None,
|
92 |
+
eps=None,
|
93 |
+
elementwise_affine=None,
|
94 |
+
bias=True,
|
95 |
+
interpolate=True,
|
96 |
+
upsample_factor=(2, 2, 2),
|
97 |
+
):
|
98 |
+
super().__init__()
|
99 |
+
self.channels = channels
|
100 |
+
self.out_channels = out_channels or channels
|
101 |
+
self.use_conv = use_conv
|
102 |
+
self.use_conv_transpose = use_conv_transpose
|
103 |
+
self.name = name
|
104 |
+
self.interpolate = interpolate
|
105 |
+
self.upsample_factor = upsample_factor
|
106 |
+
|
107 |
+
if norm_type == "ln_norm":
|
108 |
+
self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
|
109 |
+
elif norm_type == "rms_norm":
|
110 |
+
self.norm = RMSNorm(channels, eps, elementwise_affine)
|
111 |
+
elif norm_type is None:
|
112 |
+
self.norm = None
|
113 |
+
else:
|
114 |
+
raise ValueError(f"unknown norm_type: {norm_type}")
|
115 |
+
|
116 |
+
conv = None
|
117 |
+
if use_conv_transpose:
|
118 |
+
raise NotImplementedError
|
119 |
+
elif use_conv:
|
120 |
+
if kernel_size is None:
|
121 |
+
kernel_size = 3
|
122 |
+
conv = CausalConv3d(self.channels, self.out_channels, kernel_size=kernel_size, bias=bias)
|
123 |
+
|
124 |
+
if name == "conv":
|
125 |
+
self.conv = conv
|
126 |
+
else:
|
127 |
+
self.Conv2d_0 = conv
|
128 |
+
|
129 |
+
def forward(
|
130 |
+
self,
|
131 |
+
hidden_states: torch.FloatTensor,
|
132 |
+
output_size: Optional[int] = None,
|
133 |
+
scale: float = 1.0,
|
134 |
+
) -> torch.FloatTensor:
|
135 |
+
assert hidden_states.shape[1] == self.channels
|
136 |
+
|
137 |
+
if self.norm is not None:
|
138 |
+
raise NotImplementedError
|
139 |
+
|
140 |
+
if self.use_conv_transpose:
|
141 |
+
return self.conv(hidden_states)
|
142 |
+
|
143 |
+
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
|
144 |
+
dtype = hidden_states.dtype
|
145 |
+
if dtype == torch.bfloat16:
|
146 |
+
hidden_states = hidden_states.to(torch.float32)
|
147 |
+
|
148 |
+
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
149 |
+
if hidden_states.shape[0] >= 64:
|
150 |
+
hidden_states = hidden_states.contiguous()
|
151 |
+
|
152 |
+
# if `output_size` is passed we force the interpolation output
|
153 |
+
# size and do not make use of `scale_factor=2`
|
154 |
+
if self.interpolate:
|
155 |
+
B, C, T, H, W = hidden_states.shape
|
156 |
+
first_h, other_h = hidden_states.split((1, T - 1), dim=2)
|
157 |
+
if output_size is None:
|
158 |
+
if T > 1:
|
159 |
+
other_h = F.interpolate(other_h, scale_factor=self.upsample_factor, mode="nearest")
|
160 |
+
|
161 |
+
first_h = first_h.squeeze(2)
|
162 |
+
first_h = F.interpolate(first_h, scale_factor=self.upsample_factor[1:], mode="nearest")
|
163 |
+
first_h = first_h.unsqueeze(2)
|
164 |
+
else:
|
165 |
+
raise NotImplementedError
|
166 |
+
|
167 |
+
if T > 1:
|
168 |
+
hidden_states = torch.cat((first_h, other_h), dim=2)
|
169 |
+
else:
|
170 |
+
hidden_states = first_h
|
171 |
+
|
172 |
+
# If the input is bfloat16, we cast back to bfloat16
|
173 |
+
if dtype == torch.bfloat16:
|
174 |
+
hidden_states = hidden_states.to(dtype)
|
175 |
+
|
176 |
+
if self.use_conv:
|
177 |
+
if self.name == "conv":
|
178 |
+
hidden_states = self.conv(hidden_states)
|
179 |
+
else:
|
180 |
+
hidden_states = self.Conv2d_0(hidden_states)
|
181 |
+
|
182 |
+
return hidden_states
|
183 |
+
|
184 |
+
|
185 |
+
class DownsampleCausal3D(nn.Module):
|
186 |
+
"""
|
187 |
+
A 3D downsampling layer with an optional convolution.
|
188 |
+
"""
|
189 |
+
|
190 |
+
def __init__(
|
191 |
+
self,
|
192 |
+
channels: int,
|
193 |
+
use_conv: bool = False,
|
194 |
+
out_channels: Optional[int] = None,
|
195 |
+
padding: int = 1,
|
196 |
+
name: str = "conv",
|
197 |
+
kernel_size=3,
|
198 |
+
norm_type=None,
|
199 |
+
eps=None,
|
200 |
+
elementwise_affine=None,
|
201 |
+
bias=True,
|
202 |
+
stride=2,
|
203 |
+
):
|
204 |
+
super().__init__()
|
205 |
+
self.channels = channels
|
206 |
+
self.out_channels = out_channels or channels
|
207 |
+
self.use_conv = use_conv
|
208 |
+
self.padding = padding
|
209 |
+
stride = stride
|
210 |
+
self.name = name
|
211 |
+
|
212 |
+
if norm_type == "ln_norm":
|
213 |
+
self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
|
214 |
+
elif norm_type == "rms_norm":
|
215 |
+
self.norm = RMSNorm(channels, eps, elementwise_affine)
|
216 |
+
elif norm_type is None:
|
217 |
+
self.norm = None
|
218 |
+
else:
|
219 |
+
raise ValueError(f"unknown norm_type: {norm_type}")
|
220 |
+
|
221 |
+
if use_conv:
|
222 |
+
conv = CausalConv3d(
|
223 |
+
self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, bias=bias
|
224 |
+
)
|
225 |
+
else:
|
226 |
+
raise NotImplementedError
|
227 |
+
|
228 |
+
if name == "conv":
|
229 |
+
self.Conv2d_0 = conv
|
230 |
+
self.conv = conv
|
231 |
+
elif name == "Conv2d_0":
|
232 |
+
self.conv = conv
|
233 |
+
else:
|
234 |
+
self.conv = conv
|
235 |
+
|
236 |
+
def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
|
237 |
+
assert hidden_states.shape[1] == self.channels
|
238 |
+
|
239 |
+
if self.norm is not None:
|
240 |
+
hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
241 |
+
|
242 |
+
assert hidden_states.shape[1] == self.channels
|
243 |
+
|
244 |
+
hidden_states = self.conv(hidden_states)
|
245 |
+
|
246 |
+
return hidden_states
|
247 |
+
|
248 |
+
|
249 |
+
class ResnetBlockCausal3D(nn.Module):
|
250 |
+
r"""
|
251 |
+
A Resnet block.
|
252 |
+
"""
|
253 |
+
|
254 |
+
def __init__(
|
255 |
+
self,
|
256 |
+
*,
|
257 |
+
in_channels: int,
|
258 |
+
out_channels: Optional[int] = None,
|
259 |
+
conv_shortcut: bool = False,
|
260 |
+
dropout: float = 0.0,
|
261 |
+
temb_channels: int = 512,
|
262 |
+
groups: int = 32,
|
263 |
+
groups_out: Optional[int] = None,
|
264 |
+
pre_norm: bool = True,
|
265 |
+
eps: float = 1e-6,
|
266 |
+
non_linearity: str = "swish",
|
267 |
+
skip_time_act: bool = False,
|
268 |
+
# default, scale_shift, ada_group, spatial
|
269 |
+
time_embedding_norm: str = "default",
|
270 |
+
kernel: Optional[torch.FloatTensor] = None,
|
271 |
+
output_scale_factor: float = 1.0,
|
272 |
+
use_in_shortcut: Optional[bool] = None,
|
273 |
+
up: bool = False,
|
274 |
+
down: bool = False,
|
275 |
+
conv_shortcut_bias: bool = True,
|
276 |
+
conv_3d_out_channels: Optional[int] = None,
|
277 |
+
):
|
278 |
+
super().__init__()
|
279 |
+
self.pre_norm = pre_norm
|
280 |
+
self.pre_norm = True
|
281 |
+
self.in_channels = in_channels
|
282 |
+
out_channels = in_channels if out_channels is None else out_channels
|
283 |
+
self.out_channels = out_channels
|
284 |
+
self.use_conv_shortcut = conv_shortcut
|
285 |
+
self.up = up
|
286 |
+
self.down = down
|
287 |
+
self.output_scale_factor = output_scale_factor
|
288 |
+
self.time_embedding_norm = time_embedding_norm
|
289 |
+
self.skip_time_act = skip_time_act
|
290 |
+
|
291 |
+
linear_cls = nn.Linear
|
292 |
+
|
293 |
+
if groups_out is None:
|
294 |
+
groups_out = groups
|
295 |
+
|
296 |
+
if self.time_embedding_norm == "ada_group":
|
297 |
+
self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
|
298 |
+
elif self.time_embedding_norm == "spatial":
|
299 |
+
self.norm1 = SpatialNorm(in_channels, temb_channels)
|
300 |
+
else:
|
301 |
+
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
302 |
+
|
303 |
+
self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, stride=1)
|
304 |
+
|
305 |
+
if temb_channels is not None:
|
306 |
+
if self.time_embedding_norm == "default":
|
307 |
+
self.time_emb_proj = linear_cls(temb_channels, out_channels)
|
308 |
+
elif self.time_embedding_norm == "scale_shift":
|
309 |
+
self.time_emb_proj = linear_cls(temb_channels, 2 * out_channels)
|
310 |
+
elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
|
311 |
+
self.time_emb_proj = None
|
312 |
+
else:
|
313 |
+
raise ValueError(f"Unknown time_embedding_norm : {self.time_embedding_norm} ")
|
314 |
+
else:
|
315 |
+
self.time_emb_proj = None
|
316 |
+
|
317 |
+
if self.time_embedding_norm == "ada_group":
|
318 |
+
self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
|
319 |
+
elif self.time_embedding_norm == "spatial":
|
320 |
+
self.norm2 = SpatialNorm(out_channels, temb_channels)
|
321 |
+
else:
|
322 |
+
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
|
323 |
+
|
324 |
+
self.dropout = torch.nn.Dropout(dropout)
|
325 |
+
conv_3d_out_channels = conv_3d_out_channels or out_channels
|
326 |
+
self.conv2 = CausalConv3d(out_channels, conv_3d_out_channels, kernel_size=3, stride=1)
|
327 |
+
|
328 |
+
self.nonlinearity = get_activation(non_linearity)
|
329 |
+
|
330 |
+
self.upsample = self.downsample = None
|
331 |
+
if self.up:
|
332 |
+
self.upsample = UpsampleCausal3D(in_channels, use_conv=False)
|
333 |
+
elif self.down:
|
334 |
+
self.downsample = DownsampleCausal3D(in_channels, use_conv=False, name="op")
|
335 |
+
|
336 |
+
self.use_in_shortcut = self.in_channels != conv_3d_out_channels if use_in_shortcut is None else use_in_shortcut
|
337 |
+
|
338 |
+
self.conv_shortcut = None
|
339 |
+
if self.use_in_shortcut:
|
340 |
+
self.conv_shortcut = CausalConv3d(
|
341 |
+
in_channels,
|
342 |
+
conv_3d_out_channels,
|
343 |
+
kernel_size=1,
|
344 |
+
stride=1,
|
345 |
+
bias=conv_shortcut_bias,
|
346 |
+
)
|
347 |
+
|
348 |
+
def forward(
|
349 |
+
self,
|
350 |
+
input_tensor: torch.FloatTensor,
|
351 |
+
temb: torch.FloatTensor,
|
352 |
+
scale: float = 1.0,
|
353 |
+
) -> torch.FloatTensor:
|
354 |
+
hidden_states = input_tensor
|
355 |
+
|
356 |
+
if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
|
357 |
+
hidden_states = self.norm1(hidden_states, temb)
|
358 |
+
else:
|
359 |
+
hidden_states = self.norm1(hidden_states)
|
360 |
+
|
361 |
+
hidden_states = self.nonlinearity(hidden_states)
|
362 |
+
|
363 |
+
if self.upsample is not None:
|
364 |
+
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
365 |
+
if hidden_states.shape[0] >= 64:
|
366 |
+
input_tensor = input_tensor.contiguous()
|
367 |
+
hidden_states = hidden_states.contiguous()
|
368 |
+
input_tensor = (
|
369 |
+
self.upsample(input_tensor, scale=scale)
|
370 |
+
)
|
371 |
+
hidden_states = (
|
372 |
+
self.upsample(hidden_states, scale=scale)
|
373 |
+
)
|
374 |
+
elif self.downsample is not None:
|
375 |
+
input_tensor = (
|
376 |
+
self.downsample(input_tensor, scale=scale)
|
377 |
+
)
|
378 |
+
hidden_states = (
|
379 |
+
self.downsample(hidden_states, scale=scale)
|
380 |
+
)
|
381 |
+
|
382 |
+
hidden_states = self.conv1(hidden_states)
|
383 |
+
|
384 |
+
if self.time_emb_proj is not None:
|
385 |
+
if not self.skip_time_act:
|
386 |
+
temb = self.nonlinearity(temb)
|
387 |
+
temb = (
|
388 |
+
self.time_emb_proj(temb, scale)[:, :, None, None]
|
389 |
+
)
|
390 |
+
|
391 |
+
if temb is not None and self.time_embedding_norm == "default":
|
392 |
+
hidden_states = hidden_states + temb
|
393 |
+
|
394 |
+
if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
|
395 |
+
hidden_states = self.norm2(hidden_states, temb)
|
396 |
+
else:
|
397 |
+
hidden_states = self.norm2(hidden_states)
|
398 |
+
|
399 |
+
if temb is not None and self.time_embedding_norm == "scale_shift":
|
400 |
+
scale, shift = torch.chunk(temb, 2, dim=1)
|
401 |
+
hidden_states = hidden_states * (1 + scale) + shift
|
402 |
+
|
403 |
+
hidden_states = self.nonlinearity(hidden_states)
|
404 |
+
|
405 |
+
hidden_states = self.dropout(hidden_states)
|
406 |
+
hidden_states = self.conv2(hidden_states)
|
407 |
+
|
408 |
+
if self.conv_shortcut is not None:
|
409 |
+
input_tensor = (
|
410 |
+
self.conv_shortcut(input_tensor)
|
411 |
+
)
|
412 |
+
|
413 |
+
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
414 |
+
|
415 |
+
return output_tensor
|
416 |
+
|
417 |
+
|
418 |
+
def get_down_block3d(
|
419 |
+
down_block_type: str,
|
420 |
+
num_layers: int,
|
421 |
+
in_channels: int,
|
422 |
+
out_channels: int,
|
423 |
+
temb_channels: int,
|
424 |
+
add_downsample: bool,
|
425 |
+
downsample_stride: int,
|
426 |
+
resnet_eps: float,
|
427 |
+
resnet_act_fn: str,
|
428 |
+
transformer_layers_per_block: int = 1,
|
429 |
+
num_attention_heads: Optional[int] = None,
|
430 |
+
resnet_groups: Optional[int] = None,
|
431 |
+
cross_attention_dim: Optional[int] = None,
|
432 |
+
downsample_padding: Optional[int] = None,
|
433 |
+
dual_cross_attention: bool = False,
|
434 |
+
use_linear_projection: bool = False,
|
435 |
+
only_cross_attention: bool = False,
|
436 |
+
upcast_attention: bool = False,
|
437 |
+
resnet_time_scale_shift: str = "default",
|
438 |
+
attention_type: str = "default",
|
439 |
+
resnet_skip_time_act: bool = False,
|
440 |
+
resnet_out_scale_factor: float = 1.0,
|
441 |
+
cross_attention_norm: Optional[str] = None,
|
442 |
+
attention_head_dim: Optional[int] = None,
|
443 |
+
downsample_type: Optional[str] = None,
|
444 |
+
dropout: float = 0.0,
|
445 |
+
):
|
446 |
+
# If attn head dim is not defined, we default it to the number of heads
|
447 |
+
if attention_head_dim is None:
|
448 |
+
logger.warn(
|
449 |
+
f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
|
450 |
+
)
|
451 |
+
attention_head_dim = num_attention_heads
|
452 |
+
|
453 |
+
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
|
454 |
+
if down_block_type == "DownEncoderBlockCausal3D":
|
455 |
+
return DownEncoderBlockCausal3D(
|
456 |
+
num_layers=num_layers,
|
457 |
+
in_channels=in_channels,
|
458 |
+
out_channels=out_channels,
|
459 |
+
dropout=dropout,
|
460 |
+
add_downsample=add_downsample,
|
461 |
+
downsample_stride=downsample_stride,
|
462 |
+
resnet_eps=resnet_eps,
|
463 |
+
resnet_act_fn=resnet_act_fn,
|
464 |
+
resnet_groups=resnet_groups,
|
465 |
+
downsample_padding=downsample_padding,
|
466 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
467 |
+
)
|
468 |
+
raise ValueError(f"{down_block_type} does not exist.")
|
469 |
+
|
470 |
+
|
471 |
+
def get_up_block3d(
|
472 |
+
up_block_type: str,
|
473 |
+
num_layers: int,
|
474 |
+
in_channels: int,
|
475 |
+
out_channels: int,
|
476 |
+
prev_output_channel: int,
|
477 |
+
temb_channels: int,
|
478 |
+
add_upsample: bool,
|
479 |
+
upsample_scale_factor: Tuple,
|
480 |
+
resnet_eps: float,
|
481 |
+
resnet_act_fn: str,
|
482 |
+
resolution_idx: Optional[int] = None,
|
483 |
+
transformer_layers_per_block: int = 1,
|
484 |
+
num_attention_heads: Optional[int] = None,
|
485 |
+
resnet_groups: Optional[int] = None,
|
486 |
+
cross_attention_dim: Optional[int] = None,
|
487 |
+
dual_cross_attention: bool = False,
|
488 |
+
use_linear_projection: bool = False,
|
489 |
+
only_cross_attention: bool = False,
|
490 |
+
upcast_attention: bool = False,
|
491 |
+
resnet_time_scale_shift: str = "default",
|
492 |
+
attention_type: str = "default",
|
493 |
+
resnet_skip_time_act: bool = False,
|
494 |
+
resnet_out_scale_factor: float = 1.0,
|
495 |
+
cross_attention_norm: Optional[str] = None,
|
496 |
+
attention_head_dim: Optional[int] = None,
|
497 |
+
upsample_type: Optional[str] = None,
|
498 |
+
dropout: float = 0.0,
|
499 |
+
) -> nn.Module:
|
500 |
+
# If attn head dim is not defined, we default it to the number of heads
|
501 |
+
if attention_head_dim is None:
|
502 |
+
logger.warn(
|
503 |
+
f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
|
504 |
+
)
|
505 |
+
attention_head_dim = num_attention_heads
|
506 |
+
|
507 |
+
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
|
508 |
+
if up_block_type == "UpDecoderBlockCausal3D":
|
509 |
+
return UpDecoderBlockCausal3D(
|
510 |
+
num_layers=num_layers,
|
511 |
+
in_channels=in_channels,
|
512 |
+
out_channels=out_channels,
|
513 |
+
resolution_idx=resolution_idx,
|
514 |
+
dropout=dropout,
|
515 |
+
add_upsample=add_upsample,
|
516 |
+
upsample_scale_factor=upsample_scale_factor,
|
517 |
+
resnet_eps=resnet_eps,
|
518 |
+
resnet_act_fn=resnet_act_fn,
|
519 |
+
resnet_groups=resnet_groups,
|
520 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
521 |
+
temb_channels=temb_channels,
|
522 |
+
)
|
523 |
+
raise ValueError(f"{up_block_type} does not exist.")
|
524 |
+
|
525 |
+
|
526 |
+
class UNetMidBlockCausal3D(nn.Module):
|
527 |
+
"""
|
528 |
+
A 3D UNet mid-block [`UNetMidBlockCausal3D`] with multiple residual blocks and optional attention blocks.
|
529 |
+
"""
|
530 |
+
|
531 |
+
def __init__(
|
532 |
+
self,
|
533 |
+
in_channels: int,
|
534 |
+
temb_channels: int,
|
535 |
+
dropout: float = 0.0,
|
536 |
+
num_layers: int = 1,
|
537 |
+
resnet_eps: float = 1e-6,
|
538 |
+
resnet_time_scale_shift: str = "default", # default, spatial
|
539 |
+
resnet_act_fn: str = "swish",
|
540 |
+
resnet_groups: int = 32,
|
541 |
+
attn_groups: Optional[int] = None,
|
542 |
+
resnet_pre_norm: bool = True,
|
543 |
+
add_attention: bool = True,
|
544 |
+
attention_head_dim: int = 1,
|
545 |
+
output_scale_factor: float = 1.0,
|
546 |
+
):
|
547 |
+
super().__init__()
|
548 |
+
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
549 |
+
self.add_attention = add_attention
|
550 |
+
|
551 |
+
if attn_groups is None:
|
552 |
+
attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None
|
553 |
+
|
554 |
+
# there is always at least one resnet
|
555 |
+
resnets = [
|
556 |
+
ResnetBlockCausal3D(
|
557 |
+
in_channels=in_channels,
|
558 |
+
out_channels=in_channels,
|
559 |
+
temb_channels=temb_channels,
|
560 |
+
eps=resnet_eps,
|
561 |
+
groups=resnet_groups,
|
562 |
+
dropout=dropout,
|
563 |
+
time_embedding_norm=resnet_time_scale_shift,
|
564 |
+
non_linearity=resnet_act_fn,
|
565 |
+
output_scale_factor=output_scale_factor,
|
566 |
+
pre_norm=resnet_pre_norm,
|
567 |
+
)
|
568 |
+
]
|
569 |
+
attentions = []
|
570 |
+
|
571 |
+
if attention_head_dim is None:
|
572 |
+
logger.warn(
|
573 |
+
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
|
574 |
+
)
|
575 |
+
attention_head_dim = in_channels
|
576 |
+
|
577 |
+
for _ in range(num_layers):
|
578 |
+
if self.add_attention:
|
579 |
+
attentions.append(
|
580 |
+
Attention(
|
581 |
+
in_channels,
|
582 |
+
heads=in_channels // attention_head_dim,
|
583 |
+
dim_head=attention_head_dim,
|
584 |
+
rescale_output_factor=output_scale_factor,
|
585 |
+
eps=resnet_eps,
|
586 |
+
norm_num_groups=attn_groups,
|
587 |
+
spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None,
|
588 |
+
residual_connection=True,
|
589 |
+
bias=True,
|
590 |
+
upcast_softmax=True,
|
591 |
+
_from_deprecated_attn_block=True,
|
592 |
+
)
|
593 |
+
)
|
594 |
+
else:
|
595 |
+
attentions.append(None)
|
596 |
+
|
597 |
+
resnets.append(
|
598 |
+
ResnetBlockCausal3D(
|
599 |
+
in_channels=in_channels,
|
600 |
+
out_channels=in_channels,
|
601 |
+
temb_channels=temb_channels,
|
602 |
+
eps=resnet_eps,
|
603 |
+
groups=resnet_groups,
|
604 |
+
dropout=dropout,
|
605 |
+
time_embedding_norm=resnet_time_scale_shift,
|
606 |
+
non_linearity=resnet_act_fn,
|
607 |
+
output_scale_factor=output_scale_factor,
|
608 |
+
pre_norm=resnet_pre_norm,
|
609 |
+
)
|
610 |
+
)
|
611 |
+
|
612 |
+
self.attentions = nn.ModuleList(attentions)
|
613 |
+
self.resnets = nn.ModuleList(resnets)
|
614 |
+
|
615 |
+
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
|
616 |
+
hidden_states = self.resnets[0](hidden_states, temb)
|
617 |
+
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
618 |
+
if attn is not None:
|
619 |
+
B, C, T, H, W = hidden_states.shape
|
620 |
+
hidden_states = rearrange(hidden_states, "b c f h w -> b (f h w) c")
|
621 |
+
attention_mask = prepare_causal_attention_mask(
|
622 |
+
T, H * W, hidden_states.dtype, hidden_states.device, batch_size=B
|
623 |
+
)
|
624 |
+
hidden_states = attn(hidden_states, temb=temb, attention_mask=attention_mask)
|
625 |
+
hidden_states = rearrange(hidden_states, "b (f h w) c -> b c f h w", f=T, h=H, w=W)
|
626 |
+
hidden_states = resnet(hidden_states, temb)
|
627 |
+
|
628 |
+
return hidden_states
|
629 |
+
|
630 |
+
|
631 |
+
class DownEncoderBlockCausal3D(nn.Module):
|
632 |
+
def __init__(
|
633 |
+
self,
|
634 |
+
in_channels: int,
|
635 |
+
out_channels: int,
|
636 |
+
dropout: float = 0.0,
|
637 |
+
num_layers: int = 1,
|
638 |
+
resnet_eps: float = 1e-6,
|
639 |
+
resnet_time_scale_shift: str = "default",
|
640 |
+
resnet_act_fn: str = "swish",
|
641 |
+
resnet_groups: int = 32,
|
642 |
+
resnet_pre_norm: bool = True,
|
643 |
+
output_scale_factor: float = 1.0,
|
644 |
+
add_downsample: bool = True,
|
645 |
+
downsample_stride: int = 2,
|
646 |
+
downsample_padding: int = 1,
|
647 |
+
):
|
648 |
+
super().__init__()
|
649 |
+
resnets = []
|
650 |
+
|
651 |
+
for i in range(num_layers):
|
652 |
+
in_channels = in_channels if i == 0 else out_channels
|
653 |
+
resnets.append(
|
654 |
+
ResnetBlockCausal3D(
|
655 |
+
in_channels=in_channels,
|
656 |
+
out_channels=out_channels,
|
657 |
+
temb_channels=None,
|
658 |
+
eps=resnet_eps,
|
659 |
+
groups=resnet_groups,
|
660 |
+
dropout=dropout,
|
661 |
+
time_embedding_norm=resnet_time_scale_shift,
|
662 |
+
non_linearity=resnet_act_fn,
|
663 |
+
output_scale_factor=output_scale_factor,
|
664 |
+
pre_norm=resnet_pre_norm,
|
665 |
+
)
|
666 |
+
)
|
667 |
+
|
668 |
+
self.resnets = nn.ModuleList(resnets)
|
669 |
+
|
670 |
+
if add_downsample:
|
671 |
+
self.downsamplers = nn.ModuleList(
|
672 |
+
[
|
673 |
+
DownsampleCausal3D(
|
674 |
+
out_channels,
|
675 |
+
use_conv=True,
|
676 |
+
out_channels=out_channels,
|
677 |
+
padding=downsample_padding,
|
678 |
+
name="op",
|
679 |
+
stride=downsample_stride,
|
680 |
+
)
|
681 |
+
]
|
682 |
+
)
|
683 |
+
else:
|
684 |
+
self.downsamplers = None
|
685 |
+
|
686 |
+
def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
|
687 |
+
for resnet in self.resnets:
|
688 |
+
hidden_states = resnet(hidden_states, temb=None, scale=scale)
|
689 |
+
|
690 |
+
if self.downsamplers is not None:
|
691 |
+
for downsampler in self.downsamplers:
|
692 |
+
hidden_states = downsampler(hidden_states, scale)
|
693 |
+
|
694 |
+
return hidden_states
|
695 |
+
|
696 |
+
|
697 |
+
class UpDecoderBlockCausal3D(nn.Module):
|
698 |
+
def __init__(
|
699 |
+
self,
|
700 |
+
in_channels: int,
|
701 |
+
out_channels: int,
|
702 |
+
resolution_idx: Optional[int] = None,
|
703 |
+
dropout: float = 0.0,
|
704 |
+
num_layers: int = 1,
|
705 |
+
resnet_eps: float = 1e-6,
|
706 |
+
resnet_time_scale_shift: str = "default", # default, spatial
|
707 |
+
resnet_act_fn: str = "swish",
|
708 |
+
resnet_groups: int = 32,
|
709 |
+
resnet_pre_norm: bool = True,
|
710 |
+
output_scale_factor: float = 1.0,
|
711 |
+
add_upsample: bool = True,
|
712 |
+
upsample_scale_factor=(2, 2, 2),
|
713 |
+
temb_channels: Optional[int] = None,
|
714 |
+
):
|
715 |
+
super().__init__()
|
716 |
+
resnets = []
|
717 |
+
|
718 |
+
for i in range(num_layers):
|
719 |
+
input_channels = in_channels if i == 0 else out_channels
|
720 |
+
|
721 |
+
resnets.append(
|
722 |
+
ResnetBlockCausal3D(
|
723 |
+
in_channels=input_channels,
|
724 |
+
out_channels=out_channels,
|
725 |
+
temb_channels=temb_channels,
|
726 |
+
eps=resnet_eps,
|
727 |
+
groups=resnet_groups,
|
728 |
+
dropout=dropout,
|
729 |
+
time_embedding_norm=resnet_time_scale_shift,
|
730 |
+
non_linearity=resnet_act_fn,
|
731 |
+
output_scale_factor=output_scale_factor,
|
732 |
+
pre_norm=resnet_pre_norm,
|
733 |
+
)
|
734 |
+
)
|
735 |
+
|
736 |
+
self.resnets = nn.ModuleList(resnets)
|
737 |
+
|
738 |
+
if add_upsample:
|
739 |
+
self.upsamplers = nn.ModuleList(
|
740 |
+
[
|
741 |
+
UpsampleCausal3D(
|
742 |
+
out_channels,
|
743 |
+
use_conv=True,
|
744 |
+
out_channels=out_channels,
|
745 |
+
upsample_factor=upsample_scale_factor,
|
746 |
+
)
|
747 |
+
]
|
748 |
+
)
|
749 |
+
else:
|
750 |
+
self.upsamplers = None
|
751 |
+
|
752 |
+
self.resolution_idx = resolution_idx
|
753 |
+
|
754 |
+
def forward(
|
755 |
+
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
|
756 |
+
) -> torch.FloatTensor:
|
757 |
+
for resnet in self.resnets:
|
758 |
+
hidden_states = resnet(hidden_states, temb=temb, scale=scale)
|
759 |
+
|
760 |
+
if self.upsamplers is not None:
|
761 |
+
for upsampler in self.upsamplers:
|
762 |
+
hidden_states = upsampler(hidden_states)
|
763 |
+
|
764 |
+
return hidden_states
|
hyvideo/vae/vae.py
ADDED
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Optional, Tuple
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
from diffusers.utils import BaseOutput, is_torch_version
|
9 |
+
from diffusers.utils.torch_utils import randn_tensor
|
10 |
+
from diffusers.models.attention_processor import SpatialNorm
|
11 |
+
from .unet_causal_3d_blocks import (
|
12 |
+
CausalConv3d,
|
13 |
+
UNetMidBlockCausal3D,
|
14 |
+
get_down_block3d,
|
15 |
+
get_up_block3d,
|
16 |
+
)
|
17 |
+
|
18 |
+
|
19 |
+
@dataclass
|
20 |
+
class DecoderOutput(BaseOutput):
|
21 |
+
r"""
|
22 |
+
Output of decoding method.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
26 |
+
The decoded output sample from the last layer of the model.
|
27 |
+
"""
|
28 |
+
|
29 |
+
sample: torch.FloatTensor
|
30 |
+
|
31 |
+
|
32 |
+
class EncoderCausal3D(nn.Module):
|
33 |
+
r"""
|
34 |
+
The `EncoderCausal3D` layer of a variational autoencoder that encodes its input into a latent representation.
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
in_channels: int = 3,
|
40 |
+
out_channels: int = 3,
|
41 |
+
down_block_types: Tuple[str, ...] = ("DownEncoderBlockCausal3D",),
|
42 |
+
block_out_channels: Tuple[int, ...] = (64,),
|
43 |
+
layers_per_block: int = 2,
|
44 |
+
norm_num_groups: int = 32,
|
45 |
+
act_fn: str = "silu",
|
46 |
+
double_z: bool = True,
|
47 |
+
mid_block_add_attention=True,
|
48 |
+
time_compression_ratio: int = 4,
|
49 |
+
spatial_compression_ratio: int = 8,
|
50 |
+
):
|
51 |
+
super().__init__()
|
52 |
+
self.layers_per_block = layers_per_block
|
53 |
+
|
54 |
+
self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1)
|
55 |
+
self.mid_block = None
|
56 |
+
self.down_blocks = nn.ModuleList([])
|
57 |
+
|
58 |
+
# down
|
59 |
+
output_channel = block_out_channels[0]
|
60 |
+
for i, down_block_type in enumerate(down_block_types):
|
61 |
+
input_channel = output_channel
|
62 |
+
output_channel = block_out_channels[i]
|
63 |
+
is_final_block = i == len(block_out_channels) - 1
|
64 |
+
num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio))
|
65 |
+
num_time_downsample_layers = int(np.log2(time_compression_ratio))
|
66 |
+
|
67 |
+
if time_compression_ratio == 4:
|
68 |
+
add_spatial_downsample = bool(i < num_spatial_downsample_layers)
|
69 |
+
add_time_downsample = bool(
|
70 |
+
i >= (len(block_out_channels) - 1 - num_time_downsample_layers)
|
71 |
+
and not is_final_block
|
72 |
+
)
|
73 |
+
else:
|
74 |
+
raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.")
|
75 |
+
|
76 |
+
downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1)
|
77 |
+
downsample_stride_T = (2,) if add_time_downsample else (1,)
|
78 |
+
downsample_stride = tuple(downsample_stride_T + downsample_stride_HW)
|
79 |
+
down_block = get_down_block3d(
|
80 |
+
down_block_type,
|
81 |
+
num_layers=self.layers_per_block,
|
82 |
+
in_channels=input_channel,
|
83 |
+
out_channels=output_channel,
|
84 |
+
add_downsample=bool(add_spatial_downsample or add_time_downsample),
|
85 |
+
downsample_stride=downsample_stride,
|
86 |
+
resnet_eps=1e-6,
|
87 |
+
downsample_padding=0,
|
88 |
+
resnet_act_fn=act_fn,
|
89 |
+
resnet_groups=norm_num_groups,
|
90 |
+
attention_head_dim=output_channel,
|
91 |
+
temb_channels=None,
|
92 |
+
)
|
93 |
+
self.down_blocks.append(down_block)
|
94 |
+
|
95 |
+
# mid
|
96 |
+
self.mid_block = UNetMidBlockCausal3D(
|
97 |
+
in_channels=block_out_channels[-1],
|
98 |
+
resnet_eps=1e-6,
|
99 |
+
resnet_act_fn=act_fn,
|
100 |
+
output_scale_factor=1,
|
101 |
+
resnet_time_scale_shift="default",
|
102 |
+
attention_head_dim=block_out_channels[-1],
|
103 |
+
resnet_groups=norm_num_groups,
|
104 |
+
temb_channels=None,
|
105 |
+
add_attention=mid_block_add_attention,
|
106 |
+
)
|
107 |
+
|
108 |
+
# out
|
109 |
+
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
|
110 |
+
self.conv_act = nn.SiLU()
|
111 |
+
|
112 |
+
conv_out_channels = 2 * out_channels if double_z else out_channels
|
113 |
+
self.conv_out = CausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3)
|
114 |
+
|
115 |
+
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
116 |
+
r"""The forward method of the `EncoderCausal3D` class."""
|
117 |
+
assert len(sample.shape) == 5, "The input tensor should have 5 dimensions"
|
118 |
+
|
119 |
+
sample = self.conv_in(sample)
|
120 |
+
|
121 |
+
# down
|
122 |
+
for down_block in self.down_blocks:
|
123 |
+
sample = down_block(sample)
|
124 |
+
|
125 |
+
# middle
|
126 |
+
sample = self.mid_block(sample)
|
127 |
+
|
128 |
+
# post-process
|
129 |
+
sample = self.conv_norm_out(sample)
|
130 |
+
sample = self.conv_act(sample)
|
131 |
+
sample = self.conv_out(sample)
|
132 |
+
|
133 |
+
return sample
|
134 |
+
|
135 |
+
|
136 |
+
class DecoderCausal3D(nn.Module):
|
137 |
+
r"""
|
138 |
+
The `DecoderCausal3D` layer of a variational autoencoder that decodes its latent representation into an output sample.
|
139 |
+
"""
|
140 |
+
|
141 |
+
def __init__(
|
142 |
+
self,
|
143 |
+
in_channels: int = 3,
|
144 |
+
out_channels: int = 3,
|
145 |
+
up_block_types: Tuple[str, ...] = ("UpDecoderBlockCausal3D",),
|
146 |
+
block_out_channels: Tuple[int, ...] = (64,),
|
147 |
+
layers_per_block: int = 2,
|
148 |
+
norm_num_groups: int = 32,
|
149 |
+
act_fn: str = "silu",
|
150 |
+
norm_type: str = "group", # group, spatial
|
151 |
+
mid_block_add_attention=True,
|
152 |
+
time_compression_ratio: int = 4,
|
153 |
+
spatial_compression_ratio: int = 8,
|
154 |
+
):
|
155 |
+
super().__init__()
|
156 |
+
self.layers_per_block = layers_per_block
|
157 |
+
|
158 |
+
self.conv_in = CausalConv3d(in_channels, block_out_channels[-1], kernel_size=3, stride=1)
|
159 |
+
self.mid_block = None
|
160 |
+
self.up_blocks = nn.ModuleList([])
|
161 |
+
|
162 |
+
temb_channels = in_channels if norm_type == "spatial" else None
|
163 |
+
|
164 |
+
# mid
|
165 |
+
self.mid_block = UNetMidBlockCausal3D(
|
166 |
+
in_channels=block_out_channels[-1],
|
167 |
+
resnet_eps=1e-6,
|
168 |
+
resnet_act_fn=act_fn,
|
169 |
+
output_scale_factor=1,
|
170 |
+
resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
|
171 |
+
attention_head_dim=block_out_channels[-1],
|
172 |
+
resnet_groups=norm_num_groups,
|
173 |
+
temb_channels=temb_channels,
|
174 |
+
add_attention=mid_block_add_attention,
|
175 |
+
)
|
176 |
+
|
177 |
+
# up
|
178 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
179 |
+
output_channel = reversed_block_out_channels[0]
|
180 |
+
for i, up_block_type in enumerate(up_block_types):
|
181 |
+
prev_output_channel = output_channel
|
182 |
+
output_channel = reversed_block_out_channels[i]
|
183 |
+
is_final_block = i == len(block_out_channels) - 1
|
184 |
+
num_spatial_upsample_layers = int(np.log2(spatial_compression_ratio))
|
185 |
+
num_time_upsample_layers = int(np.log2(time_compression_ratio))
|
186 |
+
|
187 |
+
if time_compression_ratio == 4:
|
188 |
+
add_spatial_upsample = bool(i < num_spatial_upsample_layers)
|
189 |
+
add_time_upsample = bool(
|
190 |
+
i >= len(block_out_channels) - 1 - num_time_upsample_layers
|
191 |
+
and not is_final_block
|
192 |
+
)
|
193 |
+
else:
|
194 |
+
raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.")
|
195 |
+
|
196 |
+
upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, 1)
|
197 |
+
upsample_scale_factor_T = (2,) if add_time_upsample else (1,)
|
198 |
+
upsample_scale_factor = tuple(upsample_scale_factor_T + upsample_scale_factor_HW)
|
199 |
+
up_block = get_up_block3d(
|
200 |
+
up_block_type,
|
201 |
+
num_layers=self.layers_per_block + 1,
|
202 |
+
in_channels=prev_output_channel,
|
203 |
+
out_channels=output_channel,
|
204 |
+
prev_output_channel=None,
|
205 |
+
add_upsample=bool(add_spatial_upsample or add_time_upsample),
|
206 |
+
upsample_scale_factor=upsample_scale_factor,
|
207 |
+
resnet_eps=1e-6,
|
208 |
+
resnet_act_fn=act_fn,
|
209 |
+
resnet_groups=norm_num_groups,
|
210 |
+
attention_head_dim=output_channel,
|
211 |
+
temb_channels=temb_channels,
|
212 |
+
resnet_time_scale_shift=norm_type,
|
213 |
+
)
|
214 |
+
self.up_blocks.append(up_block)
|
215 |
+
prev_output_channel = output_channel
|
216 |
+
|
217 |
+
# out
|
218 |
+
if norm_type == "spatial":
|
219 |
+
self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
|
220 |
+
else:
|
221 |
+
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
|
222 |
+
self.conv_act = nn.SiLU()
|
223 |
+
self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3)
|
224 |
+
|
225 |
+
self.gradient_checkpointing = False
|
226 |
+
|
227 |
+
def forward(
|
228 |
+
self,
|
229 |
+
sample: torch.FloatTensor,
|
230 |
+
latent_embeds: Optional[torch.FloatTensor] = None,
|
231 |
+
) -> torch.FloatTensor:
|
232 |
+
r"""The forward method of the `DecoderCausal3D` class."""
|
233 |
+
assert len(sample.shape) == 5, "The input tensor should have 5 dimensions."
|
234 |
+
|
235 |
+
sample = self.conv_in(sample)
|
236 |
+
|
237 |
+
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
|
238 |
+
if self.training and self.gradient_checkpointing:
|
239 |
+
|
240 |
+
def create_custom_forward(module):
|
241 |
+
def custom_forward(*inputs):
|
242 |
+
return module(*inputs)
|
243 |
+
|
244 |
+
return custom_forward
|
245 |
+
|
246 |
+
if is_torch_version(">=", "1.11.0"):
|
247 |
+
# middle
|
248 |
+
sample = torch.utils.checkpoint.checkpoint(
|
249 |
+
create_custom_forward(self.mid_block),
|
250 |
+
sample,
|
251 |
+
latent_embeds,
|
252 |
+
use_reentrant=False,
|
253 |
+
)
|
254 |
+
sample = sample.to(upscale_dtype)
|
255 |
+
|
256 |
+
# up
|
257 |
+
for up_block in self.up_blocks:
|
258 |
+
sample = torch.utils.checkpoint.checkpoint(
|
259 |
+
create_custom_forward(up_block),
|
260 |
+
sample,
|
261 |
+
latent_embeds,
|
262 |
+
use_reentrant=False,
|
263 |
+
)
|
264 |
+
else:
|
265 |
+
# middle
|
266 |
+
sample = torch.utils.checkpoint.checkpoint(
|
267 |
+
create_custom_forward(self.mid_block), sample, latent_embeds
|
268 |
+
)
|
269 |
+
sample = sample.to(upscale_dtype)
|
270 |
+
|
271 |
+
# up
|
272 |
+
for up_block in self.up_blocks:
|
273 |
+
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
|
274 |
+
else:
|
275 |
+
# middle
|
276 |
+
sample = self.mid_block(sample, latent_embeds)
|
277 |
+
sample = sample.to(upscale_dtype)
|
278 |
+
|
279 |
+
# up
|
280 |
+
for up_block in self.up_blocks:
|
281 |
+
sample = up_block(sample, latent_embeds)
|
282 |
+
|
283 |
+
# post-process
|
284 |
+
if latent_embeds is None:
|
285 |
+
sample = self.conv_norm_out(sample)
|
286 |
+
else:
|
287 |
+
sample = self.conv_norm_out(sample, latent_embeds)
|
288 |
+
sample = self.conv_act(sample)
|
289 |
+
sample = self.conv_out(sample)
|
290 |
+
|
291 |
+
return sample
|
292 |
+
|
293 |
+
|
294 |
+
class DiagonalGaussianDistribution(object):
|
295 |
+
def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
|
296 |
+
if parameters.ndim == 3:
|
297 |
+
dim = 2 # (B, L, C)
|
298 |
+
elif parameters.ndim == 5 or parameters.ndim == 4:
|
299 |
+
dim = 1 # (B, C, T, H ,W) / (B, C, H, W)
|
300 |
+
else:
|
301 |
+
raise NotImplementedError
|
302 |
+
self.parameters = parameters
|
303 |
+
self.mean, self.logvar = torch.chunk(parameters, 2, dim=dim)
|
304 |
+
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
305 |
+
self.deterministic = deterministic
|
306 |
+
self.std = torch.exp(0.5 * self.logvar)
|
307 |
+
self.var = torch.exp(self.logvar)
|
308 |
+
if self.deterministic:
|
309 |
+
self.var = self.std = torch.zeros_like(
|
310 |
+
self.mean, device=self.parameters.device, dtype=self.parameters.dtype
|
311 |
+
)
|
312 |
+
|
313 |
+
def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
|
314 |
+
# make sure sample is on the same device as the parameters and has same dtype
|
315 |
+
sample = randn_tensor(
|
316 |
+
self.mean.shape,
|
317 |
+
generator=generator,
|
318 |
+
device=self.parameters.device,
|
319 |
+
dtype=self.parameters.dtype,
|
320 |
+
)
|
321 |
+
x = self.mean + self.std * sample
|
322 |
+
return x
|
323 |
+
|
324 |
+
def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
|
325 |
+
if self.deterministic:
|
326 |
+
return torch.Tensor([0.0])
|
327 |
+
else:
|
328 |
+
reduce_dim = list(range(1, self.mean.ndim))
|
329 |
+
if other is None:
|
330 |
+
return 0.5 * torch.sum(
|
331 |
+
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
|
332 |
+
dim=reduce_dim,
|
333 |
+
)
|
334 |
+
else:
|
335 |
+
return 0.5 * torch.sum(
|
336 |
+
torch.pow(self.mean - other.mean, 2) / other.var
|
337 |
+
+ self.var / other.var
|
338 |
+
- 1.0
|
339 |
+
- self.logvar
|
340 |
+
+ other.logvar,
|
341 |
+
dim=reduce_dim,
|
342 |
+
)
|
343 |
+
|
344 |
+
def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
|
345 |
+
if self.deterministic:
|
346 |
+
return torch.Tensor([0.0])
|
347 |
+
logtwopi = np.log(2.0 * np.pi)
|
348 |
+
return 0.5 * torch.sum(
|
349 |
+
logtwopi + self.logvar +
|
350 |
+
torch.pow(sample - self.mean, 2) / self.var,
|
351 |
+
dim=dims,
|
352 |
+
)
|
353 |
+
|
354 |
+
def mode(self) -> torch.Tensor:
|
355 |
+
return self.mean
|
requirements.txt
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.5.1
|
2 |
+
torchvision==0.20.1
|
3 |
+
torchaudio==2.5.1
|
4 |
+
opencv-python==4.9.0.80
|
5 |
+
diffusers==0.30.2
|
6 |
+
transformers==4.39.3
|
7 |
+
tokenizers==0.15.2
|
8 |
+
accelerate==1.1.1
|
9 |
+
pandas==2.0.3
|
10 |
+
numpy==1.24.4
|
11 |
+
einops==0.7.0
|
12 |
+
tqdm==4.66.2
|
13 |
+
loguru==0.7.2
|
14 |
+
imageio==2.34.0
|
15 |
+
imageio-ffmpeg==0.5.1
|
16 |
+
safetensors==0.4.3
|
17 |
+
mmgp==3.0.3
|
18 |
+
gradio==5.8.0
|
19 |
+
moviepy==1.0.3
|
20 |
+
#flash-attn==2.7.2.post1
|
21 |
+
#sageattention==1.0.6
|
requirements_xdit.txt
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
opencv-python==4.9.0.80
|
2 |
+
diffusers==0.31.0
|
3 |
+
transformers==4.46.3
|
4 |
+
tokenizers==0.20.3
|
5 |
+
accelerate==1.1.1
|
6 |
+
pandas==2.0.3
|
7 |
+
numpy==1.24.4
|
8 |
+
einops==0.7.0
|
9 |
+
tqdm==4.66.2
|
10 |
+
loguru==0.7.2
|
11 |
+
imageio==2.34.0
|
12 |
+
imageio-ffmpeg==0.5.1
|
13 |
+
safetensors==0.4.3
|
14 |
+
ninja
|
15 |
+
flash-attn==2.6.3
|
16 |
+
xfuser==0.4.0
|
sample_video.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
from pathlib import Path
|
4 |
+
from loguru import logger
|
5 |
+
from datetime import datetime
|
6 |
+
|
7 |
+
from hyvideo.utils.file_utils import save_videos_grid
|
8 |
+
from hyvideo.config import parse_args
|
9 |
+
from hyvideo.inference import HunyuanVideoSampler
|
10 |
+
|
11 |
+
|
12 |
+
def main():
|
13 |
+
args = parse_args()
|
14 |
+
print(args)
|
15 |
+
models_root_path = Path(args.model_base)
|
16 |
+
if not models_root_path.exists():
|
17 |
+
raise ValueError(f"`models_root` not exists: {models_root_path}")
|
18 |
+
|
19 |
+
# Create save folder to save the samples
|
20 |
+
save_path = args.save_path if args.save_path_suffix=="" else f'{args.save_path}_{args.save_path_suffix}'
|
21 |
+
if not os.path.exists(args.save_path):
|
22 |
+
os.makedirs(save_path, exist_ok=True)
|
23 |
+
|
24 |
+
models_root_path = "ckpts/hunyuan-video-t2v-720p/transformers/hunyuan_video_720_bf16.safetensors"
|
25 |
+
text_encoder_filename = "ckpts/hunyuan-video-t2v-720p/text_encoder/llava-llama-3-8b_fp16.safetensors"
|
26 |
+
import json
|
27 |
+
with open("./ckpts/hunyuan-video-t2v-720p/vae/config.json", "r", encoding="utf-8") as reader:
|
28 |
+
text = reader.read()
|
29 |
+
vae_config= json.loads(text)
|
30 |
+
# reduce time window used by the VAE for temporal splitting (former time windows is too large for 24 GB)
|
31 |
+
if vae_config["sample_tsize"] == 64:
|
32 |
+
vae_config["sample_tsize"] = 32
|
33 |
+
with open("./ckpts/hunyuan-video-t2v-720p/vae/config.json", "w", encoding="utf-8") as writer:
|
34 |
+
writer.write(json.dumps(vae_config))
|
35 |
+
|
36 |
+
# Load models
|
37 |
+
hunyuan_video_sampler = HunyuanVideoSampler.from_pretrained(models_root_path,text_encoder_filename, args=args)
|
38 |
+
|
39 |
+
from mmgp import offload, profile_type
|
40 |
+
pipe = hunyuan_video_sampler.pipeline
|
41 |
+
offload.profile(pipe, profile_no= profile_type.HighRAM_LowVRAM_Fast)
|
42 |
+
|
43 |
+
# Get the updated args
|
44 |
+
args = hunyuan_video_sampler.args
|
45 |
+
|
46 |
+
# Start sampling
|
47 |
+
# TODO: batch inference check
|
48 |
+
outputs = hunyuan_video_sampler.predict(
|
49 |
+
prompt=args.prompt,
|
50 |
+
height=args.video_size[0],
|
51 |
+
width=args.video_size[1],
|
52 |
+
video_length=args.video_length,
|
53 |
+
seed=args.seed,
|
54 |
+
negative_prompt=args.neg_prompt,
|
55 |
+
infer_steps=args.infer_steps,
|
56 |
+
guidance_scale=args.cfg_scale,
|
57 |
+
num_videos_per_prompt=args.num_videos,
|
58 |
+
flow_shift=args.flow_shift,
|
59 |
+
batch_size=args.batch_size,
|
60 |
+
embedded_guidance_scale=args.embedded_cfg_scale
|
61 |
+
)
|
62 |
+
samples = outputs['samples']
|
63 |
+
|
64 |
+
# Save samples
|
65 |
+
if 'LOCAL_RANK' not in os.environ or int(os.environ['LOCAL_RANK']) == 0:
|
66 |
+
for i, sample in enumerate(samples):
|
67 |
+
sample = samples[i].unsqueeze(0)
|
68 |
+
time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%H:%M:%S")
|
69 |
+
save_path = f"{save_path}/{time_flag}_seed{outputs['seeds'][i]}_{outputs['prompts'][i][:100].replace('/','')}.mp4"
|
70 |
+
save_videos_grid(sample, save_path, fps=24)
|
71 |
+
logger.info(f'Sample save to: {save_path}')
|
72 |
+
|
73 |
+
if __name__ == "__main__":
|
74 |
+
main()
|