Rodneyontherock1067 commited on
Commit
42d94eb
·
verified ·
1 Parent(s): ccac407

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +4 -0
  2. .gitignore +2 -0
  3. LICENSE.txt +77 -0
  4. Notice +233 -0
  5. README.md +172 -12
  6. README_zh.md +425 -0
  7. assets/3dvae.png +0 -0
  8. assets/WECHAT.md +7 -0
  9. assets/backbone.png +3 -0
  10. assets/hunyuanvideo.pdf +3 -0
  11. assets/logo.png +0 -0
  12. assets/overall.png +3 -0
  13. assets/text_encoder.png +3 -0
  14. assets/wechat.jpg +0 -0
  15. docker/Dockerfile_xDiT +41 -0
  16. environment.yml +8 -0
  17. gradio_server.py +376 -0
  18. hyvideo/__init__.py +0 -0
  19. hyvideo/config.py +406 -0
  20. hyvideo/constants.py +90 -0
  21. hyvideo/diffusion/__init__.py +2 -0
  22. hyvideo/diffusion/pipelines/__init__.py +1 -0
  23. hyvideo/diffusion/pipelines/pipeline_hunyuan_video.py +1103 -0
  24. hyvideo/diffusion/schedulers/__init__.py +1 -0
  25. hyvideo/diffusion/schedulers/scheduling_flow_match_discrete.py +257 -0
  26. hyvideo/inference.py +682 -0
  27. hyvideo/modules/__init__.py +26 -0
  28. hyvideo/modules/activation_layers.py +23 -0
  29. hyvideo/modules/attenion.py +257 -0
  30. hyvideo/modules/embed_layers.py +157 -0
  31. hyvideo/modules/mlp_layers.py +118 -0
  32. hyvideo/modules/models.py +870 -0
  33. hyvideo/modules/modulate_layers.py +76 -0
  34. hyvideo/modules/norm_layers.py +77 -0
  35. hyvideo/modules/posemb_layers.py +310 -0
  36. hyvideo/modules/token_refiner.py +236 -0
  37. hyvideo/prompt_rewrite.py +51 -0
  38. hyvideo/text_encoder/__init__.py +366 -0
  39. hyvideo/utils/__init__.py +0 -0
  40. hyvideo/utils/data_utils.py +15 -0
  41. hyvideo/utils/file_utils.py +70 -0
  42. hyvideo/utils/helpers.py +40 -0
  43. hyvideo/utils/preprocess_text_encoder_tokenizer_utils.py +46 -0
  44. hyvideo/vae/__init__.py +62 -0
  45. hyvideo/vae/autoencoder_kl_causal_3d.py +603 -0
  46. hyvideo/vae/unet_causal_3d_blocks.py +764 -0
  47. hyvideo/vae/vae.py +355 -0
  48. requirements.txt +21 -0
  49. requirements_xdit.txt +16 -0
  50. sample_video.py +74 -0
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/backbone.png filter=lfs diff=lfs merge=lfs -text
37
+ assets/hunyuanvideo.pdf filter=lfs diff=lfs merge=lfs -text
38
+ assets/overall.png filter=lfs diff=lfs merge=lfs -text
39
+ assets/text_encoder.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__
2
+ /ckpts/**/
LICENSE.txt ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT
2
+ Tencent HunyuanVideo Release Date: December 3, 2024
3
+ THIS LICENSE AGREEMENT DOES NOT APPLY IN THE EUROPEAN UNION, UNITED KINGDOM AND SOUTH KOREA AND IS EXPRESSLY LIMITED TO THE TERRITORY, AS DEFINED BELOW.
4
+ By clicking to agree or by using, reproducing, modifying, distributing, performing or displaying any portion or element of the Tencent Hunyuan Works, including via any Hosted Service, You will be deemed to have recognized and accepted the content of this Agreement, which is effective immediately.
5
+ 1. DEFINITIONS.
6
+ a. “Acceptable Use Policy” shall mean the policy made available by Tencent as set forth in the Exhibit A.
7
+ b. “Agreement” shall mean the terms and conditions for use, reproduction, distribution, modification, performance and displaying of Tencent Hunyuan Works or any portion or element thereof set forth herein.
8
+ c. “Documentation” shall mean the specifications, manuals and documentation for Tencent Hunyuan made publicly available by Tencent.
9
+ d. “Hosted Service” shall mean a hosted service offered via an application programming interface (API), web access, or any other electronic or remote means.
10
+ e. “Licensee,” “You” or “Your” shall mean a natural person or legal entity exercising the rights granted by this Agreement and/or using the Tencent Hunyuan Works for any purpose and in any field of use.
11
+ f. “Materials” shall mean, collectively, Tencent’s proprietary Tencent Hunyuan and Documentation (and any portion thereof) as made available by Tencent under this Agreement.
12
+ g. “Model Derivatives” shall mean all: (i) modifications to Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; (ii) works based on Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; or (iii) any other machine learning model which is created by transfer of patterns of the weights, parameters, operations, or Output of Tencent Hunyuan or any Model Derivative of Tencent Hunyuan, to that model in order to cause that model to perform similarly to Tencent Hunyuan or a Model Derivative of Tencent Hunyuan, including distillation methods, methods that use intermediate data representations, or methods based on the generation of synthetic data Outputs by Tencent Hunyuan or a Model Derivative of Tencent Hunyuan for training that model. For clarity, Outputs by themselves are not deemed Model Derivatives.
13
+ h. “Output” shall mean the information and/or content output of Tencent Hunyuan or a Model Derivative that results from operating or otherwise using Tencent Hunyuan or a Model Derivative, including via a Hosted Service.
14
+ i. “Tencent,” “We” or “Us” shall mean THL A29 Limited.
15
+ j. “Tencent Hunyuan” shall mean the large language models, text/image/video/audio/3D generation models, and multimodal large language models and their software and algorithms, including trained model weights, parameters (including optimizer states), machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing made publicly available by Us, including, without limitation to, Tencent HunyuanVideo released at [https://github.com/Tencent/HunyuanVideo].
16
+ k. “Tencent Hunyuan Works” shall mean: (i) the Materials; (ii) Model Derivatives; and (iii) all derivative works thereof.
17
+ l. “Territory” shall mean the worldwide territory, excluding the territory of the European Union, United Kingdom and South Korea.
18
+ m. “Third Party” or “Third Parties” shall mean individuals or legal entities that are not under common control with Us or You.
19
+ n. “including” shall mean including but not limited to.
20
+ 2. GRANT OF RIGHTS.
21
+ We grant You, for the Territory only, a non-exclusive, non-transferable and royalty-free limited license under Tencent’s intellectual property or other rights owned by Us embodied in or utilized by the Materials to use, reproduce, distribute, create derivative works of (including Model Derivatives), and make modifications to the Materials, only in accordance with the terms of this Agreement and the Acceptable Use Policy, and You must not violate (or encourage or permit anyone else to violate) any term of this Agreement or the Acceptable Use Policy.
22
+ 3. DISTRIBUTION.
23
+ You may, subject to Your compliance with this Agreement, distribute or make available to Third Parties the Tencent Hunyuan Works, exclusively in the Territory, provided that You meet all of the following conditions:
24
+ a. You must provide all such Third Party recipients of the Tencent Hunyuan Works or products or services using them a copy of this Agreement;
25
+ b. You must cause any modified files to carry prominent notices stating that You changed the files;
26
+ c. You are encouraged to: (i) publish at least one technology introduction blogpost or one public statement expressing Your experience of using the Tencent Hunyuan Works; and (ii) mark the products or services developed by using the Tencent Hunyuan Works to indicate that the product/service is “Powered by Tencent Hunyuan”; and
27
+ d. All distributions to Third Parties (other than through a Hosted Service) must be accompanied by a “Notice” text file that contains the following notice: “Tencent Hunyuan is licensed under the Tencent Hunyuan Community License Agreement, Copyright © 2024 Tencent. All Rights Reserved. The trademark rights of “Tencent Hunyuan” are owned by Tencent or its affiliate.”
28
+ You may add Your own copyright statement to Your modifications and, except as set forth in this Section and in Section 5, may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Model Derivatives as a whole, provided Your use, reproduction, modification, distribution, performance and display of the work otherwise complies with the terms and conditions of this Agreement (including as regards the Territory). If You receive Tencent Hunyuan Works from a Licensee as part of an integrated end user product, then this Section 3 of this Agreement will not apply to You.
29
+ 4. ADDITIONAL COMMERCIAL TERMS.
30
+ If, on the Tencent Hunyuan version release date, the monthly active users of all products or services made available by or for Licensee is greater than 100 million monthly active users in the preceding calendar month, You must request a license from Tencent, which Tencent may grant to You in its sole discretion, and You are not authorized to exercise any of the rights under this Agreement unless or until Tencent otherwise expressly grants You such rights.
31
+ 5. RULES OF USE.
32
+ a. Your use of the Tencent Hunyuan Works must comply with applicable laws and regulations (including trade compliance laws and regulations) and adhere to the Acceptable Use Policy for the Tencent Hunyuan Works, which is hereby incorporated by reference into this Agreement. You must include the use restrictions referenced in these Sections 5(a) and 5(b) as an enforceable provision in any agreement (e.g., license agreement, terms of use, etc.) governing the use and/or distribution of Tencent Hunyuan Works and You must provide notice to subsequent users to whom You distribute that Tencent Hunyuan Works are subject to the use restrictions in these Sections 5(a) and 5(b).
33
+ b. You must not use the Tencent Hunyuan Works or any Output or results of the Tencent Hunyuan Works to improve any other AI model (other than Tencent Hunyuan or Model Derivatives thereof).
34
+ c. You must not use, reproduce, modify, distribute, or display the Tencent Hunyuan Works, Output or results of the Tencent Hunyuan Works outside the Territory. Any such use outside the Territory is unlicensed and unauthorized under this Agreement.
35
+ 6. INTELLECTUAL PROPERTY.
36
+ a. Subject to Tencent’s ownership of Tencent Hunyuan Works made by or for Tencent and intellectual property rights therein, conditioned upon Your compliance with the terms and conditions of this Agreement, as between You and Tencent, You will be the owner of any derivative works and modifications of the Materials and any Model Derivatives that are made by or for You.
37
+ b. No trademark licenses are granted under this Agreement, and in connection with the Tencent Hunyuan Works, Licensee may not use any name or mark owned by or associated with Tencent or any of its affiliates, except as required for reasonable and customary use in describing and distributing the Tencent Hunyuan Works. Tencent hereby grants You a license to use “Tencent Hunyuan” (the “Mark”) in the Territory solely as required to comply with the provisions of Section 3(c), provided that You comply with any applicable laws related to trademark protection. All goodwill arising out of Your use of the Mark will inure to the benefit of Tencent.
38
+ c. If You commence a lawsuit or other proceedings (including a cross-claim or counterclaim in a lawsuit) against Us or any person or entity alleging that the Materials or any Output, or any portion of any of the foregoing, infringe any intellectual property or other right owned or licensable by You, then all licenses granted to You under this Agreement shall terminate as of the date such lawsuit or other proceeding is filed. You will defend, indemnify and hold harmless Us from and against any claim by any Third Party arising out of or related to Your or the Third Party’s use or distribution of the Tencent Hunyuan Works.
39
+ d. Tencent claims no rights in Outputs You generate. You and Your users are solely responsible for Outputs and their subsequent uses.
40
+ 7. DISCLAIMERS OF WARRANTY AND LIMITATIONS OF LIABILITY.
41
+ a. We are not obligated to support, update, provide training for, or develop any further version of the Tencent Hunyuan Works or to grant any license thereto.
42
+ b. UNLESS AND ONLY TO THE EXTENT REQUIRED BY APPLICABLE LAW, THE TENCENT HUNYUAN WORKS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED “AS IS” WITHOUT ANY EXPRESS OR IMPLIED WARRANTIES OF ANY KIND INCLUDING ANY WARRANTIES OF TITLE, MERCHANTABILITY, NONINFRINGEMENT, COURSE OF DEALING, USAGE OF TRADE, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING, REPRODUCING, MODIFYING, PERFORMING, DISPLAYING OR DISTRIBUTING ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND ASSUME ANY AND ALL RISKS ASSOCIATED WITH YOUR OR A THIRD PARTY’S USE OR DISTRIBUTION OF ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND YOUR EXERCISE OF RIGHTS AND PERMISSIONS UNDER THIS AGREEMENT.
43
+ c. TO THE FULLEST EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT SHALL TENCENT OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, FOR ANY DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, CONSEQUENTIAL OR PUNITIVE DAMAGES, OR LOST PROFITS OF ANY KIND ARISING FROM THIS AGREEMENT OR RELATED TO ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS, EVEN IF TENCENT OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
44
+ 8. SURVIVAL AND TERMINATION.
45
+ a. The term of this Agreement shall commence upon Your acceptance of this Agreement or access to the Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein.
46
+ b. We may terminate this Agreement if You breach any of the terms or conditions of this Agreement. Upon termination of this Agreement, You must promptly delete and cease use of the Tencent Hunyuan Works. Sections 6(a), 6(c), 7 and 9 shall survive the termination of this Agreement.
47
+ 9. GOVERNING LAW AND JURISDICTION.
48
+ a. This Agreement and any dispute arising out of or relating to it will be governed by the laws of the Hong Kong Special Administrative Region of the People’s Republic of China, without regard to conflict of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement.
49
+ b. Exclusive jurisdiction and venue for any dispute arising out of or relating to this Agreement will be a court of competent jurisdiction in the Hong Kong Special Administrative Region of the People’s Republic of China, and Tencent and Licensee consent to the exclusive jurisdiction of such court with respect to any such dispute.
50
+
51
+ EXHIBIT A
52
+ ACCEPTABLE USE POLICY
53
+
54
+ Tencent reserves the right to update this Acceptable Use Policy from time to time.
55
+ Last modified: November 5, 2024
56
+
57
+ Tencent endeavors to promote safe and fair use of its tools and features, including Tencent Hunyuan. You agree not to use Tencent Hunyuan or Model Derivatives:
58
+ 1. Outside the Territory;
59
+ 2. In any way that violates any applicable national, federal, state, local, international or any other law or regulation;
60
+ 3. To harm Yourself or others;
61
+ 4. To repurpose or distribute output from Tencent Hunyuan or any Model Derivatives to harm Yourself or others;
62
+ 5. To override or circumvent the safety guardrails and safeguards We have put in place;
63
+ 6. For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;
64
+ 7. To generate or disseminate verifiably false information and/or content with the purpose of harming others or influencing elections;
65
+ 8. To generate or facilitate false online engagement, including fake reviews and other means of fake online engagement;
66
+ 9. To intentionally defame, disparage or otherwise harass others;
67
+ 10. To generate and/or disseminate malware (including ransomware) or any other content to be used for the purpose of harming electronic systems;
68
+ 11. To generate or disseminate personal identifiable information with the purpose of harming others;
69
+ 12. To generate or disseminate information (including images, code, posts, articles), and place the information in any public context (including –through the use of bot generated tweets), without expressly and conspicuously identifying that the information and/or content is machine generated;
70
+ 13. To impersonate another individual without consent, authorization, or legal right;
71
+ 14. To make high-stakes automated decisions in domains that affect an individual’s safety, rights or wellbeing (e.g., law enforcement, migration, medicine/health, management of critical infrastructure, safety components of products, essential services, credit, employment, housing, education, social scoring, or insurance);
72
+ 15. In a manner that violates or disrespects the social ethics and moral standards of other countries or regions;
73
+ 16. To perform, facilitate, threaten, incite, plan, promote or encourage violent extremism or terrorism;
74
+ 17. For any use intended to discriminate against or harm individuals or groups based on protected characteristics or categories, online or offline social behavior or known or predicted personal or personality characteristics;
75
+ 18. To intentionally exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;
76
+ 19. For military purposes;
77
+ 20. To engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or other professional practices.
Notice ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Usage and Legal Notices:
2
+
3
+ Tencent is pleased to support the open source community by making Tencent HunyuanVideo available.
4
+
5
+ Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. The below software and/or models in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) THL A29 Limited.
6
+
7
+ Tencent HunyuanVideo is licensed under the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT except for the third-party components listed below. Tencent HunyuanVideo does not impose any additional limitations beyond what is outlined in the repsective licenses of these third-party components. Users must comply with all terms and conditions of original licenses of these third-party components and must ensure that the usage of the third party components adheres to all relevant laws and regulations.
8
+
9
+ For avoidance of doubts, Tencent HunyuanVideo means the large language models and their software and algorithms, including trained model weights, parameters (including optimizer states), machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing may be made publicly available by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
10
+
11
+
12
+ Other dependencies and licenses:
13
+
14
+
15
+ Open Source Model Licensed under the Apache License Version 2.0:
16
+ The below software in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
17
+ --------------------------------------------------------------------
18
+ 1. diffusers
19
+ Copyright (c) diffusers original author and authors
20
+ Please note this software has been modified by Tencent in this distribution.
21
+
22
+ 2. transformers
23
+ Copyright (c) transformers original author and authors
24
+
25
+ 3. safetensors
26
+ Copyright (c) safetensors original author and authors
27
+
28
+ 4. flux
29
+ Copyright (c) flux original author and authors
30
+
31
+
32
+ Terms of the Apache License Version 2.0:
33
+ --------------------------------------------------------------------
34
+ Apache License
35
+
36
+ Version 2.0, January 2004
37
+
38
+ http://www.apache.org/licenses/
39
+
40
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
41
+ 1. Definitions.
42
+
43
+ "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
44
+
45
+ "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
46
+
47
+ "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
48
+
49
+ "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License.
50
+
51
+ "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
52
+
53
+ "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
54
+
55
+ "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
56
+
57
+ "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
58
+
59
+ "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
60
+
61
+ "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
62
+
63
+ 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
64
+
65
+ 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
66
+
67
+ 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
68
+
69
+ You must give any other recipients of the Work or Derivative Works a copy of this License; and
70
+
71
+ You must cause any modified files to carry prominent notices stating that You changed the files; and
72
+
73
+ You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
74
+
75
+ If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
76
+
77
+ You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
78
+
79
+ 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
80
+
81
+ 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
82
+
83
+ 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
84
+
85
+ 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
86
+
87
+ 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
88
+
89
+ END OF TERMS AND CONDITIONS
90
+
91
+
92
+
93
+ Open Source Software Licensed under the BSD 2-Clause License:
94
+ --------------------------------------------------------------------
95
+ 1. imageio
96
+ Copyright (c) 2014-2022, imageio developers
97
+ All rights reserved.
98
+
99
+
100
+ Terms of the BSD 2-Clause License:
101
+ --------------------------------------------------------------------
102
+ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
103
+
104
+ * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
105
+
106
+ * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
107
+
108
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
109
+
110
+
111
+
112
+ Open Source Software Licensed under the BSD 3-Clause License:
113
+ --------------------------------------------------------------------
114
+ 1. torchvision
115
+ Copyright (c) Soumith Chintala 2016,
116
+ All rights reserved.
117
+
118
+ 2. flash-attn
119
+ Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file.
120
+ All rights reserved.
121
+
122
+
123
+ Terms of the BSD 3-Clause License:
124
+ --------------------------------------------------------------------
125
+ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
126
+
127
+ 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
128
+
129
+ 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
130
+
131
+ 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
132
+
133
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
134
+
135
+
136
+
137
+ Open Source Software Licensed under the BSD 3-Clause License and Other Licenses of the Third-Party Components therein:
138
+ --------------------------------------------------------------------
139
+ 1. torch
140
+ Copyright (c) 2016- Facebook, Inc (Adam Paszke)
141
+ Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
142
+ Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
143
+ Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
144
+ Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
145
+ Copyright (c) 2011-2013 NYU (Clement Farabet)
146
+ Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
147
+ Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
148
+ Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
149
+
150
+
151
+ A copy of the BSD 3-Clause is included in this file.
152
+
153
+ For the license of other third party components, please refer to the following URL:
154
+ https://github.com/pytorch/pytorch/tree/v2.1.1/third_party
155
+
156
+
157
+ Open Source Software Licensed under the BSD 3-Clause License and Other Licenses of the Third-Party Components therein:
158
+ --------------------------------------------------------------------
159
+ 1. pandas
160
+ Copyright (c) 2008-2011, AQR Capital Management, LLC, Lambda Foundry, Inc. and PyData Development Team
161
+ All rights reserved.
162
+
163
+ Copyright (c) 2011-2023, Open source contributors.
164
+
165
+
166
+ A copy of the BSD 3-Clause is included in this file.
167
+
168
+ For the license of other third party components, please refer to the following URL:
169
+ https://github.com/pandas-dev/pandas/tree/v2.0.3/LICENSES
170
+
171
+
172
+ Open Source Software Licensed under the BSD 3-Clause License and Other Licenses of the Third-Party Components therein:
173
+ --------------------------------------------------------------------
174
+ 1. numpy
175
+ Copyright (c) 2005-2022, NumPy Developers.
176
+ All rights reserved.
177
+
178
+
179
+ A copy of the BSD 3-Clause is included in this file.
180
+
181
+ For the license of other third party components, please refer to the following URL:
182
+ https://github.com/numpy/numpy/blob/v1.24.4/LICENSES_bundled.txt
183
+
184
+
185
+ Open Source Software Licensed under the MIT License:
186
+ --------------------------------------------------------------------
187
+ 1. einops
188
+ Copyright (c) 2018 Alex Rogozhnikov
189
+
190
+ 2. loguru
191
+ Copyright (c) 2017
192
+
193
+
194
+ Terms of the MIT License:
195
+ --------------------------------------------------------------------
196
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
197
+
198
+ The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
199
+
200
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
201
+
202
+
203
+
204
+ Open Source Software Licensed under the MIT License and Other Licenses of the Third-Party Components therein:
205
+ --------------------------------------------------------------------
206
+ 1. tqdm
207
+ Copyright (c) 2013 noamraph
208
+
209
+
210
+ A copy of the MIT is included in this file.
211
+
212
+ For the license of other third party components, please refer to the following URL:
213
+ https://github.com/tqdm/tqdm/blob/v4.66.2/LICENCE
214
+
215
+
216
+
217
+ Open Source Model Licensed under the MIT License:
218
+ --------------------------------------------------------------------
219
+ 1. clip-large
220
+ Copyright (c) 2021 OpenAI
221
+
222
+
223
+ A copy of the MIT is included in this file.
224
+
225
+
226
+ --------------------------------------------------------------------
227
+ We may also use other third-party components:
228
+
229
+ 1. llava-llama3
230
+
231
+ Copyright (c) llava-llama3 original author and authors
232
+
233
+ URL: https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers#model
README.md CHANGED
@@ -1,12 +1,172 @@
1
- ---
2
- title: Hvgp
3
- emoji: 🐠
4
- colorFrom: yellow
5
- colorTo: yellow
6
- sdk: gradio
7
- sdk_version: 5.10.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!-- ## **HunyuanVideo** -->
2
+
3
+ [中文阅读](./README_zh.md)
4
+
5
+
6
+ # HunyuanVideo: A Systematic Framework For Large Video Generation Model
7
+ <div align="center">
8
+ <a href="https://github.com/Tencent/HunyuanVideo"><img src="https://img.shields.io/static/v1?label=HunyuanVideo Code&message=Github&color=blue&logo=github-pages"></a> &ensp;
9
+ <a href="https://aivideo.hunyuan.tencent.com"><img src="https://img.shields.io/static/v1?label=Project%20Page&message=Web&color=green&logo=github-pages"></a> &ensp;
10
+ <a href="https://video.hunyuan.tencent.com"><img src="https://img.shields.io/static/v1?label=Playground&message=Web&color=green&logo=github-pages"></a> &ensp;
11
+ <a href="https://arxiv.org/abs/2412.03603"><img src="https://img.shields.io/static/v1?label=Tech Report&message=Arxiv:HunyuanVideo&color=red&logo=arxiv"></a> &ensp;
12
+ <a href="https://huggingface.co/tencent/HunyuanVideo"><img src="https://img.shields.io/static/v1?label=HunyuanVideo&message=HuggingFace&color=yellow"></a> &ensp; &ensp;
13
+ <a href="https://huggingface.co/tencent/HunyuanVideo-PromptRewrite"><img src="https://img.shields.io/static/v1?label=HunyuanVideo-PromptRewrite&message=HuggingFace&color=yellow"></a> &ensp; &ensp;
14
+
15
+ [![Replicate](https://replicate.com/zsxkib/hunyuan-video/badge)](https://replicate.com/zsxkib/hunyuan-video)
16
+ </div>
17
+
18
+ 06/04/2025: Version 2.1 Integrated Tea Cache (https://github.com/ali-vilab/TeaCache) for even faster generations\
19
+ 01/04/2025: Version 2.0 Full leverage of mmgp 3.0 (faster and even lower RAM requirements ! + support for compilation on Linux and WSL)\
20
+ 22/12/2024: Version 1.0 First release\
21
+
22
+ *GPU Poor version by **DeepBeepMeep**. This great video generator can now run smoothly on a 12 GB to 24 GB GPU.*
23
+
24
+ This version has the following improvements over the original Hunyuan Video model:
25
+ - Reduce greatly the RAM requirements and VRAM requirements
26
+ - Much faster thanks to compilation and fast loading / unloading
27
+ - 5 profiles in order to able to run the model at a decent speed on a low end consumer config (32 GB of RAM and 12 VRAM) and to run it at a very good speed on a high end consumer config (48 GB of RAM and 24 GB of VRAM)
28
+ - Autodownloading of the needed model files
29
+ - Improved gradio interface with progression bar and more options
30
+ - Switch easily between Hunyuan and Fast Hunyuan models and quantized / non quantized models
31
+ - Much simpler installation
32
+
33
+
34
+
35
+ This fork by DeepBeepMeep is an integration of the mmpg module on the gradio_server.py.
36
+
37
+ It is an illustration on how one can set up on an existing model some fast and properly working CPU offloading with changing only a few lines of code in the core model.
38
+
39
+ For more information on how to use the mmpg module, please go to: https://github.com/deepbeepmeep/mmgp
40
+
41
+ You will find the original Hunyuan Video repository here: https://github.com/Tencent/HunyuanVideo
42
+
43
+
44
+
45
+ ## **Abstract**
46
+ We present HunyuanVideo, a novel open-source video foundation model that exhibits performance in video generation that is comparable to, if not superior to, leading closed-source models. In order to train HunyuanVideo model, we adopt several key technologies for model learning, including data curation, image-video joint model training, and an efficient infrastructure designed to facilitate large-scale model training and inference. Additionally, through an effective strategy for scaling model architecture and dataset, we successfully trained a video generative model with over 13 billion parameters, making it the largest among all open-source models.
47
+
48
+ We conducted extensive experiments and implemented a series of targeted designs to ensure high visual quality, motion diversity, text-video alignment, and generation stability. According to professional human evaluation results, HunyuanVideo outperforms previous state-of-the-art models, including Runway Gen-3, Luma 1.6, and 3 top-performing Chinese video generative models. By releasing the code and weights of the foundation model and its applications, we aim to bridge the gap between closed-source and open-source video foundation models. This initiative will empower everyone in the community to experiment with their ideas, fostering a more dynamic and vibrant video generation ecosystem.
49
+
50
+
51
+
52
+ ### Installation Guide for Linux and Windows
53
+
54
+ We provide an `environment.yml` file for setting up a Conda environment.
55
+ Conda's installation instructions are available [here](https://docs.anaconda.com/free/miniconda/index.html).
56
+
57
+ We recommend CUDA versions 12.4 or 11.8 for the manual installation.
58
+
59
+ ```shell
60
+ # 1. Prepare conda environment
61
+ conda env create -f environment.yml
62
+
63
+ # 2. Activate the environment
64
+ conda activate HunyuanVideo
65
+
66
+ # 3. Install pip dependencies
67
+ python -m pip install -r requirements.txt
68
+
69
+
70
+ # 4.1 optional Flash attention support (easy to install on Linux but much harder on Windows)
71
+ python -m pip install flash-attn==2.7.2.post1
72
+
73
+ # 4.2 optional Sage attention support (30% faster, easy to install on Linux but much harder on Windows)
74
+ python -m pip install sageattention==1.0.6
75
+
76
+ ```
77
+
78
+ ### Profiles
79
+ You can choose between 5 profiles depending on your hardware:
80
+ - HighRAM_HighVRAM (1): at least 48 GB of RAM and 24 GB of VRAM : the fastest well suited for a RTX 3090 / RTX 4090 but consumes much more VRAM, adapted for fast shorter video
81
+ - HighRAM_LowVRAM (2): at least 48 GB of RAM and 12 GB of VRAM : a bit slower, better suited for RTX 3070/3080/4070/4080 or for RTX 3090 / RTX 4090 with large pictures batches or long videos
82
+ - LowRAM_HighVRAM (3): at least 32 GB of RAM and 24 GB of VRAM : adapted for RTX 3090 / RTX 4090 with limited RAM but at the cost of VRAM (shorter videos)
83
+ - LowRAM_LowVRAM (4): at least 32 GB of RAM and 12 GB of VRAM : if you have little VRAM or want to generate longer videos
84
+ - VerylowRAM_LowVRAM (5): at least 24 GB of RAM and 10 GB of VRAM : if you don't have much it won't be fast but maybe it will work
85
+
86
+ Profile 2 (High RAM) and 4 (Low RAM)are the most recommended profiles since they are versatile (support for long videos for a slight performance cost).\
87
+ However, a safe approach is to start from profile 5 (default profile) and then go down progressively to profile 4 and then to profile 2 as long as the app remains responsive or doesn't trigger any out of memory error.
88
+
89
+
90
+ ### Run a Gradio Server on port 7860 (recommended)
91
+ ```bash
92
+ python3 gradio_server.py
93
+ ```
94
+
95
+ You will have the possibility to configure a RAM / VRAM profile by expanding the section *Video Engine Configuration* in the Web Interface.\
96
+ If by mistake you have chosen a configuration not supported by your system, you can force a profile while loading the app with the safe profile 5:
97
+ ```bash
98
+ python3 gradio_server.py --profile 5
99
+ ```
100
+
101
+
102
+ ### Run through the command line
103
+ ```bash
104
+ cd HunyuanVideo
105
+
106
+ python3 sample_video.py \
107
+ --video-size 720 1280 \
108
+ --video-length 129 \
109
+ --infer-steps 50 \
110
+ --prompt "A cat walks on the grass, realistic style." \
111
+ --flow-reverse \
112
+ --save-path ./results
113
+ ```
114
+
115
+ Please note currently that profile and the models used need to be mentioned inside the *sample_video.py* file.
116
+
117
+ ### More Configurations
118
+
119
+ We list some more useful configurations for easy usage:
120
+
121
+ | Argument | Default | Description |
122
+ |:----------------------:|:---------:|:-----------------------------------------:|
123
+ | `--prompt` | None | The text prompt for video generation |
124
+ | `--video-size` | 720 1280 | The size of the generated video |
125
+ | `--video-length` | 129 | The length of the generated video |
126
+ | `--infer-steps` | 50 | The number of steps for sampling |
127
+ | `--embedded-cfg-scale` | 6.0 | Embeded Classifier free guidance scale |
128
+ | `--flow-shift` | 7.0 | Shift factor for flow matching schedulers |
129
+ | `--flow-reverse` | False | If reverse, learning/sampling from t=1 -> t=0 |
130
+ | `--seed` | None | The random seed for generating video, if None, we init a random seed |
131
+ | `--use-cpu-offload` | False | Use CPU offload for the model load to save more memory, necessary for high-res video generation |
132
+ | `--save-path` | ./results | Path to save the generated video |
133
+
134
+
135
+
136
+ ## 🔗 BibTeX
137
+ If you find [HunyuanVideo](https://arxiv.org/abs/2412.03603) useful for your research and applications, please cite using this BibTeX:
138
+
139
+ ```BibTeX
140
+ @misc{kong2024hunyuanvideo,
141
+ title={HunyuanVideo: A Systematic Framework For Large Video Generative Models},
142
+ author={Weijie Kong, Qi Tian, Zijian Zhang, Rox Min, Zuozhuo Dai, Jin Zhou, Jiangfeng Xiong, Xin Li, Bo Wu, Jianwei Zhang, Kathrina Wu, Qin Lin, Aladdin Wang, Andong Wang, Changlin Li, Duojun Huang, Fang Yang, Hao Tan, Hongmei Wang, Jacob Song, Jiawang Bai, Jianbing Wu, Jinbao Xue, Joey Wang, Junkun Yuan, Kai Wang, Mengyang Liu, Pengyu Li, Shuai Li, Weiyan Wang, Wenqing Yu, Xinchi Deng, Yang Li, Yanxin Long, Yi Chen, Yutao Cui, Yuanbo Peng, Zhentao Yu, Zhiyu He, Zhiyong Xu, Zixiang Zhou, Zunnan Xu, Yangyu Tao, Qinglin Lu, Songtao Liu, Dax Zhou, Hongfa Wang, Yong Yang, Di Wang, Yuhong Liu, and Jie Jiang, along with Caesar Zhong},
143
+ year={2024},
144
+ archivePrefix={arXiv preprint arXiv:2412.03603},
145
+ primaryClass={cs.CV},
146
+ url={https://arxiv.org/abs/2412.03603},
147
+ }
148
+ ```
149
+
150
+
151
+
152
+ ## 🧩 Projects that use HunyuanVideo
153
+
154
+ If you develop/use HunyuanVideo in your projects, welcome to let us know.
155
+
156
+ - ComfyUI (with support for F8 Inference and Video2Video Generation): [ComfyUI-HunyuanVideoWrapper](https://github.com/kijai/ComfyUI-HunyuanVideoWrapper) by [Kijai](https://github.com/kijai)
157
+
158
+
159
+
160
+ ## Acknowledgements
161
+
162
+ We would like to thank the contributors to the [SD3](https://huggingface.co/stabilityai/stable-diffusion-3-medium), [FLUX](https://github.com/black-forest-labs/flux), [Llama](https://github.com/meta-llama/llama), [LLaVA](https://github.com/haotian-liu/LLaVA), [Xtuner](https://github.com/InternLM/xtuner), [diffusers](https://github.com/huggingface/diffusers) and [HuggingFace](https://huggingface.co) repositories, for their open research and exploration.
163
+ Additionally, we also thank the Tencent Hunyuan Multimodal team for their help with the text encoder.
164
+
165
+ ## Star History
166
+ <a href="https://star-history.com/#Tencent/HunyuanVideo&Date">
167
+ <picture>
168
+ <source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=Tencent/HunyuanVideo&type=Date&theme=dark" />
169
+ <source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=Tencent/HunyuanVideo&type=Date" />
170
+ <img alt="Star History Chart" src="https://api.star-history.com/svg?repos=Tencent/HunyuanVideo&type=Date" />
171
+ </picture>
172
+ </a>
README_zh.md ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!-- ## **HunyuanVideo** -->
2
+
3
+ [English](./README.md)
4
+
5
+ <p align="center">
6
+ <img src="https://raw.githubusercontent.com/Tencent/HunyuanVideo/refs/heads/main/assets/logo.png" height=100>
7
+ </p>
8
+
9
+ # HunyuanVideo: A Systematic Framework For Large Video Generation Model
10
+
11
+ <div align="center">
12
+ <a href="https://github.com/Tencent/HunyuanVideo"><img src="https://img.shields.io/static/v1?label=HunyuanVideo Code&message=Github&color=blue&logo=github-pages"></a> &ensp;
13
+ <a href="https://aivideo.hunyuan.tencent.com"><img src="https://img.shields.io/static/v1?label=Project%20Page&message=Web&color=green&logo=github-pages"></a> &ensp;
14
+ <a href="https://video.hunyuan.tencent.com"><img src="https://img.shields.io/static/v1?label=Playground&message=Web&color=green&logo=github-pages"></a> &ensp;
15
+ <a href="https://arxiv.org/abs/2412.03603"><img src="https://img.shields.io/static/v1?label=Tech Report&message=Arxiv:HunyuanVideo&color=red&logo=arxiv"></a> &ensp;
16
+ <a href="https://huggingface.co/tencent/HunyuanVideo"><img src="https://img.shields.io/static/v1?label=HunyuanVideo&message=HuggingFace&color=yellow"></a> &ensp; &ensp;
17
+ <a href="https://huggingface.co/tencent/HunyuanVideo-PromptRewrite"><img src="https://img.shields.io/static/v1?label=HunyuanVideo-PromptRewrite&message=HuggingFace&color=yellow"></a> &ensp; &ensp;
18
+
19
+ [![Replicate](https://replicate.com/zsxkib/hunyuan-video/badge)](https://replicate.com/zsxkib/hunyuan-video)
20
+ </div>
21
+ <p align="center">
22
+ 👋 加入我们的 <a href="assets/WECHAT.md" target="_blank">WeChat</a> 和 <a href="https://discord.gg/GpARqvrh" target="_blank">Discord</a>
23
+ </p>
24
+
25
+
26
+
27
+ -----
28
+
29
+ 本仓库包含了 HunyuanVideo 项目的 PyTorch 模型定义、预训练权重和推理/采样代码。参考我们的项目页面 [project page](https://aivideo.hunyuan.tencent.com) 查看更多内容。
30
+
31
+ > [**HunyuanVideo: A Systematic Framework For Large Video Generation Model**](https://arxiv.org/abs/2412.03603) <br>
32
+
33
+ ## 🎥 作品展示
34
+ <div align="center">
35
+ <video src="https://github.com/user-attachments/assets/f37925a3-7d42-40c9-8a9b-5a010c7198e2" width="50%">
36
+ </div>
37
+
38
+ 注:由于 GitHub 的政策限制,上面的视频质量被大幅压缩。你可以从 [这里](https://aivideo.hunyuan.tencent.com/download/HunyuanVideo/material) 下载高质量版本。
39
+
40
+ ## 🔥🔥🔥 更新!!
41
+ * 2024年12月03日: 🚀 开源 HunyuanVideo 多卡并行推理代码,由[xDiT](https://github.com/xdit-project/xDiT)提供。
42
+ * 2024年12月03日: 🤗 开源 HunyuanVideo 文生视频的推理代码和模型权重。
43
+
44
+ ## 📑 开源计划
45
+
46
+ - HunyuanVideo (文生视频模型)
47
+ - [x] 推理代码
48
+ - [x] 模型权重
49
+ - [x] 多GPU序列并行推理(GPU 越多,推理速度越快)
50
+ - [x] Web Demo (Gradio)
51
+ - [ ] Penguin Video 基准测试集
52
+ - [ ] ComfyUI
53
+ - [ ] Diffusers
54
+ - [ ] 多GPU PipeFusion并行推理 (更低显存需求)
55
+ - HunyuanVideo (图生视频模型)
56
+ - [ ] 推理代码
57
+ - [ ] 模型权重
58
+
59
+ ## 目录
60
+ - [HunyuanVideo: A Systematic Framework For Large Video Generation Model](#hunyuanvideo-a-systematic-framework-for-large-video-generation-model)
61
+ - [🎥 作品展示](#-作品展示)
62
+ - [🔥🔥🔥 更新!!](#-更新)
63
+ - [📑 开源计划](#-开源计划)
64
+ - [目录](#目录)
65
+ - [**摘要**](#摘要)
66
+ - [**HunyuanVideo 的架构**](#hunyuanvideo-的架构)
67
+ - [🎉 **亮点**](#-亮点)
68
+ - [**统一的图视频生成架构**](#统一的图视频生成架构)
69
+ - [**MLLM 文本编码器**](#mllm-文本编码器)
70
+ - [**3D VAE**](#3d-vae)
71
+ - [**Prompt 改写**](#prompt-改写)
72
+ - [📈 能力评估](#-能力评估)
73
+ - [📜 运行配置](#-运行配置)
74
+ - [🛠️ 安装和依赖](#️-安装和依赖)
75
+ - [Linux 安装指引](#linux-安装指引)
76
+ - [🧱 下载预训练模型](#-下载预训练模型)
77
+ - [🔑 推理](#-推理)
78
+ - [使用命令行](#使用命令行)
79
+ - [运行gradio服务](#运行gradio服务)
80
+ - [更多配置](#更多配置)
81
+ - [🚀 使用 xDiT 实现多卡并行推理](#-使用-xdit-实现多卡并行推理)
82
+ - [安装与 xDiT 兼容的依赖项](#安装与-xdit-兼容的依赖项)
83
+ - [使用命令行](#使用命令行-1)
84
+ - [🔗 BibTeX](#-bibtex)
85
+ - [🧩 使用 HunyuanVideo 的项目](#-使用-hunyuanvideo-的项目)
86
+ - [致谢](#致谢)
87
+ - [Star 趋势](#star-趋势)
88
+ ---
89
+
90
+ ## **摘要**
91
+ HunyuanVideo 是一个全新的开源视频生成大模型,具有与领先的闭源模型相媲美甚至更优的视频生成表现。为了训练 HunyuanVideo,我们采用了一个全面的框架,集成了数据整理、图像-视频联合模型训练和高效的基础设施以支持大规模模型训练和推理。此外,通过有效的模型架构和数据集扩展策略,我们成功地训练了一个拥有超过 130 亿参数的视频生成模型,使其成为最大的开源视频生成模型之一。
92
+
93
+ 我们在模型结构的设计上做了大量的实验以确保其能拥有高质量的视觉效果、多样的运动、文本-视频对齐和生成稳定性。根据专业人员的评估结果,HunyuanVideo 在综合指标上优于以往的最先进模型,包括 Runway Gen-3、Luma 1.6 和 3 个中文社区表现最好的视频生成模型。**通过开源基础模型和应用模型的代码和权重,我们旨在弥合闭源和开源视频基础模型之间的差距,帮助社区中的每个人都能够尝试自己的想法,促进更加动态和活跃的视频生成生态。**
94
+
95
+
96
+ ## **HunyuanVideo 的架构**
97
+
98
+ HunyuanVideo 是一个隐空间模型,训练时它采用了 3D VAE 压缩时间维度和空间维度的特征。文本提示通过一个大语言模型编码后作为条件输入模型,引导模型通过对高斯噪声的多步去噪,输出一个视频的隐空间表示。最后,推理时通过 3D VAE 解码器将隐空间表示解码为视频。
99
+ <p align="center">
100
+ <img src="https://raw.githubusercontent.com/Tencent/HunyuanVideo/refs/heads/main/assets/overall.png" height=300>
101
+ </p>
102
+
103
+ ## 🎉 **亮点**
104
+ ### **统一的图视频生成架构**
105
+
106
+ HunyuanVideo 采用了 Transformer 和 Full Attention 的设计用于视频生成。具体来说,我们使用了一个“双流到单流”的混合模型设计用于视频生成。在双流阶段,视频和文本 token 通过并行的 Transformer Block 独立处理,使得每个模态可以学习适合自己的调制机制而不会相互干扰。在单流阶段,我们将视频和文本 token 连接起来并将它们输入到后续的 Transformer Block 中进行有效的多模态信息融合。这种设计捕捉了视觉和语义信息之间的复杂交互,增强了整体模型性能。
107
+ <p align="center">
108
+ <img src="https://raw.githubusercontent.com/Tencent/HunyuanVideo/refs/heads/main/assets/backbone.png" height=350>
109
+ </p>
110
+
111
+ ### **MLLM 文本编码器**
112
+ 过去的视频生成模型通常使用预训练的 CLIP 和 T5-XXL 作为文本编码器,其中 CLIP 使用 Transformer Encoder,T5 使用 Encoder-Decoder 结构。HunyuanVideo 使用了一个预训练的 Multimodal Large Language Model (MLLM) 作为文本编码器,它具有以下优势:
113
+ * 与 T5 相比,MLLM 基于图文数据指令微调后在特征空间中具有更好的图像-文本对齐能力,这减轻了扩散模型中的图文对齐的难度;
114
+ * 与 CLIP 相比,MLLM 在图像的细节描述和复杂推理方面表现出更强的能力;
115
+ * MLLM 可以通过遵循系统指令实现零样本生成,帮助文本特征更多地关注关键信息。
116
+
117
+ 由于 MLLM 是基于 Causal Attention 的,而 T5-XXL 使用了 Bidirectional Attention 为扩散模型提供更好的文本引导。因此,我们引入了一个额外的 token 优化器来增强文本特征。
118
+ <p align="center">
119
+ <img src="https://raw.githubusercontent.com/Tencent/HunyuanVideo/refs/heads/main/assets/text_encoder.png" height=275>
120
+ </p>
121
+
122
+ ### **3D VAE**
123
+ 我们的 VAE 采用了 CausalConv3D 作为 HunyuanVideo 的编码器和解码器,用于压缩视频的时间维度和空间维度,其中时间维度压缩 4 倍,空间维度压缩 8 倍,压缩为 16 channels。这样可以显著减少后续 Transformer 模型的 token 数量,使我们能够在原始分辨率和帧率下训练视频生成模型。
124
+ <p align="center">
125
+ <img src="https://raw.githubusercontent.com/Tencent/HunyuanVideo/refs/heads/main/assets/3dvae.png" height=150>
126
+ </p>
127
+
128
+ ### **Prompt 改写**
129
+ 为了解决用户输入文本提示的多样性和不一致性的困难,我们微调了 [Hunyuan-Large model](https://github.com/Tencent/Tencent-Hunyuan-Large) 模型作为我们的 prompt 改写模型,将用户输入的提示词改写为更适合模型偏好的写法。
130
+
131
+ 我们提供了两个改写模式:正常模式和导演模式。两种模式的提示词见[这里](hyvideo/prompt_rewrite.py)。正常模式旨在增强视频生成模型对用户意图的理解,从而更准确地解释提供的指令。导演模式增强了诸如构图、光照和摄像机移动等方面的描述,倾向于生成视觉质量更高的视频。注意,这种增强有时可能会导致一些语义细节的丢失。
132
+
133
+ Prompt 改写模型可以直接使用 [Hunyuan-Large](https://github.com/Tencent/Tencent-Hunyuan-Large) 部署和推理. 我们开源了 prompt 改写模型的权重,见[这里](https://huggingface.co/Tencent/HunyuanVideo-PromptRewrite).
134
+
135
+ ## 📈 能力评估
136
+
137
+ 为了评估 HunyuanVideo 的能力,我们选择了四个闭源视频生成模型作为对比。我们总共使用了 1,533 个 prompt,每个 prompt 通过一次推理生成了相同数量的视频样本。为了公平比较,我们只进行了一次推理以避免任何挑选。在与其他方法比较时,我们保持了所有选择模型的默认设置,并确保了视频分辨率的一致性。视频根据三个标准进行评估:文本对齐、运动质量和视觉质量。在 60 多名专业评估人员评估后,HunyuanVideo 在综合指标上表现最好,特别是在运动质量方面表现较为突出。
138
+
139
+ <p align="center">
140
+ <table>
141
+ <thead>
142
+ <tr>
143
+ <th rowspan="2">模型</th> <th rowspan="2">是否开源</th> <th>时长</th> <th>文本对齐</th> <th>运动质量</th> <th rowspan="2">视觉质量</th> <th rowspan="2">综合评价</th> <th rowspan="2">排序</th>
144
+ </tr>
145
+ </thead>
146
+ <tbody>
147
+ <tr>
148
+ <td>HunyuanVideo (Ours)</td> <td> ✔ </td> <td>5s</td> <td>61.8%</td> <td>66.5%</td> <td>95.7%</td> <td>41.3%</td> <td>1</td>
149
+ </tr>
150
+ <tr>
151
+ <td>国内模型 A (API)</td> <td> &#10008 </td> <td>5s</td> <td>62.6%</td> <td>61.7%</td> <td>95.6%</td> <td>37.7%</td> <td>2</td>
152
+ </tr>
153
+ <tr>
154
+ <td>国内模型 B (Web)</td> <td> &#10008</td> <td>5s</td> <td>60.1%</td> <td>62.9%</td> <td>97.7%</td> <td>37.5%</td> <td>3</td>
155
+ </tr>
156
+ <tr>
157
+ <td>GEN-3 alpha (Web)</td> <td>&#10008</td> <td>6s</td> <td>47.7%</td> <td>54.7%</td> <td>97.5%</td> <td>27.4%</td> <td>4</td>
158
+ </tr>
159
+ <tr>
160
+ <td>Luma1.6 (API)</td><td>&#10008</td> <td>5s</td> <td>57.6%</td> <td>44.2%</td> <td>94.1%</td> <td>24.8%</td> <td>5</td>
161
+ </tr>
162
+ </tbody>
163
+ </table>
164
+ </p>
165
+
166
+ ## 📜 运行配置
167
+
168
+ 下表列出了运行 HunyuanVideo 模型使用文本生成视频的推荐配置(batch size = 1):
169
+
170
+ | 模型 | 分辨率<br/>(height/width/frame) | 峰值显存 |
171
+ |:--------------:|:--------------------------------:|:----------------:|
172
+ | HunyuanVideo | 720px1280px129f | 60G |
173
+ | HunyuanVideo | 544px960px129f | 45G |
174
+
175
+ * 本项目适用于使用 NVIDIA GPU 和支持 CUDA 的设备
176
+ * 模型在单张 80G GPU 上测试
177
+ * 运行 720px1280px129f 的最小显存要求是 60GB,544px960px129f 的最小显存要求是 45GB。
178
+ * 测试操作系统:Linux
179
+
180
+ ## 🛠️ 安装和依赖
181
+
182
+ 首先克隆 git 仓库:
183
+ ```shell
184
+ git clone https://github.com/tencent/HunyuanVideo
185
+ cd HunyuanVideo
186
+ ```
187
+
188
+ ### Linux 安装指引
189
+
190
+ 我们提供了 `environment.yml` 文件来设置 Conda 环境。Conda 的安装指南可以参考[这里](https://docs.anaconda.com/free/miniconda/index.html)。
191
+
192
+ 我们推理使用 CUDA 12.4 或 11.8 的版本。
193
+
194
+ ```shell
195
+ # 1. Prepare conda environment
196
+ conda env create -f environment.yml
197
+
198
+ # 2. Activate the environment
199
+ conda activate HunyuanVideo
200
+
201
+ # 3. Install pip dependencies
202
+ python -m pip install -r requirements.txt
203
+
204
+ # 4. Install flash attention v2 for acceleration (requires CUDA 11.8 or above)
205
+ python -m pip install ninja
206
+ python -m pip install git+https://github.com/Dao-AILab/[email protected]
207
+ ```
208
+
209
+ 如果在特定GPU型号上遭遇float point exception(core dump)问题,可尝试以下方案修复:
210
+
211
+ ```shell
212
+ #选项1:确保已正确安装CUDA 12.4, CUBLAS>=12.4.5.8, and CUDNN>=9.00(或直接使用我们提供的CUDA12镜像)
213
+ pip install nvidia-cublas-cu12==12.4.5.8
214
+ export LD_LIBRARY_PATH=/opt/conda/lib/python3.8/site-packages/nvidia/cublas/lib/
215
+
216
+ #选项2:强制显式使用CUDA11.8编译的Pytorch版本以及其他所有软件包
217
+ pip uninstall -r requirements.txt # 确保卸载所有依赖包
218
+ pip install torch==2.4.0 --index-url https://download.pytorch.org/whl/cu118
219
+ pip install -r requirements.txt
220
+ python -m pip install git+https://github.com/Dao-AILab/[email protected]
221
+ ```
222
+
223
+ 另外,我们提供了一个预构建的 Docker 镜像,可以使用如下命令进行拉取和运行。
224
+ ```shell
225
+ # 用于CUDA 12.4 (已更新避免float point exception)
226
+ docker pull hunyuanvideo/hunyuanvideo:cuda_12
227
+ docker run -itd --gpus all --init --net=host --uts=host --ipc=host --name hunyuanvideo --security-opt=seccomp=unconfined --ulimit=stack=67108864 --ulimit=memlock=-1 --privileged hunyuanvideo/hunyuanvideo:cuda_12
228
+
229
+ # 用于CUDA 11.8
230
+ docker pull hunyuanvideo/hunyuanvideo:cuda_11
231
+ docker run -itd --gpus all --init --net=host --uts=host --ipc=host --name hunyuanvideo --security-opt=seccomp=unconfined --ulimit=stack=67108864 --ulimit=memlock=-1 --privileged hunyuanvideo/hunyuanvideo:cuda_11
232
+ ```
233
+
234
+ ## 🧱 下载预训练模型
235
+
236
+ 下载预训练模型参考[这里](ckpts/README.md)。
237
+
238
+ ## 🔑 推理
239
+ 我们在下表中列出了支持的高度/宽度/帧数设置。
240
+
241
+ | 分辨率 | h/w=9:16 | h/w=16:9 | h/w=4:3 | h/w=3:4 | h/w=1:1 |
242
+ |:---------------------:|:----------------------------:|:---------------:|:---------------:|:---------------:|:---------------:|
243
+ | 540p | 544px960px129f | 960px544px129f | 624px832px129f | 832px624px129f | 720px720px129f |
244
+ | 720p (推荐) | 720px1280px129f | 1280px720px129f | 1104px832px129f | 832px1104px129f | 960px960px129f |
245
+
246
+ ### 使用命令行
247
+
248
+ ```bash
249
+ cd HunyuanVideo
250
+
251
+ python3 sample_video.py \
252
+ --video-size 720 1280 \
253
+ --video-length 129 \
254
+ --infer-steps 50 \
255
+ --prompt "A cat walks on the grass, realistic style." \
256
+ --flow-reverse \
257
+ --use-cpu-offload \
258
+ --save-path ./results
259
+ ```
260
+
261
+ ### 运行gradio服务
262
+ ```bash
263
+ python3 gradio_server.py --flow-reverse
264
+
265
+ # set SERVER_NAME and SERVER_PORT manually
266
+ # SERVER_NAME=0.0.0.0 SERVER_PORT=8081 python3 gradio_server.py --flow-reverse
267
+ ```
268
+
269
+ ### 更多配置
270
+
271
+ 下面列出了更多关键配置项:
272
+
273
+ | 参数 | 默认值 | 描述 |
274
+ |:----------------------:|:---------:|:-----------------------------------------:|
275
+ | `--prompt` | None | 用于生成视频的 prompt |
276
+ | `--video-size` | 720 1280 | 生成视频的高度和宽度 |
277
+ | `--video-length` | 129 | 生成视频的帧数 |
278
+ | `--infer-steps` | 50 | 生成时采样的步数 |
279
+ | `--embedded-cfg-scale` | 6.0 | 文本的控制强度 |
280
+ | `--flow-shift` | 7.0 | 推理时 timestep 的 shift 系数,值越大,高噪区域采样步数越多 |
281
+ | `--flow-reverse` | False | If reverse, learning/sampling from t=1 -> t=0 |
282
+ | `--neg-prompt` | None | 负向词 |
283
+ | `--seed` | 0 | 随机种子 |
284
+ | `--use-cpu-offload` | False | 启用 CPU offload,可以节省显存 |
285
+ | `--save-path` | ./results | 保存路径 |
286
+
287
+
288
+ ## 🚀 使用 xDiT 实现多卡并行推理
289
+
290
+ [xDiT](https://github.com/xdit-project/xDiT) 是一个针对多 GPU 集群的扩展推理引擎,用于扩展 Transformers(DiTs)。
291
+ 它成功为各种 DiT 模型(包括 mochi-1、CogVideoX、Flux.1、SD3 等)提供了低延迟的并行推理解决方案。该存储库采用了 [Unified Sequence Parallelism (USP)](https://arxiv.org/abs/2405.07719) API 用于混元视频模型的并行推理。
292
+
293
+ ### 安装与 xDiT 兼容的依赖项
294
+
295
+ ```
296
+ # 1. 创建一个空白的 conda 环境
297
+ conda create -n hunyuanxdit python==3.10.9
298
+ conda activate hunyuanxdit
299
+
300
+ # 2. 使用 CUDA 11.8 安装 PyTorch 组件
301
+ conda install pytorch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 pytorch-cuda=11.8 -c pytorch -c nvidia
302
+
303
+ # 3. 安装 pip 依赖项
304
+ python -m pip install -r requirements_xdit.txt
305
+ ```
306
+
307
+ 您可以跳过上述步骤,直接拉取预构建的 Docker 镜像,这个镜像是从 [docker/Dockerfile_xDiT](./docker/Dockerfile_xDiT) 构建的
308
+
309
+ ```
310
+ docker pull thufeifeibear/hunyuanvideo:latest
311
+ ```
312
+
313
+ ### 使用命令行
314
+
315
+ 例如,可用如下命令使用8张GPU卡完成推理
316
+
317
+ ```bash
318
+ cd HunyuanVideo
319
+
320
+ torchrun --nproc_per_node=8 sample_video_parallel.py \
321
+ --video-size 1280 720 \
322
+ --video-length 129 \
323
+ --infer-steps 50 \
324
+ --prompt "A cat walks on the grass, realistic style." \
325
+ --flow-reverse \
326
+ --seed 42 \
327
+ --ulysses_degree 8 \
328
+ --ring_degree 1 \
329
+ --save-path ./results
330
+ ```
331
+
332
+ 可以配置`--ulysses-degree`和`--ring-degree`来控制并行配置,可选参数如下。
333
+
334
+ <details>
335
+ <summary>支持的并行配置 (点击查看详情)</summary>
336
+
337
+ | --video-size | --video-length | --ulysses-degree x --ring-degree | --nproc_per_node |
338
+ |----------------------|----------------|----------------------------------|------------------|
339
+ | 1280 720 或 720 1280 | 129 | 8x1,4x2,2x4,1x8 | 8 |
340
+ | 1280 720 或 720 1280 | 129 | 1x5 | 5 |
341
+ | 1280 720 或 720 1280 | 129 | 4x1,2x2,1x4 | 4 |
342
+ | 1280 720 或 720 1280 | 129 | 3x1,1x3 | 3 |
343
+ | 1280 720 或 720 1280 | 129 | 2x1,1x2 | 2 |
344
+ | 1104 832 或 832 1104 | 129 | 4x1,2x2,1x4 | 4 |
345
+ | 1104 832 或 832 1104 | 129 | 3x1,1x3 | 3 |
346
+ | 1104 832 或 832 1104 | 129 | 2x1,1x2 | 2 |
347
+ | 960 960 | 129 | 6x1,3x2,2x3,1x6 | 6 |
348
+ | 960 960 | 129 | 4x1,2x2,1x4 | 4 |
349
+ | 960 960 | 129 | 3x1,1x3 | 3 |
350
+ | 960 960 | 129 | 1x2,2x1 | 2 |
351
+ | 960 544 或 544 960 | 129 | 6x1,3x2,2x3,1x6 | 6 |
352
+ | 960 544 或 544 960 | 129 | 4x1,2x2,1x4 | 4 |
353
+ | 960 544 或 544 960 | 129 | 3x1,1x3 | 3 |
354
+ | 960 544 或 544 960 | 129 | 1x2,2x1 | 2 |
355
+ | 832 624 或 624 832 | 129 | 4x1,2x2,1x4 | 4 |
356
+ | 624 832 或 624 832 | 129 | 3x1,1x3 | 3 |
357
+ | 832 624 或 624 832 | 129 | 2x1,1x2 | 2 |
358
+ | 720 720 | 129 | 1x5 | 5 |
359
+ | 720 720 | 129 | 3x1,1x3 | 3 |
360
+
361
+ </details>
362
+
363
+ <p align="center">
364
+ <table align="center">
365
+ <thead>
366
+ <tr>
367
+ <th colspan="4">在 8xGPU上生成1280x720 (129 帧 50 步)的时耗 (秒) </th>
368
+ </tr>
369
+ <tr>
370
+ <th>1</th>
371
+ <th>2</th>
372
+ <th>4</th>
373
+ <th>8</th>
374
+ </tr>
375
+ </thead>
376
+ <tbody>
377
+ <tr>
378
+ <th>1904.08</th>
379
+ <th>934.09 (2.04x)</th>
380
+ <th>514.08 (3.70x)</th>
381
+ <th>337.58 (5.64x)</th>
382
+ </tr>
383
+
384
+ </tbody>
385
+ </table>
386
+ </p>
387
+
388
+
389
+ ## 🔗 BibTeX
390
+ 如果您认为 [HunyuanVideo](https://arxiv.org/abs/2412.03603) 给您的研究和应用带来了一些帮助,可以通过下面的方式来引用:
391
+
392
+ ```BibTeX
393
+ @misc{kong2024hunyuanvideo,
394
+ title={HunyuanVideo: A Systematic Framework For Large Video Generative Models},
395
+ author={Weijie Kong, Qi Tian, Zijian Zhang, Rox Min, Zuozhuo Dai, Jin Zhou, Jiangfeng Xiong, Xin Li, Bo Wu, Jianwei Zhang, Kathrina Wu, Qin Lin, Aladdin Wang, Andong Wang, Changlin Li, Duojun Huang, Fang Yang, Hao Tan, Hongmei Wang, Jacob Song, Jiawang Bai, Jianbing Wu, Jinbao Xue, Joey Wang, Junkun Yuan, Kai Wang, Mengyang Liu, Pengyu Li, Shuai Li, Weiyan Wang, Wenqing Yu, Xinchi Deng, Yang Li, Yanxin Long, Yi Chen, Yutao Cui, Yuanbo Peng, Zhentao Yu, Zhiyu He, Zhiyong Xu, Zixiang Zhou, Zunnan Xu, Yangyu Tao, Qinglin Lu, Songtao Liu, Dax Zhou, Hongfa Wang, Yong Yang, Di Wang, Yuhong Liu, and Jie Jiang, along with Caesar Zhong},
396
+ year={2024},
397
+ archivePrefix={arXiv preprint arXiv:2412.03603},
398
+ primaryClass={cs.CV}
399
+ }
400
+ ```
401
+
402
+
403
+
404
+ ## 🧩 使用 HunyuanVideo 的项目
405
+
406
+ 如果您的项目中有开发或使用 HunyuanVideo,欢迎告知我们。
407
+
408
+ - ComfyUI (支持F8推理和Video2Video生成): [ComfyUI-HunyuanVideoWrapper](https://github.com/kijai/ComfyUI-HunyuanVideoWrapper) by [Kijai](https://github.com/kijai)
409
+
410
+
411
+
412
+ ## 致谢
413
+
414
+ HunyuanVideo 的开源离不开诸多开源工作,这里我们特别感谢 [SD3](https://huggingface.co/stabilityai/stable-diffusion-3-medium), [FLUX](https://github.com/black-forest-labs/flux), [Llama](https://github.com/meta-llama/llama), [LLaVA](https://github.com/haotian-liu/LLaVA), [Xtuner](https://github.com/InternLM/xtuner), [diffusers](https://github.com/huggingface/diffusers) and [HuggingFace](https://huggingface.co) 的开源工作和探索。另外,我们也感谢腾讯混元多模态团队对 HunyuanVideo 适配多种文本编码器的支持。
415
+
416
+
417
+ ## Star 趋势
418
+
419
+ <a href="https://star-history.com/#Tencent/HunyuanVideo&Date">
420
+ <picture>
421
+ <source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=Tencent/HunyuanVideo&type=Date&theme=dark" />
422
+ <source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=Tencent/HunyuanVideo&type=Date" />
423
+ <img alt="Star History Chart" src="https://api.star-history.com/svg?repos=Tencent/HunyuanVideo&type=Date" />
424
+ </picture>
425
+ </a>
assets/3dvae.png ADDED
assets/WECHAT.md ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+ <img src=wechat.jpg width="60%"/>
3
+
4
+ <p> 扫码关注混元系列工作,加入「 Hunyuan Video 交流群」 </p>
5
+ <p> Scan the QR code to join the "Hunyuan Discussion Group" </p>
6
+ </div>
7
+
assets/backbone.png ADDED

Git LFS Details

  • SHA256: d30e8775add644bd4f484cc0ab5edf1b3c9ab90f7e2215dc46471d263f6792ba
  • Pointer size: 132 Bytes
  • Size of remote file: 1.34 MB
assets/hunyuanvideo.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c5aea4522a926775eec39db8bdb4a9ced63a1c0a703fb7254e7b9fdc1dbe7f98
3
+ size 44603215
assets/logo.png ADDED
assets/overall.png ADDED

Git LFS Details

  • SHA256: 1f1cb64d8d16cb76f7bb01d3b5e3224724f1bd0ee6d856150a7d0b210fd172e5
  • Pointer size: 132 Bytes
  • Size of remote file: 1.28 MB
assets/text_encoder.png ADDED

Git LFS Details

  • SHA256: 5b3cfe4a4acb8fef96d8b14ee279c16626f6b74a95762d37272ddddef7ae5cc4
  • Pointer size: 132 Bytes
  • Size of remote file: 1.31 MB
assets/wechat.jpg ADDED
docker/Dockerfile_xDiT ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu22.04
2
+
3
+ ENV DEBIAN_FRONTEND=noninteractive
4
+
5
+ # Install necessary packages
6
+ RUN apt-get update && apt-get install -y --no-install-recommends \
7
+ python3 \
8
+ python3-pip \
9
+ python3-dev \
10
+ git \
11
+ wget \
12
+ bzip2 \
13
+ && rm -rf /var/lib/apt/lists/*
14
+
15
+ # Install Miniconda
16
+ RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O /tmp/miniconda.sh \
17
+ && bash /tmp/miniconda.sh -b -p /opt/conda \
18
+ && rm /tmp/miniconda.sh
19
+
20
+ # Set up environment variables for Conda
21
+ ENV PATH="/opt/conda/bin:$PATH"
22
+
23
+ # Create a new conda environment
24
+ RUN conda create -n myenv python=3.10 -y
25
+
26
+ # Activate the conda environment
27
+ SHELL ["/bin/bash", "--login", "-c"]
28
+ RUN echo "source activate myenv" > ~/.bashrc
29
+ ENV PATH /opt/conda/envs/myenv/bin:$PATH
30
+
31
+ # Install PyTorch and other dependencies using conda
32
+ RUN conda install pytorch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 pytorch-cuda=11.8 -c pytorch -c nvidia
33
+
34
+ COPY ./requirements_xdit.txt /tmp/requirements_xdit.txt
35
+ RUN pip install --no-cache-dir -r /tmp/requirements_xdit.txt
36
+
37
+ # Set working directory
38
+ WORKDIR /workspace
39
+
40
+ # Default command
41
+ CMD ["/bin/bash"]
environment.yml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ name: HunyuanVideo
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ dependencies:
6
+ - python=3.10.9
7
+ - pytorch=2.5.1
8
+ - pip
gradio_server.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from pathlib import Path
4
+ from loguru import logger
5
+ from datetime import datetime
6
+ import gradio as gr
7
+ import random
8
+ import json
9
+ from hyvideo.utils.file_utils import save_videos_grid
10
+ from hyvideo.config import parse_args
11
+ from hyvideo.inference import HunyuanVideoSampler
12
+ from hyvideo.constants import NEGATIVE_PROMPT
13
+ from mmgp import offload, profile_type
14
+
15
+
16
+ args = parse_args()
17
+
18
+
19
+ force_profile_no = int(args.profile)
20
+ verbose_level = int(args.verbose)
21
+
22
+ transformer_choices=["ckpts/hunyuan-video-t2v-720p/transformers/hunyuan_video_720_bf16.safetensors", "ckpts/hunyuan-video-t2v-720p/transformers/hunyuan_video_720_quanto_int8.safetensors", "ckpts/hunyuan-video-t2v-720p/transformers/fast_hunyuan_video_720_quanto_int8.safetensors"]
23
+ text_encoder_choices = ["ckpts/text_encoder/llava-llama-3-8b-v1_1_fp16.safetensors", "ckpts/text_encoder/llava-llama-3-8b-v1_1_quanto_int8.safetensors"]
24
+ server_config_filename = "gradio_config.json"
25
+
26
+ if not Path(server_config_filename).is_file():
27
+ server_config = {"attention_mode" : "sdpa",
28
+ "transformer_filename": transformer_choices[1],
29
+ "text_encoder_filename" : text_encoder_choices[1],
30
+ "compile" : "",
31
+ "profile" : profile_type.LowRAM_LowVRAM }
32
+
33
+ with open(server_config_filename, "w", encoding="utf-8") as writer:
34
+ writer.write(json.dumps(server_config))
35
+ else:
36
+ with open(server_config_filename, "r", encoding="utf-8") as reader:
37
+ text = reader.read()
38
+ server_config = json.loads(text)
39
+
40
+ transformer_filename = server_config["transformer_filename"]
41
+ text_encoder_filename = server_config["text_encoder_filename"]
42
+ attention_mode = server_config["attention_mode"]
43
+ profile = force_profile_no if force_profile_no >=0 else server_config["profile"]
44
+ compile = server_config.get("compile", "")
45
+
46
+ #transformer_filename = "ckpts/hunyuan-video-t2v-720p/transformers/hunyuan_video_720_bf16.safetensors"
47
+ #transformer_filename = "ckpts/hunyuan-video-t2v-720p/transformers/hunyuan_video_720_quanto_int8.safetensors"
48
+ #transformer_filename = "ckpts/hunyuan-video-t2v-720p/transformers/fast_hunyuan_video_720_quanto_int8.safetensors"
49
+
50
+ #text_encoder_filename = "ckpts/text_encoder/llava-llama-3-8b-v1_1_fp16.safetensors"
51
+ #text_encoder_filename = "ckpts/text_encoder/llava-llama-3-8b-v1_1_quanto_int8.safetensors"
52
+
53
+ #attention_mode="sage"
54
+ #attention_mode="flash"
55
+
56
+ def download_models(transformer_filename, text_encoder_filename):
57
+ def computeList(filename):
58
+ pos = filename.rfind("/")
59
+ filename = filename[pos+1:]
60
+ if not "quanto" in filename:
61
+ return [filename]
62
+ pos = filename.rfind(".")
63
+ return [filename, filename[:pos] +"_map.json"]
64
+
65
+ from huggingface_hub import hf_hub_download, snapshot_download
66
+ repoId = "DeepBeepMeep/HunyuanVideo"
67
+ sourceFolderList = ["text_encoder_2", "text_encoder", "hunyuan-video-t2v-720p/vae", "hunyuan-video-t2v-720p/transformers" ]
68
+ fileList = [ [], ["config.json", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json"] + computeList(text_encoder_filename) , [], computeList(transformer_filename) ]
69
+ targetRoot = "ckpts/"
70
+ for sourceFolder, files in zip(sourceFolderList,fileList ):
71
+ if len(files)==0:
72
+ if not Path(targetRoot + sourceFolder).exists():
73
+ snapshot_download(repo_id=repoId, allow_patterns=sourceFolder +"/*", local_dir= targetRoot)
74
+ else:
75
+ for onefile in files:
76
+ if not os.path.isfile(targetRoot + sourceFolder + "/" + onefile ):
77
+ hf_hub_download(repo_id=repoId, filename=onefile, local_dir = targetRoot, subfolder=sourceFolder)
78
+
79
+
80
+ download_models(transformer_filename, text_encoder_filename)
81
+
82
+ # models_root_path = Path(args.model_base)
83
+ # if not models_root_path.exists():
84
+ # raise ValueError(f"`models_root` not exists: {models_root_path}")
85
+
86
+ offload.default_verboseLevel = verbose_level
87
+ with open("./ckpts/hunyuan-video-t2v-720p/vae/config.json", "r", encoding="utf-8") as reader:
88
+ text = reader.read()
89
+ vae_config= json.loads(text)
90
+ # reduce time window used by the VAE for temporal splitting (former time windows is too large for 24 GB)
91
+ if vae_config["sample_tsize"] == 64:
92
+ vae_config["sample_tsize"] = 32
93
+ with open("./ckpts/hunyuan-video-t2v-720p/vae/config.json", "w", encoding="utf-8") as writer:
94
+ writer.write(json.dumps(vae_config))
95
+
96
+ args.flow_reverse = True
97
+ if profile == 5:
98
+ pinToMemory = False
99
+ partialPinning = False
100
+ else:
101
+ pinToMemory = True
102
+ import psutil
103
+ physical_memory= psutil.virtual_memory().total
104
+ partialPinning = physical_memory <= 2**30 * 32
105
+
106
+ hunyuan_video_sampler = HunyuanVideoSampler.from_pretrained(transformer_filename, text_encoder_filename, attention_mode = attention_mode, pinToMemory = pinToMemory, partialPinning = partialPinning, args=args, device="cpu")
107
+
108
+ pipe = hunyuan_video_sampler.pipeline
109
+
110
+ offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = False)
111
+
112
+ def apply_changes(
113
+ transformer_choice,
114
+ text_encoder_choice,
115
+ attention_choice,
116
+ compile_choice,
117
+ profile_choice,
118
+ ):
119
+ server_config = {"attention_mode" : attention_choice,
120
+ "transformer_filename": transformer_choices[transformer_choice],
121
+ "text_encoder_filename" : text_encoder_choices[text_encoder_choice],
122
+ "compile" : compile_choice,
123
+ "profile" : profile_choice }
124
+
125
+ with open(server_config_filename, "w", encoding="utf-8") as writer:
126
+ writer.write(json.dumps(server_config))
127
+
128
+ return "<h1>New Config file created. Please restart the Gradio Server</h1>"
129
+
130
+
131
+ from moviepy.editor import ImageSequenceClip
132
+ import numpy as np
133
+
134
+ def save_video(final_frames, output_path, fps=24):
135
+ assert final_frames.ndim == 4 and final_frames.shape[3] == 3, f"invalid shape: {final_frames} (need t h w c)"
136
+ if final_frames.dtype != np.uint8:
137
+ final_frames = (final_frames * 255).astype(np.uint8)
138
+ ImageSequenceClip(list(final_frames), fps=fps).write_videofile(output_path, verbose= False, logger = None)
139
+
140
+
141
+ def generate_video(
142
+ prompt,
143
+ resolution,
144
+ video_length,
145
+ seed,
146
+ num_inference_steps,
147
+ guidance_scale,
148
+ flow_shift,
149
+ embedded_guidance_scale,
150
+ tea_cache,
151
+ progress=gr.Progress(track_tqdm=True)
152
+
153
+ ):
154
+ seed = None if seed == -1 else seed
155
+ width, height = resolution.split("x")
156
+ width, height = int(width), int(height)
157
+ negative_prompt = "" # not applicable in the inference
158
+
159
+
160
+ # TeaCache
161
+ trans = hunyuan_video_sampler.pipeline.transformer.__class__
162
+ trans.enable_teacache = tea_cache > 0
163
+ if trans.enable_teacache:
164
+ trans.num_steps = num_inference_steps
165
+ trans.cnt = 0
166
+ trans.rel_l1_thresh = 0.15 # 0.1 for 1.6x speedup, 0.15 for 2.1x speedup
167
+ trans.accumulated_rel_l1_distance = 0
168
+ trans.previous_modulated_input = None
169
+ trans.previous_residual = None
170
+
171
+
172
+ outputs = hunyuan_video_sampler.predict(
173
+ prompt=prompt,
174
+ height=height,
175
+ width=width,
176
+ video_length=(video_length // 4)* 4 + 1 ,
177
+ seed=seed,
178
+ negative_prompt=negative_prompt,
179
+ infer_steps=num_inference_steps,
180
+ guidance_scale=guidance_scale,
181
+ num_videos_per_prompt=1,
182
+ flow_shift=flow_shift,
183
+ batch_size=1,
184
+ embedded_guidance_scale=embedded_guidance_scale
185
+ )
186
+
187
+
188
+ from einops import rearrange
189
+
190
+ samples = outputs['samples']
191
+ sample = samples[0]
192
+
193
+ video = rearrange(sample.cpu().numpy(), "c t h w -> t h w c")
194
+
195
+ save_path = os.path.join(os.getcwd(), "gradio_outputs")
196
+ os.makedirs(save_path, exist_ok=True)
197
+ time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%Hh%Mm%Ss")
198
+ file_name = f"{time_flag}_seed{outputs['seeds'][0]}_{outputs['prompts'][0][:100].replace('/','')}.mp4".replace(':',' ').replace('\\',' ')
199
+ video_path = os.path.join(os.getcwd(), "gradio_outputs", file_name)
200
+
201
+ save_video(video, video_path )
202
+ print(f"New video saved to Path: "+video_path)
203
+
204
+
205
+ return video_path
206
+
207
+ def create_demo(model_path, save_path):
208
+
209
+ with gr.Blocks() as demo:
210
+ gr.Markdown("<div align=center><H1>HunyuanVideo<SUP>GP</SUP> by Tencent</H3></div>")
211
+ gr.Markdown("*GPU Poor version by **DeepBeepMeep**. Now this great video generator can run smoothly on a 24 GB rig.*")
212
+ gr.Markdown("Please be aware of these limits with profiles 2 and 4 if you have 24 GB of VRAM (RTX 3090 / RTX 4090):")
213
+ gr.Markdown("- max 192 frames for 848 x 480 ")
214
+ gr.Markdown("- max 86 frames for 1280 x 720")
215
+ gr.Markdown("In the worst case, one step should not take more than 2 minutes. If it the case you may be running out of RAM / VRAM. Try to generate fewer images / lower res / a less demanding profile.")
216
+ gr.Markdown("If you have a Linux / WSL system you may turn on compilation (see below) and will be able to generate an extra 30°% frames")
217
+
218
+ with gr.Accordion("Video Engine Configuration", open = False):
219
+ gr.Markdown("For the changes to be effective you will need to restart the gradio_server")
220
+
221
+ with gr.Column():
222
+ index = transformer_choices.index(transformer_filename)
223
+ index = 0 if index ==0 else index
224
+
225
+ transformer_choice = gr.Dropdown(
226
+ choices=[
227
+ ("Hunyuan Video 16 bits - the default engine in its original glory, offers a slightly better image quality but slower and requires more RAM", 0),
228
+ ("Hunyuan Video quantized to 8 bits (recommended) - the default engine but quantized", 1),
229
+ ("Fast Hunyuan Video quantized to 8 bits - requires less than 10 steps but worse quality", 2),
230
+ ],
231
+ value= index,
232
+ label="Transformer"
233
+ )
234
+ index = text_encoder_choices.index(text_encoder_filename)
235
+ index = 0 if index ==0 else index
236
+
237
+ gr.Markdown("Note that even if you choose a 16 bits Llava model below, depending on the profile it may be automatically quantized to 8 bits on the fly")
238
+ text_encoder_choice = gr.Dropdown(
239
+ choices=[
240
+ ("Llava Llama 1.1 16 bits - unquantized text encoder, better quality uses more RAM", 0),
241
+ ("Llava Llama 1.1 quantized to 8 bits - quantized text encoder, worse quality but uses less RAM", 1),
242
+ ],
243
+ value= index,
244
+ label="Text Encoder"
245
+ )
246
+ attention_choice = gr.Dropdown(
247
+ choices=[
248
+ ("Scale Dot Product Attention: default", "sdpa"),
249
+ ("Flash: good quality - requires additional install (usually complex to set up on Windows without WSL)", "flash"),
250
+ ("Sage: 30% faster but worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage"),
251
+ ],
252
+ value= attention_mode,
253
+ label="Attention Type"
254
+ )
255
+ gr.Markdown("Beware: restarting the server or changing a resolution or video duration will trigger a recompilation that may last a few minutes")
256
+ compile_choice = gr.Dropdown(
257
+ choices=[
258
+ ("ON: works only on Linux / WSL", "transformer"),
259
+ ("OFF: no other choice if you have Windows without using WSL", "" ),
260
+ ],
261
+ value= compile,
262
+ label="Compile Transformer (up to 50% faster and 30% more frames but requires Linux / WSL and Flash or Sage attention)"
263
+ )
264
+ profile_choice = gr.Dropdown(
265
+ choices=[
266
+ ("HighRAM_HighVRAM, profile 1: at least 48 GB of RAM and 24 GB of VRAM, the fastest for shorter videos a RTX 3090 / RTX 4090", 1),
267
+ ("HighRAM_LowVRAM, profile 2 (Recommended): at least 48 GB of RAM and 12 GB of VRAM, the most versatile profile with high RAM, better suited for RTX 3070/3080/4070/4080 or for RTX 3090 / RTX 4090 with large pictures batches or long videos", 2),
268
+ ("LowRAM_HighVRAM, profile 3: at least 32 GB of RAM and 24 GB of VRAM, adapted for RTX 3090 / RTX 4090 with limited RAM for good speed short video",3),
269
+ ("LowRAM_LowVRAM, profile 4 (Default): at least 32 GB of RAM and 12 GB of VRAM, if you have little VRAM or want to generate longer videos",4),
270
+ ("VerylowRAM_LowVRAM, profile 5: (Fail safe): at least 16 GB of RAM and 10 GB of VRAM, if you don't have much it won't be fast but maybe it will work",5)
271
+ ],
272
+ value= profile,
273
+ label="Profile"
274
+ )
275
+
276
+ msg = gr.Markdown()
277
+ apply_btn = gr.Button("Apply Changes")
278
+
279
+
280
+ apply_btn.click(
281
+ fn=apply_changes,
282
+ inputs=[
283
+ transformer_choice,
284
+ text_encoder_choice,
285
+ attention_choice,
286
+ compile_choice,
287
+ profile_choice,
288
+ ],
289
+ outputs=msg
290
+ )
291
+
292
+ with gr.Row():
293
+ with gr.Column():
294
+ prompt = gr.Textbox(label="Prompt", value="A large orange octopus is seen resting on the bottom of the ocean floor, blending in with the sandy and rocky terrain. Its tentacles are spread out around its body, and its eyes are closed. The octopus is unaware of a king crab that is crawling towards it from behind a rock, its claws raised and ready to attack. The crab is brown and spiny, with long legs and antennae. The scene is captured from a wide angle, showing the vastness and depth of the ocean. The water is clear and blue, with rays of sunlight filtering through. The shot is sharp and crisp, with a high dynamic range. The octopus and the crab are in focus, while the background is slightly blurred, creating a depth of field effect.")
295
+ with gr.Row():
296
+ resolution = gr.Dropdown(
297
+ choices=[
298
+ # 720p
299
+ ("1280x720 (16:9, 720p)", "1280x720"),
300
+ ("720x1280 (9:16, 720p)", "720x1280"),
301
+ ("1104x832 (4:3, 720p)", "1104x832"),
302
+ ("832x1104 (3:4, 720p)", "832x1104"),
303
+ ("960x960 (1:1, 720p)", "960x960"),
304
+ # 540p
305
+ ("960x544 (16:9, 540p)", "960x544"),
306
+ ("848x480 (16:9, 540p)", "848x480"),
307
+ ("544x960 (9:16, 540p)", "544x960"),
308
+ ("832x624 (4:3, 540p)", "832x624"),
309
+ ("624x832 (3:4, 540p)", "624x832"),
310
+ ("720x720 (1:1, 540p)", "720x720"),
311
+ ],
312
+ value="848x480",
313
+ label="Resolution"
314
+ )
315
+
316
+ video_length = gr.Slider(5, 193, value=97, step=4, label="Number of frames (24 = 1s)")
317
+
318
+ # video_length = gr.Dropdown(
319
+ # label="Video Length",
320
+ # choices=[
321
+ # ("1.5s(41f)", 41),
322
+ # ("2s(65f)", 65),
323
+ # ("4s(97f)", 97),
324
+ # ("5s(129f)", 129),
325
+ # ],
326
+ # value=97,
327
+ # )
328
+ num_inference_steps = gr.Slider(1, 100, value=50, step=1, label="Number of Inference Steps")
329
+ show_advanced = gr.Checkbox(label="Show Advanced Options", value=False)
330
+ with gr.Row(visible=False) as advanced_row:
331
+ with gr.Column():
332
+ seed = gr.Number(value=-1, label="Seed (-1 for random)")
333
+ guidance_scale = gr.Slider(1.0, 20.0, value=1.0, step=0.5, label="Guidance Scale")
334
+ flow_shift = gr.Slider(0.0, 25.0, value=7.0, step=0.1, label="Flow Shift")
335
+ embedded_guidance_scale = gr.Slider(1.0, 20.0, value=6.0, step=0.5, label="Embedded Guidance Scale")
336
+ with gr.Row():
337
+ tea_cache_setting = gr.Dropdown(
338
+ choices=[
339
+ ("Disabled", 0),
340
+ ("Fast (x1.6 speed up)", 0.1),
341
+ ("Faster (x2.1 speed up)", 0.15),
342
+ ],
343
+ value=0,
344
+ label="Tea Cache acceleration (the faster the acceleration the higher the degradation of the quality of the video)"
345
+ )
346
+
347
+ show_advanced.change(fn=lambda x: gr.Row(visible=x), inputs=[show_advanced], outputs=[advanced_row])
348
+ generate_btn = gr.Button("Generate")
349
+
350
+ with gr.Column():
351
+ output = gr.Video(label="Generated Video")
352
+
353
+ generate_btn.click(
354
+ fn=generate_video,
355
+ inputs=[
356
+ prompt,
357
+ resolution,
358
+ video_length,
359
+ seed,
360
+ num_inference_steps,
361
+ guidance_scale,
362
+ flow_shift,
363
+ embedded_guidance_scale,
364
+ tea_cache_setting
365
+ ],
366
+ outputs=output
367
+ )
368
+
369
+ return demo
370
+
371
+ if __name__ == "__main__":
372
+ os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
373
+ server_name = os.getenv("SERVER_NAME", "0.0.0.0")
374
+ server_port = int(os.getenv("SERVER_PORT", "7860"))
375
+ demo = create_demo(args.model_base, args.save_path)
376
+ demo.launch(server_name=server_name, server_port=server_port)
hyvideo/__init__.py ADDED
File without changes
hyvideo/config.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from .constants import *
3
+ import re
4
+ from .modules.models import HUNYUAN_VIDEO_CONFIG
5
+
6
+
7
+ def parse_args(namespace=None):
8
+ parser = argparse.ArgumentParser(description="HunyuanVideo inference script")
9
+
10
+ parser = add_network_args(parser)
11
+ parser = add_extra_models_args(parser)
12
+ parser = add_denoise_schedule_args(parser)
13
+ parser = add_inference_args(parser)
14
+ parser = add_parallel_args(parser)
15
+
16
+ args = parser.parse_args(namespace=namespace)
17
+ args = sanity_check_args(args)
18
+
19
+ return args
20
+
21
+
22
+ def add_network_args(parser: argparse.ArgumentParser):
23
+ group = parser.add_argument_group(title="HunyuanVideo network args")
24
+
25
+ group.add_argument(
26
+ "--profile",
27
+ type=str,
28
+ default=-1,
29
+ help="Profile No"
30
+ )
31
+
32
+ group.add_argument(
33
+ "--verbose",
34
+ type=str,
35
+ default=1,
36
+ help="Verbose level"
37
+ )
38
+
39
+ # Main model
40
+ group.add_argument(
41
+ "--model",
42
+ type=str,
43
+ choices=list(HUNYUAN_VIDEO_CONFIG.keys()),
44
+ default="HYVideo-T/2-cfgdistill",
45
+ )
46
+ group.add_argument(
47
+ "--latent-channels",
48
+ type=str,
49
+ default=16,
50
+ help="Number of latent channels of DiT. If None, it will be determined by `vae`. If provided, "
51
+ "it still needs to match the latent channels of the VAE model.",
52
+ )
53
+ group.add_argument(
54
+ "--precision",
55
+ type=str,
56
+ default="bf16",
57
+ choices=PRECISIONS,
58
+ help="Precision mode. Options: fp32, fp16, bf16. Applied to the backbone model and optimizer.",
59
+ )
60
+
61
+ # RoPE
62
+ group.add_argument(
63
+ "--rope-theta", type=int, default=256, help="Theta used in RoPE."
64
+ )
65
+ return parser
66
+
67
+
68
+ def add_extra_models_args(parser: argparse.ArgumentParser):
69
+ group = parser.add_argument_group(
70
+ title="Extra models args, including vae, text encoders and tokenizers)"
71
+ )
72
+
73
+ # - VAE
74
+ group.add_argument(
75
+ "--vae",
76
+ type=str,
77
+ default="884-16c-hy",
78
+ choices=list(VAE_PATH),
79
+ help="Name of the VAE model.",
80
+ )
81
+ group.add_argument(
82
+ "--vae-precision",
83
+ type=str,
84
+ default="fp16",
85
+ choices=PRECISIONS,
86
+ help="Precision mode for the VAE model.",
87
+ )
88
+ group.add_argument(
89
+ "--vae-tiling",
90
+ action="store_true",
91
+ help="Enable tiling for the VAE model to save GPU memory.",
92
+ )
93
+ group.set_defaults(vae_tiling=True)
94
+
95
+ group.add_argument(
96
+ "--text-encoder",
97
+ type=str,
98
+ default="llm",
99
+ choices=list(TEXT_ENCODER_PATH),
100
+ help="Name of the text encoder model.",
101
+ )
102
+ group.add_argument(
103
+ "--text-encoder-precision",
104
+ type=str,
105
+ default="fp16",
106
+ choices=PRECISIONS,
107
+ help="Precision mode for the text encoder model.",
108
+ )
109
+ group.add_argument(
110
+ "--text-states-dim",
111
+ type=int,
112
+ default=4096,
113
+ help="Dimension of the text encoder hidden states.",
114
+ )
115
+ group.add_argument(
116
+ "--text-len", type=int, default=256, help="Maximum length of the text input."
117
+ )
118
+ group.add_argument(
119
+ "--tokenizer",
120
+ type=str,
121
+ default="llm",
122
+ choices=list(TOKENIZER_PATH),
123
+ help="Name of the tokenizer model.",
124
+ )
125
+ group.add_argument(
126
+ "--prompt-template",
127
+ type=str,
128
+ default="dit-llm-encode",
129
+ choices=PROMPT_TEMPLATE,
130
+ help="Image prompt template for the decoder-only text encoder model.",
131
+ )
132
+ group.add_argument(
133
+ "--prompt-template-video",
134
+ type=str,
135
+ default="dit-llm-encode-video",
136
+ choices=PROMPT_TEMPLATE,
137
+ help="Video prompt template for the decoder-only text encoder model.",
138
+ )
139
+ group.add_argument(
140
+ "--hidden-state-skip-layer",
141
+ type=int,
142
+ default=2,
143
+ help="Skip layer for hidden states.",
144
+ )
145
+ group.add_argument(
146
+ "--apply-final-norm",
147
+ action="store_true",
148
+ help="Apply final normalization to the used text encoder hidden states.",
149
+ )
150
+
151
+ # - CLIP
152
+ group.add_argument(
153
+ "--text-encoder-2",
154
+ type=str,
155
+ default="clipL",
156
+ choices=list(TEXT_ENCODER_PATH),
157
+ help="Name of the second text encoder model.",
158
+ )
159
+ group.add_argument(
160
+ "--text-encoder-precision-2",
161
+ type=str,
162
+ default="fp16",
163
+ choices=PRECISIONS,
164
+ help="Precision mode for the second text encoder model.",
165
+ )
166
+ group.add_argument(
167
+ "--text-states-dim-2",
168
+ type=int,
169
+ default=768,
170
+ help="Dimension of the second text encoder hidden states.",
171
+ )
172
+ group.add_argument(
173
+ "--tokenizer-2",
174
+ type=str,
175
+ default="clipL",
176
+ choices=list(TOKENIZER_PATH),
177
+ help="Name of the second tokenizer model.",
178
+ )
179
+ group.add_argument(
180
+ "--text-len-2",
181
+ type=int,
182
+ default=77,
183
+ help="Maximum length of the second text input.",
184
+ )
185
+
186
+ return parser
187
+
188
+
189
+ def add_denoise_schedule_args(parser: argparse.ArgumentParser):
190
+ group = parser.add_argument_group(title="Denoise schedule args")
191
+
192
+ group.add_argument(
193
+ "--denoise-type",
194
+ type=str,
195
+ default="flow",
196
+ help="Denoise type for noised inputs.",
197
+ )
198
+
199
+ # Flow Matching
200
+ group.add_argument(
201
+ "--flow-shift",
202
+ type=float,
203
+ default=7.0,
204
+ help="Shift factor for flow matching schedulers.",
205
+ )
206
+ group.add_argument(
207
+ "--flow-reverse",
208
+ action="store_true",
209
+ help="If reverse, learning/sampling from t=1 -> t=0.",
210
+ )
211
+ group.add_argument(
212
+ "--flow-solver",
213
+ type=str,
214
+ default="euler",
215
+ help="Solver for flow matching.",
216
+ )
217
+ group.add_argument(
218
+ "--use-linear-quadratic-schedule",
219
+ action="store_true",
220
+ help="Use linear quadratic schedule for flow matching."
221
+ "Following MovieGen (https://ai.meta.com/static-resource/movie-gen-research-paper)",
222
+ )
223
+ group.add_argument(
224
+ "--linear-schedule-end",
225
+ type=int,
226
+ default=25,
227
+ help="End step for linear quadratic schedule for flow matching.",
228
+ )
229
+
230
+ return parser
231
+
232
+
233
+ def add_inference_args(parser: argparse.ArgumentParser):
234
+ group = parser.add_argument_group(title="Inference args")
235
+
236
+ # ======================== Model loads ========================
237
+ group.add_argument(
238
+ "--model-base",
239
+ type=str,
240
+ default="ckpts",
241
+ help="Root path of all the models, including t2v models and extra models.",
242
+ )
243
+ group.add_argument(
244
+ "--dit-weight",
245
+ type=str,
246
+ default="ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt",
247
+ help="Path to the HunyuanVideo model. If None, search the model in the args.model_root."
248
+ "1. If it is a file, load the model directly."
249
+ "2. If it is a directory, search the model in the directory. Support two types of models: "
250
+ "1) named `pytorch_model_*.pt`"
251
+ "2) named `*_model_states.pt`, where * can be `mp_rank_00`.",
252
+ )
253
+ group.add_argument(
254
+ "--model-resolution",
255
+ type=str,
256
+ default="540p",
257
+ choices=["540p", "720p"],
258
+ help="Root path of all the models, including t2v models and extra models.",
259
+ )
260
+ group.add_argument(
261
+ "--load-key",
262
+ type=str,
263
+ default="module",
264
+ help="Key to load the model states. 'module' for the main model, 'ema' for the EMA model.",
265
+ )
266
+ group.add_argument(
267
+ "--use-cpu-offload",
268
+ action="store_true",
269
+ help="Use CPU offload for the model load.",
270
+ )
271
+
272
+ # ======================== Inference general setting ========================
273
+ group.add_argument(
274
+ "--batch-size",
275
+ type=int,
276
+ default=1,
277
+ help="Batch size for inference and evaluation.",
278
+ )
279
+ group.add_argument(
280
+ "--infer-steps",
281
+ type=int,
282
+ default=50,
283
+ help="Number of denoising steps for inference.",
284
+ )
285
+ group.add_argument(
286
+ "--disable-autocast",
287
+ action="store_true",
288
+ help="Disable autocast for denoising loop and vae decoding in pipeline sampling.",
289
+ )
290
+ group.add_argument(
291
+ "--save-path",
292
+ type=str,
293
+ default="./results",
294
+ help="Path to save the generated samples.",
295
+ )
296
+ group.add_argument(
297
+ "--save-path-suffix",
298
+ type=str,
299
+ default="",
300
+ help="Suffix for the directory of saved samples.",
301
+ )
302
+ group.add_argument(
303
+ "--name-suffix",
304
+ type=str,
305
+ default="",
306
+ help="Suffix for the names of saved samples.",
307
+ )
308
+ group.add_argument(
309
+ "--num-videos",
310
+ type=int,
311
+ default=1,
312
+ help="Number of videos to generate for each prompt.",
313
+ )
314
+ # ---sample size---
315
+ group.add_argument(
316
+ "--video-size",
317
+ type=int,
318
+ nargs="+",
319
+ default=(720, 1280),
320
+ help="Video size for training. If a single value is provided, it will be used for both height "
321
+ "and width. If two values are provided, they will be used for height and width "
322
+ "respectively.",
323
+ )
324
+ group.add_argument(
325
+ "--video-length",
326
+ type=int,
327
+ default=129,
328
+ help="How many frames to sample from a video. if using 3d vae, the number should be 4n+1",
329
+ )
330
+ # --- prompt ---
331
+ group.add_argument(
332
+ "--prompt",
333
+ type=str,
334
+ default=None,
335
+ help="Prompt for sampling during evaluation.",
336
+ )
337
+ group.add_argument(
338
+ "--seed-type",
339
+ type=str,
340
+ default="auto",
341
+ choices=["file", "random", "fixed", "auto"],
342
+ help="Seed type for evaluation. If file, use the seed from the CSV file. If random, generate a "
343
+ "random seed. If fixed, use the fixed seed given by `--seed`. If auto, `csv` will use the "
344
+ "seed column if available, otherwise use the fixed `seed` value. `prompt` will use the "
345
+ "fixed `seed` value.",
346
+ )
347
+ group.add_argument("--seed", type=int, default=None, help="Seed for evaluation.")
348
+
349
+ # Classifier-Free Guidance
350
+ group.add_argument(
351
+ "--neg-prompt", type=str, default=None, help="Negative prompt for sampling."
352
+ )
353
+ group.add_argument(
354
+ "--cfg-scale", type=float, default=1.0, help="Classifier free guidance scale."
355
+ )
356
+ group.add_argument(
357
+ "--embedded-cfg-scale",
358
+ type=float,
359
+ default=6.0,
360
+ help="Embeded classifier free guidance scale.",
361
+ )
362
+
363
+ group.add_argument(
364
+ "--reproduce",
365
+ action="store_true",
366
+ help="Enable reproducibility by setting random seeds and deterministic algorithms.",
367
+ )
368
+
369
+ return parser
370
+
371
+
372
+ def add_parallel_args(parser: argparse.ArgumentParser):
373
+ group = parser.add_argument_group(title="Parallel args")
374
+
375
+ # ======================== Model loads ========================
376
+ group.add_argument(
377
+ "--ulysses-degree",
378
+ type=int,
379
+ default=1,
380
+ help="Ulysses degree.",
381
+ )
382
+ group.add_argument(
383
+ "--ring-degree",
384
+ type=int,
385
+ default=1,
386
+ help="Ulysses degree.",
387
+ )
388
+
389
+ return parser
390
+
391
+
392
+ def sanity_check_args(args):
393
+ # VAE channels
394
+ vae_pattern = r"\d{2,3}-\d{1,2}c-\w+"
395
+ if not re.match(vae_pattern, args.vae):
396
+ raise ValueError(
397
+ f"Invalid VAE model: {args.vae}. Must be in the format of '{vae_pattern}'."
398
+ )
399
+ vae_channels = int(args.vae.split("-")[1][:-1])
400
+ if args.latent_channels is None:
401
+ args.latent_channels = vae_channels
402
+ if vae_channels != args.latent_channels:
403
+ raise ValueError(
404
+ f"Latent channels ({args.latent_channels}) must match the VAE channels ({vae_channels})."
405
+ )
406
+ return args
hyvideo/constants.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+
4
+ __all__ = [
5
+ "C_SCALE",
6
+ "PROMPT_TEMPLATE",
7
+ "MODEL_BASE",
8
+ "PRECISIONS",
9
+ "NORMALIZATION_TYPE",
10
+ "ACTIVATION_TYPE",
11
+ "VAE_PATH",
12
+ "TEXT_ENCODER_PATH",
13
+ "TOKENIZER_PATH",
14
+ "TEXT_PROJECTION",
15
+ "DATA_TYPE",
16
+ "NEGATIVE_PROMPT",
17
+ ]
18
+
19
+ PRECISION_TO_TYPE = {
20
+ 'fp32': torch.float32,
21
+ 'fp16': torch.float16,
22
+ 'bf16': torch.bfloat16,
23
+ }
24
+
25
+ # =================== Constant Values =====================
26
+ # Computation scale factor, 1P = 1_000_000_000_000_000. Tensorboard will display the value in PetaFLOPS to avoid
27
+ # overflow error when tensorboard logging values.
28
+ C_SCALE = 1_000_000_000_000_000
29
+
30
+ # When using decoder-only models, we must provide a prompt template to instruct the text encoder
31
+ # on how to generate the text.
32
+ # --------------------------------------------------------------------
33
+ PROMPT_TEMPLATE_ENCODE = (
34
+ "<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, "
35
+ "quantity, text, spatial relationships of the objects and background:<|eot_id|>"
36
+ "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
37
+ )
38
+ PROMPT_TEMPLATE_ENCODE_VIDEO = (
39
+ "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
40
+ "1. The main content and theme of the video."
41
+ "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
42
+ "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
43
+ "4. background environment, light, style and atmosphere."
44
+ "5. camera angles, movements, and transitions used in the video:<|eot_id|>"
45
+ "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
46
+ )
47
+
48
+ NEGATIVE_PROMPT = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion"
49
+
50
+ PROMPT_TEMPLATE = {
51
+ "dit-llm-encode": {
52
+ "template": PROMPT_TEMPLATE_ENCODE,
53
+ "crop_start": 36,
54
+ },
55
+ "dit-llm-encode-video": {
56
+ "template": PROMPT_TEMPLATE_ENCODE_VIDEO,
57
+ "crop_start": 95,
58
+ },
59
+ }
60
+
61
+ # ======================= Model ======================
62
+ PRECISIONS = {"fp32", "fp16", "bf16"}
63
+ NORMALIZATION_TYPE = {"layer", "rms"}
64
+ ACTIVATION_TYPE = {"relu", "silu", "gelu", "gelu_tanh"}
65
+
66
+ # =================== Model Path =====================
67
+ MODEL_BASE = os.getenv("MODEL_BASE", "./ckpts")
68
+
69
+ # =================== Data =======================
70
+ DATA_TYPE = {"image", "video", "image_video"}
71
+
72
+ # 3D VAE
73
+ VAE_PATH = {"884-16c-hy": f"{MODEL_BASE}/hunyuan-video-t2v-720p/vae"}
74
+
75
+ # Text Encoder
76
+ TEXT_ENCODER_PATH = {
77
+ "clipL": f"{MODEL_BASE}/text_encoder_2",
78
+ "llm": f"{MODEL_BASE}/text_encoder",
79
+ }
80
+
81
+ # Tokenizer
82
+ TOKENIZER_PATH = {
83
+ "clipL": f"{MODEL_BASE}/text_encoder_2",
84
+ "llm": f"{MODEL_BASE}/text_encoder",
85
+ }
86
+
87
+ TEXT_PROJECTION = {
88
+ "linear", # Default, an nn.Linear() layer
89
+ "single_refiner", # Single TokenRefiner. Refer to LI-DiT
90
+ }
hyvideo/diffusion/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .pipelines import HunyuanVideoPipeline
2
+ from .schedulers import FlowMatchDiscreteScheduler
hyvideo/diffusion/pipelines/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .pipeline_hunyuan_video import HunyuanVideoPipeline
hyvideo/diffusion/pipelines/pipeline_hunyuan_video.py ADDED
@@ -0,0 +1,1103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ #
16
+ # Modified from diffusers==0.29.2
17
+ #
18
+ # ==============================================================================
19
+ import inspect
20
+ from typing import Any, Callable, Dict, List, Optional, Union, Tuple
21
+ import torch
22
+ import torch.distributed as dist
23
+ import numpy as np
24
+ from dataclasses import dataclass
25
+ from packaging import version
26
+
27
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
28
+ from diffusers.configuration_utils import FrozenDict
29
+ from diffusers.image_processor import VaeImageProcessor
30
+ from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin
31
+ from diffusers.models import AutoencoderKL
32
+ from diffusers.models.lora import adjust_lora_scale_text_encoder
33
+ from diffusers.schedulers import KarrasDiffusionSchedulers
34
+ from diffusers.utils import (
35
+ USE_PEFT_BACKEND,
36
+ deprecate,
37
+ logging,
38
+ replace_example_docstring,
39
+ scale_lora_layers,
40
+ unscale_lora_layers,
41
+ )
42
+ from diffusers.utils.torch_utils import randn_tensor
43
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
44
+ from diffusers.utils import BaseOutput
45
+
46
+ from ...constants import PRECISION_TO_TYPE
47
+ from ...vae.autoencoder_kl_causal_3d import AutoencoderKLCausal3D
48
+ from ...text_encoder import TextEncoder
49
+ from ...modules import HYVideoDiffusionTransformer
50
+
51
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
52
+
53
+ EXAMPLE_DOC_STRING = """"""
54
+
55
+
56
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
57
+ """
58
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
59
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
60
+ """
61
+ std_text = noise_pred_text.std(
62
+ dim=list(range(1, noise_pred_text.ndim)), keepdim=True
63
+ )
64
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
65
+ # rescale the results from guidance (fixes overexposure)
66
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
67
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
68
+ noise_cfg = (
69
+ guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
70
+ )
71
+ return noise_cfg
72
+
73
+
74
+ def retrieve_timesteps(
75
+ scheduler,
76
+ num_inference_steps: Optional[int] = None,
77
+ device: Optional[Union[str, torch.device]] = None,
78
+ timesteps: Optional[List[int]] = None,
79
+ sigmas: Optional[List[float]] = None,
80
+ **kwargs,
81
+ ):
82
+ """
83
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
84
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
85
+
86
+ Args:
87
+ scheduler (`SchedulerMixin`):
88
+ The scheduler to get timesteps from.
89
+ num_inference_steps (`int`):
90
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
91
+ must be `None`.
92
+ device (`str` or `torch.device`, *optional*):
93
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
94
+ timesteps (`List[int]`, *optional*):
95
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
96
+ `num_inference_steps` and `sigmas` must be `None`.
97
+ sigmas (`List[float]`, *optional*):
98
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
99
+ `num_inference_steps` and `timesteps` must be `None`.
100
+
101
+ Returns:
102
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
103
+ second element is the number of inference steps.
104
+ """
105
+ if timesteps is not None and sigmas is not None:
106
+ raise ValueError(
107
+ "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
108
+ )
109
+ if timesteps is not None:
110
+ accepts_timesteps = "timesteps" in set(
111
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
112
+ )
113
+ if not accepts_timesteps:
114
+ raise ValueError(
115
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
116
+ f" timestep schedules. Please check whether you are using the correct scheduler."
117
+ )
118
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
119
+ timesteps = scheduler.timesteps
120
+ num_inference_steps = len(timesteps)
121
+ elif sigmas is not None:
122
+ accept_sigmas = "sigmas" in set(
123
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
124
+ )
125
+ if not accept_sigmas:
126
+ raise ValueError(
127
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
128
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
129
+ )
130
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
131
+ timesteps = scheduler.timesteps
132
+ num_inference_steps = len(timesteps)
133
+ else:
134
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
135
+ timesteps = scheduler.timesteps
136
+ return timesteps, num_inference_steps
137
+
138
+
139
+ @dataclass
140
+ class HunyuanVideoPipelineOutput(BaseOutput):
141
+ videos: Union[torch.Tensor, np.ndarray]
142
+
143
+
144
+ class HunyuanVideoPipeline(DiffusionPipeline):
145
+ r"""
146
+ Pipeline for text-to-video generation using HunyuanVideo.
147
+
148
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
149
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
150
+
151
+ Args:
152
+ vae ([`AutoencoderKL`]):
153
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
154
+ text_encoder ([`TextEncoder`]):
155
+ Frozen text-encoder.
156
+ text_encoder_2 ([`TextEncoder`]):
157
+ Frozen text-encoder_2.
158
+ transformer ([`HYVideoDiffusionTransformer`]):
159
+ A `HYVideoDiffusionTransformer` to denoise the encoded video latents.
160
+ scheduler ([`SchedulerMixin`]):
161
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents.
162
+ """
163
+
164
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
165
+ _optional_components = ["text_encoder_2"]
166
+ _exclude_from_cpu_offload = ["transformer"]
167
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
168
+
169
+ def __init__(
170
+ self,
171
+ vae: AutoencoderKL,
172
+ text_encoder: TextEncoder,
173
+ transformer: HYVideoDiffusionTransformer,
174
+ scheduler: KarrasDiffusionSchedulers,
175
+ text_encoder_2: Optional[TextEncoder] = None,
176
+ progress_bar_config: Dict[str, Any] = None,
177
+ args=None,
178
+ ):
179
+ super().__init__()
180
+
181
+ # ==========================================================================================
182
+ if progress_bar_config is None:
183
+ progress_bar_config = {}
184
+ if not hasattr(self, "_progress_bar_config"):
185
+ self._progress_bar_config = {}
186
+ self._progress_bar_config.update(progress_bar_config)
187
+
188
+ self.args = args
189
+ # ==========================================================================================
190
+
191
+ if (
192
+ hasattr(scheduler.config, "steps_offset")
193
+ and scheduler.config.steps_offset != 1
194
+ ):
195
+ deprecation_message = (
196
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
197
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
198
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
199
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
200
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
201
+ " file"
202
+ )
203
+ deprecate(
204
+ "steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False
205
+ )
206
+ new_config = dict(scheduler.config)
207
+ new_config["steps_offset"] = 1
208
+ scheduler._internal_dict = FrozenDict(new_config)
209
+
210
+ if (
211
+ hasattr(scheduler.config, "clip_sample")
212
+ and scheduler.config.clip_sample is True
213
+ ):
214
+ deprecation_message = (
215
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
216
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
217
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
218
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
219
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
220
+ )
221
+ deprecate(
222
+ "clip_sample not set", "1.0.0", deprecation_message, standard_warn=False
223
+ )
224
+ new_config = dict(scheduler.config)
225
+ new_config["clip_sample"] = False
226
+ scheduler._internal_dict = FrozenDict(new_config)
227
+
228
+ self.register_modules(
229
+ vae=vae,
230
+ text_encoder=text_encoder,
231
+ transformer=transformer,
232
+ scheduler=scheduler,
233
+ text_encoder_2=text_encoder_2,
234
+ )
235
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
236
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
237
+
238
+ def encode_prompt(
239
+ self,
240
+ prompt,
241
+ device,
242
+ num_videos_per_prompt,
243
+ do_classifier_free_guidance,
244
+ negative_prompt=None,
245
+ prompt_embeds: Optional[torch.Tensor] = None,
246
+ attention_mask: Optional[torch.Tensor] = None,
247
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
248
+ negative_attention_mask: Optional[torch.Tensor] = None,
249
+ lora_scale: Optional[float] = None,
250
+ clip_skip: Optional[int] = None,
251
+ text_encoder: Optional[TextEncoder] = None,
252
+ data_type: Optional[str] = "image",
253
+ ):
254
+ r"""
255
+ Encodes the prompt into text encoder hidden states.
256
+
257
+ Args:
258
+ prompt (`str` or `List[str]`, *optional*):
259
+ prompt to be encoded
260
+ device: (`torch.device`):
261
+ torch device
262
+ num_videos_per_prompt (`int`):
263
+ number of videos that should be generated per prompt
264
+ do_classifier_free_guidance (`bool`):
265
+ whether to use classifier free guidance or not
266
+ negative_prompt (`str` or `List[str]`, *optional*):
267
+ The prompt or prompts not to guide the video generation. If not defined, one has to pass
268
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
269
+ less than `1`).
270
+ prompt_embeds (`torch.Tensor`, *optional*):
271
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
272
+ provided, text embeddings will be generated from `prompt` input argument.
273
+ attention_mask (`torch.Tensor`, *optional*):
274
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
275
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
276
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
277
+ argument.
278
+ negative_attention_mask (`torch.Tensor`, *optional*):
279
+ lora_scale (`float`, *optional*):
280
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
281
+ clip_skip (`int`, *optional*):
282
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
283
+ the output of the pre-final layer will be used for computing the prompt embeddings.
284
+ text_encoder (TextEncoder, *optional*):
285
+ data_type (`str`, *optional*):
286
+ """
287
+ if text_encoder is None:
288
+ text_encoder = self.text_encoder
289
+
290
+ # set lora scale so that monkey patched LoRA
291
+ # function of text encoder can correctly access it
292
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
293
+ self._lora_scale = lora_scale
294
+
295
+ # dynamically adjust the LoRA scale
296
+ if not USE_PEFT_BACKEND:
297
+ adjust_lora_scale_text_encoder(text_encoder.model, lora_scale)
298
+ else:
299
+ scale_lora_layers(text_encoder.model, lora_scale)
300
+
301
+ if prompt is not None and isinstance(prompt, str):
302
+ batch_size = 1
303
+ elif prompt is not None and isinstance(prompt, list):
304
+ batch_size = len(prompt)
305
+ else:
306
+ batch_size = prompt_embeds.shape[0]
307
+
308
+ if prompt_embeds is None:
309
+ # textual inversion: process multi-vector tokens if necessary
310
+ if isinstance(self, TextualInversionLoaderMixin):
311
+ prompt = self.maybe_convert_prompt(prompt, text_encoder.tokenizer)
312
+
313
+ text_inputs = text_encoder.text2tokens(prompt, data_type=data_type)
314
+
315
+ if clip_skip is None:
316
+ prompt_outputs = text_encoder.encode(
317
+ text_inputs, data_type=data_type, device=device
318
+ )
319
+ prompt_embeds = prompt_outputs.hidden_state
320
+ else:
321
+ prompt_outputs = text_encoder.encode(
322
+ text_inputs,
323
+ output_hidden_states=True,
324
+ data_type=data_type,
325
+ device=device,
326
+ )
327
+ # Access the `hidden_states` first, that contains a tuple of
328
+ # all the hidden states from the encoder layers. Then index into
329
+ # the tuple to access the hidden states from the desired layer.
330
+ prompt_embeds = prompt_outputs.hidden_states_list[-(clip_skip + 1)]
331
+ # We also need to apply the final LayerNorm here to not mess with the
332
+ # representations. The `last_hidden_states` that we typically use for
333
+ # obtaining the final prompt representations passes through the LayerNorm
334
+ # layer.
335
+ prompt_embeds = text_encoder.model.text_model.final_layer_norm(
336
+ prompt_embeds
337
+ )
338
+
339
+ attention_mask = prompt_outputs.attention_mask
340
+ if attention_mask is not None:
341
+ attention_mask = attention_mask.to(device)
342
+ bs_embed, seq_len = attention_mask.shape
343
+ attention_mask = attention_mask.repeat(1, num_videos_per_prompt)
344
+ attention_mask = attention_mask.view(
345
+ bs_embed * num_videos_per_prompt, seq_len
346
+ )
347
+
348
+ if text_encoder is not None:
349
+ prompt_embeds_dtype = text_encoder.dtype
350
+ elif self.transformer is not None:
351
+ prompt_embeds_dtype = self.transformer.dtype
352
+ else:
353
+ prompt_embeds_dtype = prompt_embeds.dtype
354
+
355
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
356
+
357
+ if prompt_embeds.ndim == 2:
358
+ bs_embed, _ = prompt_embeds.shape
359
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
360
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt)
361
+ prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, -1)
362
+ else:
363
+ bs_embed, seq_len, _ = prompt_embeds.shape
364
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
365
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
366
+ prompt_embeds = prompt_embeds.view(
367
+ bs_embed * num_videos_per_prompt, seq_len, -1
368
+ )
369
+
370
+ # get unconditional embeddings for classifier free guidance
371
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
372
+ uncond_tokens: List[str]
373
+ if negative_prompt is None:
374
+ uncond_tokens = [""] * batch_size
375
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
376
+ raise TypeError(
377
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
378
+ f" {type(prompt)}."
379
+ )
380
+ elif isinstance(negative_prompt, str):
381
+ uncond_tokens = [negative_prompt]
382
+ elif batch_size != len(negative_prompt):
383
+ raise ValueError(
384
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
385
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
386
+ " the batch size of `prompt`."
387
+ )
388
+ else:
389
+ uncond_tokens = negative_prompt
390
+
391
+ # textual inversion: process multi-vector tokens if necessary
392
+ if isinstance(self, TextualInversionLoaderMixin):
393
+ uncond_tokens = self.maybe_convert_prompt(
394
+ uncond_tokens, text_encoder.tokenizer
395
+ )
396
+
397
+ # max_length = prompt_embeds.shape[1]
398
+ uncond_input = text_encoder.text2tokens(uncond_tokens, data_type=data_type)
399
+
400
+ negative_prompt_outputs = text_encoder.encode(
401
+ uncond_input, data_type=data_type, device=device
402
+ )
403
+ negative_prompt_embeds = negative_prompt_outputs.hidden_state
404
+
405
+ negative_attention_mask = negative_prompt_outputs.attention_mask
406
+ if negative_attention_mask is not None:
407
+ negative_attention_mask = negative_attention_mask.to(device)
408
+ _, seq_len = negative_attention_mask.shape
409
+ negative_attention_mask = negative_attention_mask.repeat(
410
+ 1, num_videos_per_prompt
411
+ )
412
+ negative_attention_mask = negative_attention_mask.view(
413
+ batch_size * num_videos_per_prompt, seq_len
414
+ )
415
+
416
+ if do_classifier_free_guidance:
417
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
418
+ seq_len = negative_prompt_embeds.shape[1]
419
+
420
+ negative_prompt_embeds = negative_prompt_embeds.to(
421
+ dtype=prompt_embeds_dtype, device=device
422
+ )
423
+
424
+ if negative_prompt_embeds.ndim == 2:
425
+ negative_prompt_embeds = negative_prompt_embeds.repeat(
426
+ 1, num_videos_per_prompt
427
+ )
428
+ negative_prompt_embeds = negative_prompt_embeds.view(
429
+ batch_size * num_videos_per_prompt, -1
430
+ )
431
+ else:
432
+ negative_prompt_embeds = negative_prompt_embeds.repeat(
433
+ 1, num_videos_per_prompt, 1
434
+ )
435
+ negative_prompt_embeds = negative_prompt_embeds.view(
436
+ batch_size * num_videos_per_prompt, seq_len, -1
437
+ )
438
+
439
+ if text_encoder is not None:
440
+ if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
441
+ # Retrieve the original scale by scaling back the LoRA layers
442
+ unscale_lora_layers(text_encoder.model, lora_scale)
443
+
444
+ return (
445
+ prompt_embeds,
446
+ negative_prompt_embeds,
447
+ attention_mask,
448
+ negative_attention_mask,
449
+ )
450
+
451
+ def decode_latents(self, latents, enable_tiling=True):
452
+ deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
453
+ deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
454
+
455
+ latents = 1 / self.vae.config.scaling_factor * latents
456
+ if enable_tiling:
457
+ self.vae.enable_tiling()
458
+ image = self.vae.decode(latents, return_dict=False)[0]
459
+ else:
460
+ image = self.vae.decode(latents, return_dict=False)[0]
461
+ image = (image / 2 + 0.5).clamp(0, 1)
462
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
463
+ if image.ndim == 4:
464
+ image = image.cpu().permute(0, 2, 3, 1).float()
465
+ else:
466
+ image = image.cpu().float()
467
+ return image
468
+
469
+ def prepare_extra_func_kwargs(self, func, kwargs):
470
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
471
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
472
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
473
+ # and should be between [0, 1]
474
+ extra_step_kwargs = {}
475
+
476
+ for k, v in kwargs.items():
477
+ accepts = k in set(inspect.signature(func).parameters.keys())
478
+ if accepts:
479
+ extra_step_kwargs[k] = v
480
+ return extra_step_kwargs
481
+
482
+ def check_inputs(
483
+ self,
484
+ prompt,
485
+ height,
486
+ width,
487
+ video_length,
488
+ callback_steps,
489
+ negative_prompt=None,
490
+ prompt_embeds=None,
491
+ negative_prompt_embeds=None,
492
+ callback_on_step_end_tensor_inputs=None,
493
+ vae_ver="88-4c-sd",
494
+ ):
495
+ if height % 8 != 0 or width % 8 != 0:
496
+ raise ValueError(
497
+ f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
498
+ )
499
+
500
+ if video_length is not None:
501
+ if "884" in vae_ver:
502
+ if video_length != 1 and (video_length - 1) % 4 != 0:
503
+ raise ValueError(
504
+ f"`video_length` has to be 1 or a multiple of 4 but is {video_length}."
505
+ )
506
+ elif "888" in vae_ver:
507
+ if video_length != 1 and (video_length - 1) % 8 != 0:
508
+ raise ValueError(
509
+ f"`video_length` has to be 1 or a multiple of 8 but is {video_length}."
510
+ )
511
+
512
+ if callback_steps is not None and (
513
+ not isinstance(callback_steps, int) or callback_steps <= 0
514
+ ):
515
+ raise ValueError(
516
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
517
+ f" {type(callback_steps)}."
518
+ )
519
+ if callback_on_step_end_tensor_inputs is not None and not all(
520
+ k in self._callback_tensor_inputs
521
+ for k in callback_on_step_end_tensor_inputs
522
+ ):
523
+ raise ValueError(
524
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
525
+ )
526
+
527
+ if prompt is not None and prompt_embeds is not None:
528
+ raise ValueError(
529
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
530
+ " only forward one of the two."
531
+ )
532
+ elif prompt is None and prompt_embeds is None:
533
+ raise ValueError(
534
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
535
+ )
536
+ elif prompt is not None and (
537
+ not isinstance(prompt, str) and not isinstance(prompt, list)
538
+ ):
539
+ raise ValueError(
540
+ f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
541
+ )
542
+
543
+ if negative_prompt is not None and negative_prompt_embeds is not None:
544
+ raise ValueError(
545
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
546
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
547
+ )
548
+
549
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
550
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
551
+ raise ValueError(
552
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
553
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
554
+ f" {negative_prompt_embeds.shape}."
555
+ )
556
+
557
+
558
+ def prepare_latents(
559
+ self,
560
+ batch_size,
561
+ num_channels_latents,
562
+ height,
563
+ width,
564
+ video_length,
565
+ dtype,
566
+ device,
567
+ generator,
568
+ latents=None,
569
+ ):
570
+ shape = (
571
+ batch_size,
572
+ num_channels_latents,
573
+ video_length,
574
+ int(height) // self.vae_scale_factor,
575
+ int(width) // self.vae_scale_factor,
576
+ )
577
+ if isinstance(generator, list) and len(generator) != batch_size:
578
+ raise ValueError(
579
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
580
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
581
+ )
582
+
583
+ if latents is None:
584
+ latents = randn_tensor(
585
+ shape, generator=generator, device=device, dtype=dtype
586
+ )
587
+ else:
588
+ latents = latents.to(device)
589
+
590
+ # Check existence to make it compatible with FlowMatchEulerDiscreteScheduler
591
+ if hasattr(self.scheduler, "init_noise_sigma"):
592
+ # scale the initial noise by the standard deviation required by the scheduler
593
+ latents = latents * self.scheduler.init_noise_sigma
594
+ return latents
595
+
596
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
597
+ def get_guidance_scale_embedding(
598
+ self,
599
+ w: torch.Tensor,
600
+ embedding_dim: int = 512,
601
+ dtype: torch.dtype = torch.float32,
602
+ ) -> torch.Tensor:
603
+ """
604
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
605
+
606
+ Args:
607
+ w (`torch.Tensor`):
608
+ Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
609
+ embedding_dim (`int`, *optional*, defaults to 512):
610
+ Dimension of the embeddings to generate.
611
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
612
+ Data type of the generated embeddings.
613
+
614
+ Returns:
615
+ `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
616
+ """
617
+ assert len(w.shape) == 1
618
+ w = w * 1000.0
619
+
620
+ half_dim = embedding_dim // 2
621
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
622
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
623
+ emb = w.to(dtype)[:, None] * emb[None, :]
624
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
625
+ if embedding_dim % 2 == 1: # zero pad
626
+ emb = torch.nn.functional.pad(emb, (0, 1))
627
+ assert emb.shape == (w.shape[0], embedding_dim)
628
+ return emb
629
+
630
+ @property
631
+ def guidance_scale(self):
632
+ return self._guidance_scale
633
+
634
+ @property
635
+ def guidance_rescale(self):
636
+ return self._guidance_rescale
637
+
638
+ @property
639
+ def clip_skip(self):
640
+ return self._clip_skip
641
+
642
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
643
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
644
+ # corresponds to doing no classifier free guidance.
645
+ @property
646
+ def do_classifier_free_guidance(self):
647
+ # return self._guidance_scale > 1 and self.transformer.config.time_cond_proj_dim is None
648
+ return self._guidance_scale > 1
649
+
650
+ @property
651
+ def cross_attention_kwargs(self):
652
+ return self._cross_attention_kwargs
653
+
654
+ @property
655
+ def num_timesteps(self):
656
+ return self._num_timesteps
657
+
658
+ @property
659
+ def interrupt(self):
660
+ return self._interrupt
661
+
662
+ @torch.no_grad()
663
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
664
+ def __call__(
665
+ self,
666
+ prompt: Union[str, List[str]],
667
+ height: int,
668
+ width: int,
669
+ video_length: int,
670
+ data_type: str = "video",
671
+ num_inference_steps: int = 50,
672
+ timesteps: List[int] = None,
673
+ sigmas: List[float] = None,
674
+ guidance_scale: float = 7.5,
675
+ negative_prompt: Optional[Union[str, List[str]]] = None,
676
+ num_videos_per_prompt: Optional[int] = 1,
677
+ eta: float = 0.0,
678
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
679
+ latents: Optional[torch.Tensor] = None,
680
+ prompt_embeds: Optional[torch.Tensor] = None,
681
+ attention_mask: Optional[torch.Tensor] = None,
682
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
683
+ negative_attention_mask: Optional[torch.Tensor] = None,
684
+ output_type: Optional[str] = "pil",
685
+ return_dict: bool = True,
686
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
687
+ guidance_rescale: float = 0.0,
688
+ clip_skip: Optional[int] = None,
689
+ callback_on_step_end: Optional[
690
+ Union[
691
+ Callable[[int, int, Dict], None],
692
+ PipelineCallback,
693
+ MultiPipelineCallbacks,
694
+ ]
695
+ ] = None,
696
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
697
+ freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
698
+ vae_ver: str = "88-4c-sd",
699
+ enable_tiling: bool = False,
700
+ n_tokens: Optional[int] = None,
701
+ embedded_guidance_scale: Optional[float] = None,
702
+ **kwargs,
703
+ ):
704
+ r"""
705
+ The call function to the pipeline for generation.
706
+
707
+ Args:
708
+ prompt (`str` or `List[str]`):
709
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
710
+ height (`int`):
711
+ The height in pixels of the generated image.
712
+ width (`int`):
713
+ The width in pixels of the generated image.
714
+ video_length (`int`):
715
+ The number of frames in the generated video.
716
+ num_inference_steps (`int`, *optional*, defaults to 50):
717
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
718
+ expense of slower inference.
719
+ timesteps (`List[int]`, *optional*):
720
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
721
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
722
+ passed will be used. Must be in descending order.
723
+ sigmas (`List[float]`, *optional*):
724
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
725
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
726
+ will be used.
727
+ guidance_scale (`float`, *optional*, defaults to 7.5):
728
+ A higher guidance scale value encourages the model to generate images closely linked to the text
729
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
730
+ negative_prompt (`str` or `List[str]`, *optional*):
731
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
732
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
733
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
734
+ The number of images to generate per prompt.
735
+ eta (`float`, *optional*, defaults to 0.0):
736
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
737
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
738
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
739
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
740
+ generation deterministic.
741
+ latents (`torch.Tensor`, *optional*):
742
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
743
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
744
+ tensor is generated by sampling using the supplied random `generator`.
745
+ prompt_embeds (`torch.Tensor`, *optional*):
746
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
747
+ provided, text embeddings are generated from the `prompt` input argument.
748
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
749
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
750
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
751
+
752
+ output_type (`str`, *optional*, defaults to `"pil"`):
753
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
754
+ return_dict (`bool`, *optional*, defaults to `True`):
755
+ Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a
756
+ plain tuple.
757
+ cross_attention_kwargs (`dict`, *optional*):
758
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
759
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
760
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
761
+ Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
762
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
763
+ using zero terminal SNR.
764
+ clip_skip (`int`, *optional*):
765
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
766
+ the output of the pre-final layer will be used for computing the prompt embeddings.
767
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
768
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
769
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
770
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
771
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
772
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
773
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
774
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
775
+ `._callback_tensor_inputs` attribute of your pipeline class.
776
+
777
+ Examples:
778
+
779
+ Returns:
780
+ [`~HunyuanVideoPipelineOutput`] or `tuple`:
781
+ If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned,
782
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
783
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
784
+ "not-safe-for-work" (nsfw) content.
785
+ """
786
+ callback = kwargs.pop("callback", None)
787
+ callback_steps = kwargs.pop("callback_steps", None)
788
+
789
+ if callback is not None:
790
+ deprecate(
791
+ "callback",
792
+ "1.0.0",
793
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
794
+ )
795
+ if callback_steps is not None:
796
+ deprecate(
797
+ "callback_steps",
798
+ "1.0.0",
799
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
800
+ )
801
+
802
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
803
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
804
+
805
+ # 0. Default height and width to unet
806
+ # height = height or self.transformer.config.sample_size * self.vae_scale_factor
807
+ # width = width or self.transformer.config.sample_size * self.vae_scale_factor
808
+ # to deal with lora scaling and other possible forward hooks
809
+
810
+ # 1. Check inputs. Raise error if not correct
811
+ self.check_inputs(
812
+ prompt,
813
+ height,
814
+ width,
815
+ video_length,
816
+ callback_steps,
817
+ negative_prompt,
818
+ prompt_embeds,
819
+ negative_prompt_embeds,
820
+ callback_on_step_end_tensor_inputs,
821
+ vae_ver=vae_ver,
822
+ )
823
+
824
+ self._guidance_scale = guidance_scale
825
+ self._guidance_rescale = guidance_rescale
826
+ self._clip_skip = clip_skip
827
+ self._cross_attention_kwargs = cross_attention_kwargs
828
+ self._interrupt = False
829
+
830
+ # 2. Define call parameters
831
+ if prompt is not None and isinstance(prompt, str):
832
+ batch_size = 1
833
+ elif prompt is not None and isinstance(prompt, list):
834
+ batch_size = len(prompt)
835
+ else:
836
+ batch_size = prompt_embeds.shape[0]
837
+
838
+ device = torch.device(f"cuda:{dist.get_rank()}") if dist.is_initialized() else self._execution_device
839
+
840
+ # 3. Encode input prompt
841
+ lora_scale = (
842
+ self.cross_attention_kwargs.get("scale", None)
843
+ if self.cross_attention_kwargs is not None
844
+ else None
845
+ )
846
+
847
+ (
848
+ prompt_embeds,
849
+ negative_prompt_embeds,
850
+ prompt_mask,
851
+ negative_prompt_mask,
852
+ ) = self.encode_prompt(
853
+ prompt,
854
+ device,
855
+ num_videos_per_prompt,
856
+ self.do_classifier_free_guidance,
857
+ negative_prompt,
858
+ prompt_embeds=prompt_embeds,
859
+ attention_mask=attention_mask,
860
+ negative_prompt_embeds=negative_prompt_embeds,
861
+ negative_attention_mask=negative_attention_mask,
862
+ lora_scale=lora_scale,
863
+ clip_skip=self.clip_skip,
864
+ data_type=data_type,
865
+ )
866
+ if self.text_encoder_2 is not None:
867
+ (
868
+ prompt_embeds_2,
869
+ negative_prompt_embeds_2,
870
+ prompt_mask_2,
871
+ negative_prompt_mask_2,
872
+ ) = self.encode_prompt(
873
+ prompt,
874
+ device,
875
+ num_videos_per_prompt,
876
+ self.do_classifier_free_guidance,
877
+ negative_prompt,
878
+ prompt_embeds=None,
879
+ attention_mask=None,
880
+ negative_prompt_embeds=None,
881
+ negative_attention_mask=None,
882
+ lora_scale=lora_scale,
883
+ clip_skip=self.clip_skip,
884
+ text_encoder=self.text_encoder_2,
885
+ data_type=data_type,
886
+ )
887
+ else:
888
+ prompt_embeds_2 = None
889
+ negative_prompt_embeds_2 = None
890
+ prompt_mask_2 = None
891
+ negative_prompt_mask_2 = None
892
+
893
+ # For classifier free guidance, we need to do two forward passes.
894
+ # Here we concatenate the unconditional and text embeddings into a single batch
895
+ # to avoid doing two forward passes
896
+ if self.do_classifier_free_guidance:
897
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
898
+ if prompt_mask is not None:
899
+ prompt_mask = torch.cat([negative_prompt_mask, prompt_mask])
900
+ if prompt_embeds_2 is not None:
901
+ prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2])
902
+ if prompt_mask_2 is not None:
903
+ prompt_mask_2 = torch.cat([negative_prompt_mask_2, prompt_mask_2])
904
+
905
+
906
+ # 4. Prepare timesteps
907
+ extra_set_timesteps_kwargs = self.prepare_extra_func_kwargs(
908
+ self.scheduler.set_timesteps, {"n_tokens": n_tokens}
909
+ )
910
+ timesteps, num_inference_steps = retrieve_timesteps(
911
+ self.scheduler,
912
+ num_inference_steps,
913
+ device,
914
+ timesteps,
915
+ sigmas,
916
+ **extra_set_timesteps_kwargs,
917
+ )
918
+
919
+ if "884" in vae_ver:
920
+ video_length = (video_length - 1) // 4 + 1
921
+ elif "888" in vae_ver:
922
+ video_length = (video_length - 1) // 8 + 1
923
+ else:
924
+ video_length = video_length
925
+
926
+ # 5. Prepare latent variables
927
+ num_channels_latents = self.transformer.config.in_channels
928
+ latents = self.prepare_latents(
929
+ batch_size * num_videos_per_prompt,
930
+ num_channels_latents,
931
+ height,
932
+ width,
933
+ video_length,
934
+ prompt_embeds.dtype,
935
+ device,
936
+ generator,
937
+ latents,
938
+ )
939
+
940
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
941
+ extra_step_kwargs = self.prepare_extra_func_kwargs(
942
+ self.scheduler.step,
943
+ {"generator": generator, "eta": eta},
944
+ )
945
+
946
+ target_dtype = PRECISION_TO_TYPE[self.args.precision]
947
+ autocast_enabled = (
948
+ target_dtype != torch.float32
949
+ ) and not self.args.disable_autocast
950
+ vae_dtype = PRECISION_TO_TYPE[self.args.vae_precision]
951
+ vae_autocast_enabled = (
952
+ vae_dtype != torch.float32
953
+ ) and not self.args.disable_autocast
954
+
955
+ # 7. Denoising loop
956
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
957
+ self._num_timesteps = len(timesteps)
958
+
959
+ # if is_progress_bar:
960
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
961
+ for i, t in enumerate(timesteps):
962
+ if self.interrupt:
963
+ continue
964
+ import os
965
+ if os.path.isfile("abort"):
966
+ continue
967
+
968
+ # expand the latents if we are doing classifier free guidance
969
+ latent_model_input = (
970
+ torch.cat([latents] * 2)
971
+ if self.do_classifier_free_guidance
972
+ else latents
973
+ )
974
+ latent_model_input = self.scheduler.scale_model_input(
975
+ latent_model_input, t
976
+ )
977
+
978
+ t_expand = t.repeat(latent_model_input.shape[0])
979
+ guidance_expand = (
980
+ torch.tensor(
981
+ [embedded_guidance_scale] * latent_model_input.shape[0],
982
+ dtype=torch.float32,
983
+ device=device,
984
+ ).to(target_dtype)
985
+ * 1000.0
986
+ if embedded_guidance_scale is not None
987
+ else None
988
+ )
989
+
990
+ # predict the noise residual
991
+ with torch.autocast(
992
+ device_type="cuda", dtype=target_dtype, enabled=autocast_enabled
993
+ ):
994
+ noise_pred = self.transformer( # For an input image (129, 192, 336) (1, 256, 256)
995
+ latent_model_input, # [2, 16, 33, 24, 42]
996
+ t_expand, # [2]
997
+ text_states=prompt_embeds, # [2, 256, 4096]
998
+ text_mask=prompt_mask, # [2, 256]
999
+ text_states_2=prompt_embeds_2, # [2, 768]
1000
+ freqs_cos=freqs_cis[0], # [seqlen, head_dim]
1001
+ freqs_sin=freqs_cis[1], # [seqlen, head_dim]
1002
+ guidance=guidance_expand,
1003
+ return_dict=True,
1004
+ )[
1005
+ "x"
1006
+ ]
1007
+
1008
+ # perform guidance
1009
+ if self.do_classifier_free_guidance:
1010
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1011
+ noise_pred = noise_pred_uncond + self.guidance_scale * (
1012
+ noise_pred_text - noise_pred_uncond
1013
+ )
1014
+
1015
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
1016
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1017
+ noise_pred = rescale_noise_cfg(
1018
+ noise_pred,
1019
+ noise_pred_text,
1020
+ guidance_rescale=self.guidance_rescale,
1021
+ )
1022
+
1023
+ # compute the previous noisy sample x_t -> x_t-1
1024
+ latents = self.scheduler.step(
1025
+ noise_pred, t, latents, **extra_step_kwargs, return_dict=False
1026
+ )[0]
1027
+
1028
+ if callback_on_step_end is not None:
1029
+ callback_kwargs = {}
1030
+ for k in callback_on_step_end_tensor_inputs:
1031
+ callback_kwargs[k] = locals()[k]
1032
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1033
+
1034
+ latents = callback_outputs.pop("latents", latents)
1035
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1036
+ negative_prompt_embeds = callback_outputs.pop(
1037
+ "negative_prompt_embeds", negative_prompt_embeds
1038
+ )
1039
+
1040
+ # call the callback, if provided
1041
+ if i == len(timesteps) - 1 or (
1042
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
1043
+ ):
1044
+ if progress_bar is not None:
1045
+ progress_bar.update()
1046
+ if callback is not None and i % callback_steps == 0:
1047
+ step_idx = i // getattr(self.scheduler, "order", 1)
1048
+ callback(step_idx, t, latents)
1049
+
1050
+ if not output_type == "latent":
1051
+ expand_temporal_dim = False
1052
+ if len(latents.shape) == 4:
1053
+ if isinstance(self.vae, AutoencoderKLCausal3D):
1054
+ latents = latents.unsqueeze(2)
1055
+ expand_temporal_dim = True
1056
+ elif len(latents.shape) == 5:
1057
+ pass
1058
+ else:
1059
+ raise ValueError(
1060
+ f"Only support latents with shape (b, c, h, w) or (b, c, f, h, w), but got {latents.shape}."
1061
+ )
1062
+
1063
+ if (
1064
+ hasattr(self.vae.config, "shift_factor")
1065
+ and self.vae.config.shift_factor
1066
+ ):
1067
+ latents = (
1068
+ latents / self.vae.config.scaling_factor
1069
+ + self.vae.config.shift_factor
1070
+ )
1071
+ else:
1072
+ latents = latents / self.vae.config.scaling_factor
1073
+
1074
+ with torch.autocast(
1075
+ device_type="cuda", dtype=vae_dtype, enabled=vae_autocast_enabled
1076
+ ):
1077
+ if enable_tiling:
1078
+ self.vae.enable_tiling()
1079
+ image = self.vae.decode(
1080
+ latents, return_dict=False, generator=generator
1081
+ )[0]
1082
+ else:
1083
+ image = self.vae.decode(
1084
+ latents, return_dict=False, generator=generator
1085
+ )[0]
1086
+
1087
+ if expand_temporal_dim or image.shape[2] == 1:
1088
+ image = image.squeeze(2)
1089
+
1090
+ else:
1091
+ image = latents
1092
+
1093
+ image = (image / 2 + 0.5).clamp(0, 1)
1094
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
1095
+ image = image.cpu().float()
1096
+
1097
+ # Offload all models
1098
+ self.maybe_free_model_hooks()
1099
+
1100
+ if not return_dict:
1101
+ return image
1102
+
1103
+ return HunyuanVideoPipelineOutput(videos=image)
hyvideo/diffusion/schedulers/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .scheduling_flow_match_discrete import FlowMatchDiscreteScheduler
hyvideo/diffusion/schedulers/scheduling_flow_match_discrete.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ #
16
+ # Modified from diffusers==0.29.2
17
+ #
18
+ # ==============================================================================
19
+
20
+ from dataclasses import dataclass
21
+ from typing import Optional, Tuple, Union
22
+
23
+ import numpy as np
24
+ import torch
25
+
26
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
27
+ from diffusers.utils import BaseOutput, logging
28
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
29
+
30
+
31
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
32
+
33
+
34
+ @dataclass
35
+ class FlowMatchDiscreteSchedulerOutput(BaseOutput):
36
+ """
37
+ Output class for the scheduler's `step` function output.
38
+
39
+ Args:
40
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
41
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
42
+ denoising loop.
43
+ """
44
+
45
+ prev_sample: torch.FloatTensor
46
+
47
+
48
+ class FlowMatchDiscreteScheduler(SchedulerMixin, ConfigMixin):
49
+ """
50
+ Euler scheduler.
51
+
52
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
53
+ methods the library implements for all schedulers such as loading and saving.
54
+
55
+ Args:
56
+ num_train_timesteps (`int`, defaults to 1000):
57
+ The number of diffusion steps to train the model.
58
+ timestep_spacing (`str`, defaults to `"linspace"`):
59
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
60
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
61
+ shift (`float`, defaults to 1.0):
62
+ The shift value for the timestep schedule.
63
+ reverse (`bool`, defaults to `True`):
64
+ Whether to reverse the timestep schedule.
65
+ """
66
+
67
+ _compatibles = []
68
+ order = 1
69
+
70
+ @register_to_config
71
+ def __init__(
72
+ self,
73
+ num_train_timesteps: int = 1000,
74
+ shift: float = 1.0,
75
+ reverse: bool = True,
76
+ solver: str = "euler",
77
+ n_tokens: Optional[int] = None,
78
+ ):
79
+ sigmas = torch.linspace(1, 0, num_train_timesteps + 1)
80
+
81
+ if not reverse:
82
+ sigmas = sigmas.flip(0)
83
+
84
+ self.sigmas = sigmas
85
+ # the value fed to model
86
+ self.timesteps = (sigmas[:-1] * num_train_timesteps).to(dtype=torch.float32)
87
+
88
+ self._step_index = None
89
+ self._begin_index = None
90
+
91
+ self.supported_solver = ["euler"]
92
+ if solver not in self.supported_solver:
93
+ raise ValueError(
94
+ f"Solver {solver} not supported. Supported solvers: {self.supported_solver}"
95
+ )
96
+
97
+ @property
98
+ def step_index(self):
99
+ """
100
+ The index counter for current timestep. It will increase 1 after each scheduler step.
101
+ """
102
+ return self._step_index
103
+
104
+ @property
105
+ def begin_index(self):
106
+ """
107
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
108
+ """
109
+ return self._begin_index
110
+
111
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
112
+ def set_begin_index(self, begin_index: int = 0):
113
+ """
114
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
115
+
116
+ Args:
117
+ begin_index (`int`):
118
+ The begin index for the scheduler.
119
+ """
120
+ self._begin_index = begin_index
121
+
122
+ def _sigma_to_t(self, sigma):
123
+ return sigma * self.config.num_train_timesteps
124
+
125
+ def set_timesteps(
126
+ self,
127
+ num_inference_steps: int,
128
+ device: Union[str, torch.device] = None,
129
+ n_tokens: int = None,
130
+ ):
131
+ """
132
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
133
+
134
+ Args:
135
+ num_inference_steps (`int`):
136
+ The number of diffusion steps used when generating samples with a pre-trained model.
137
+ device (`str` or `torch.device`, *optional*):
138
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
139
+ n_tokens (`int`, *optional*):
140
+ Number of tokens in the input sequence.
141
+ """
142
+ self.num_inference_steps = num_inference_steps
143
+
144
+ sigmas = torch.linspace(1, 0, num_inference_steps + 1)
145
+ sigmas = self.sd3_time_shift(sigmas)
146
+
147
+ if not self.config.reverse:
148
+ sigmas = 1 - sigmas
149
+
150
+ self.sigmas = sigmas
151
+ self.timesteps = (sigmas[:-1] * self.config.num_train_timesteps).to(
152
+ dtype=torch.float32, device=device
153
+ )
154
+
155
+ # Reset step index
156
+ self._step_index = None
157
+
158
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
159
+ if schedule_timesteps is None:
160
+ schedule_timesteps = self.timesteps
161
+
162
+ indices = (schedule_timesteps == timestep).nonzero()
163
+
164
+ # The sigma index that is taken for the **very** first `step`
165
+ # is always the second index (or the last index if there is only 1)
166
+ # This way we can ensure we don't accidentally skip a sigma in
167
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
168
+ pos = 1 if len(indices) > 1 else 0
169
+
170
+ return indices[pos].item()
171
+
172
+ def _init_step_index(self, timestep):
173
+ if self.begin_index is None:
174
+ if isinstance(timestep, torch.Tensor):
175
+ timestep = timestep.to(self.timesteps.device)
176
+ self._step_index = self.index_for_timestep(timestep)
177
+ else:
178
+ self._step_index = self._begin_index
179
+
180
+ def scale_model_input(
181
+ self, sample: torch.Tensor, timestep: Optional[int] = None
182
+ ) -> torch.Tensor:
183
+ return sample
184
+
185
+ def sd3_time_shift(self, t: torch.Tensor):
186
+ return (self.config.shift * t) / (1 + (self.config.shift - 1) * t)
187
+
188
+ def step(
189
+ self,
190
+ model_output: torch.FloatTensor,
191
+ timestep: Union[float, torch.FloatTensor],
192
+ sample: torch.FloatTensor,
193
+ return_dict: bool = True,
194
+ ) -> Union[FlowMatchDiscreteSchedulerOutput, Tuple]:
195
+ """
196
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
197
+ process from the learned model outputs (most often the predicted noise).
198
+
199
+ Args:
200
+ model_output (`torch.FloatTensor`):
201
+ The direct output from learned diffusion model.
202
+ timestep (`float`):
203
+ The current discrete timestep in the diffusion chain.
204
+ sample (`torch.FloatTensor`):
205
+ A current instance of a sample created by the diffusion process.
206
+ generator (`torch.Generator`, *optional*):
207
+ A random number generator.
208
+ n_tokens (`int`, *optional*):
209
+ Number of tokens in the input sequence.
210
+ return_dict (`bool`):
211
+ Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
212
+ tuple.
213
+
214
+ Returns:
215
+ [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
216
+ If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
217
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
218
+ """
219
+
220
+ if (
221
+ isinstance(timestep, int)
222
+ or isinstance(timestep, torch.IntTensor)
223
+ or isinstance(timestep, torch.LongTensor)
224
+ ):
225
+ raise ValueError(
226
+ (
227
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
228
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
229
+ " one of the `scheduler.timesteps` as a timestep."
230
+ ),
231
+ )
232
+
233
+ if self.step_index is None:
234
+ self._init_step_index(timestep)
235
+
236
+ # Upcast to avoid precision issues when computing prev_sample
237
+ sample = sample.to(torch.float32)
238
+
239
+ dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index]
240
+
241
+ if self.config.solver == "euler":
242
+ prev_sample = sample + model_output.to(torch.float32) * dt
243
+ else:
244
+ raise ValueError(
245
+ f"Solver {self.config.solver} not supported. Supported solvers: {self.supported_solver}"
246
+ )
247
+
248
+ # upon completion increase step index by one
249
+ self._step_index += 1
250
+
251
+ if not return_dict:
252
+ return (prev_sample,)
253
+
254
+ return FlowMatchDiscreteSchedulerOutput(prev_sample=prev_sample)
255
+
256
+ def __len__(self):
257
+ return self.config.num_train_timesteps
hyvideo/inference.py ADDED
@@ -0,0 +1,682 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import random
4
+ import functools
5
+ from typing import List, Optional, Tuple, Union
6
+
7
+ from pathlib import Path
8
+ from loguru import logger
9
+
10
+ import torch
11
+ import torch.distributed as dist
12
+ from hyvideo.constants import PROMPT_TEMPLATE, NEGATIVE_PROMPT, PRECISION_TO_TYPE
13
+ from hyvideo.vae import load_vae
14
+ from hyvideo.modules import load_model
15
+ from hyvideo.text_encoder import TextEncoder
16
+ from hyvideo.utils.data_utils import align_to
17
+ from hyvideo.modules.posemb_layers import get_nd_rotary_pos_embed
18
+ from hyvideo.diffusion.schedulers import FlowMatchDiscreteScheduler
19
+ from hyvideo.diffusion.pipelines import HunyuanVideoPipeline
20
+
21
+ try:
22
+ import xfuser
23
+ from xfuser.core.distributed import (
24
+ get_sequence_parallel_world_size,
25
+ get_sequence_parallel_rank,
26
+ get_sp_group,
27
+ initialize_model_parallel,
28
+ init_distributed_environment
29
+ )
30
+ except:
31
+ xfuser = None
32
+ get_sequence_parallel_world_size = None
33
+ get_sequence_parallel_rank = None
34
+ get_sp_group = None
35
+ initialize_model_parallel = None
36
+ init_distributed_environment = None
37
+
38
+
39
+ def parallelize_transformer(pipe):
40
+ transformer = pipe.transformer
41
+ original_forward = transformer.forward
42
+
43
+ @functools.wraps(transformer.__class__.forward)
44
+ def new_forward(
45
+ self,
46
+ x: torch.Tensor,
47
+ t: torch.Tensor, # Should be in range(0, 1000).
48
+ text_states: torch.Tensor = None,
49
+ text_mask: torch.Tensor = None, # Now we don't use it.
50
+ text_states_2: Optional[torch.Tensor] = None, # Text embedding for modulation.
51
+ freqs_cos: Optional[torch.Tensor] = None,
52
+ freqs_sin: Optional[torch.Tensor] = None,
53
+ guidance: torch.Tensor = None, # Guidance for modulation, should be cfg_scale x 1000.
54
+ return_dict: bool = True,
55
+ ):
56
+ if x.shape[-2] // 2 % get_sequence_parallel_world_size() == 0:
57
+ # try to split x by height
58
+ split_dim = -2
59
+ elif x.shape[-1] // 2 % get_sequence_parallel_world_size() == 0:
60
+ # try to split x by width
61
+ split_dim = -1
62
+ else:
63
+ raise ValueError(f"Cannot split video sequence into ulysses_degree x ring_degree ({get_sequence_parallel_world_size()}) parts evenly")
64
+
65
+ # patch sizes for the temporal, height, and width dimensions are 1, 2, and 2.
66
+ temporal_size, h, w = x.shape[2], x.shape[3] // 2, x.shape[4] // 2
67
+
68
+ x = torch.chunk(x, get_sequence_parallel_world_size(),dim=split_dim)[get_sequence_parallel_rank()]
69
+
70
+ dim_thw = freqs_cos.shape[-1]
71
+ freqs_cos = freqs_cos.reshape(temporal_size, h, w, dim_thw)
72
+ freqs_cos = torch.chunk(freqs_cos, get_sequence_parallel_world_size(),dim=split_dim - 1)[get_sequence_parallel_rank()]
73
+ freqs_cos = freqs_cos.reshape(-1, dim_thw)
74
+ dim_thw = freqs_sin.shape[-1]
75
+ freqs_sin = freqs_sin.reshape(temporal_size, h, w, dim_thw)
76
+ freqs_sin = torch.chunk(freqs_sin, get_sequence_parallel_world_size(),dim=split_dim - 1)[get_sequence_parallel_rank()]
77
+ freqs_sin = freqs_sin.reshape(-1, dim_thw)
78
+
79
+ from xfuser.core.long_ctx_attention import xFuserLongContextAttention
80
+
81
+ for block in transformer.double_blocks + transformer.single_blocks:
82
+ block.hybrid_seq_parallel_attn = xFuserLongContextAttention()
83
+
84
+ output = original_forward(
85
+ x,
86
+ t,
87
+ text_states,
88
+ text_mask,
89
+ text_states_2,
90
+ freqs_cos,
91
+ freqs_sin,
92
+ guidance,
93
+ return_dict,
94
+ )
95
+
96
+ return_dict = not isinstance(output, tuple)
97
+ sample = output["x"]
98
+ sample = get_sp_group().all_gather(sample, dim=split_dim)
99
+ output["x"] = sample
100
+ return output
101
+
102
+ new_forward = new_forward.__get__(transformer)
103
+ transformer.forward = new_forward
104
+
105
+
106
+ class Inference(object):
107
+ def __init__(
108
+ self,
109
+ args,
110
+ vae,
111
+ vae_kwargs,
112
+ text_encoder,
113
+ model,
114
+ text_encoder_2=None,
115
+ pipeline=None,
116
+ use_cpu_offload=False,
117
+ device=None,
118
+ logger=None,
119
+ parallel_args=None,
120
+ ):
121
+ self.vae = vae
122
+ self.vae_kwargs = vae_kwargs
123
+
124
+ self.text_encoder = text_encoder
125
+ self.text_encoder_2 = text_encoder_2
126
+
127
+ self.model = model
128
+ self.pipeline = pipeline
129
+ self.use_cpu_offload = use_cpu_offload
130
+
131
+ self.args = args
132
+ self.device = (
133
+ device
134
+ if device is not None
135
+ else "cuda"
136
+ if torch.cuda.is_available()
137
+ else "cpu"
138
+ )
139
+ self.logger = logger
140
+ self.parallel_args = parallel_args
141
+
142
+ @classmethod
143
+ def from_pretrained(cls, pretrained_model_path, text_encoder_path, args, device=None, **kwargs):
144
+ """
145
+ Initialize the Inference pipeline.
146
+
147
+ Args:
148
+ pretrained_model_path (str or pathlib.Path): The model path, including t2v, text encoder and vae checkpoints.
149
+ args (argparse.Namespace): The arguments for the pipeline.
150
+ device (int): The device for inference. Default is 0.
151
+ """
152
+ # ========================================================================
153
+ logger.info(f"Got text-to-video model root path: {pretrained_model_path}")
154
+
155
+ # ==================== Initialize Distributed Environment ================
156
+ if args.ulysses_degree > 1 or args.ring_degree > 1:
157
+ assert xfuser is not None, \
158
+ "Ulysses Attention and Ring Attention requires xfuser package."
159
+
160
+ assert args.use_cpu_offload is False, \
161
+ "Cannot enable use_cpu_offload in the distributed environment."
162
+
163
+ dist.init_process_group("nccl")
164
+
165
+ assert dist.get_world_size() == args.ring_degree * args.ulysses_degree, \
166
+ "number of GPUs should be equal to ring_degree * ulysses_degree."
167
+
168
+ init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size())
169
+
170
+ initialize_model_parallel(
171
+ sequence_parallel_degree=dist.get_world_size(),
172
+ ring_degree=args.ring_degree,
173
+ ulysses_degree=args.ulysses_degree,
174
+ )
175
+ device = torch.device(f"cuda:{os.environ['LOCAL_RANK']}")
176
+ else:
177
+ if device is None:
178
+ device = "cuda" if torch.cuda.is_available() else "cpu"
179
+
180
+ parallel_args = {"ulysses_degree": args.ulysses_degree, "ring_degree": args.ring_degree}
181
+
182
+ # ======================== Get the args path =============================
183
+
184
+ # Disable gradient
185
+ torch.set_grad_enabled(False)
186
+
187
+ # =========================== Build main model ===========================
188
+ logger.info("Building model...")
189
+ pinToMemory = kwargs.pop("pinToMemory")
190
+ partialPinning = kwargs.pop("partialPinning")
191
+ # factor_kwargs = {"device": device, "dtype": PRECISION_TO_TYPE[args.precision]}
192
+ factor_kwargs = kwargs | {"device": "meta", "dtype": PRECISION_TO_TYPE[args.precision]}
193
+ in_channels = args.latent_channels
194
+ out_channels = args.latent_channels
195
+
196
+ model = load_model(
197
+ args,
198
+ in_channels=in_channels,
199
+ out_channels=out_channels,
200
+ factor_kwargs=factor_kwargs,
201
+ )
202
+ # model = model.to(device)
203
+ # model = Inference.load_state_dict(args, model, pretrained_model_path)
204
+
205
+ logger.info(f"Loading torch model {pretrained_model_path}...")
206
+
207
+ from mmgp import offload
208
+ offload.load_model_data(model, pretrained_model_path, pinToMemory = pinToMemory, partialPinning = partialPinning)
209
+
210
+ model.eval()
211
+
212
+ # ============================= Build extra models ========================
213
+ # VAE
214
+ vae, _, s_ratio, t_ratio = load_vae(
215
+ args.vae,
216
+ args.vae_precision,
217
+ logger=logger,
218
+ device=device if not args.use_cpu_offload else "cpu",
219
+ )
220
+ vae_kwargs = {"s_ratio": s_ratio, "t_ratio": t_ratio}
221
+
222
+ # Text encoder
223
+ if args.prompt_template_video is not None:
224
+ crop_start = PROMPT_TEMPLATE[args.prompt_template_video].get(
225
+ "crop_start", 0
226
+ )
227
+ elif args.prompt_template is not None:
228
+ crop_start = PROMPT_TEMPLATE[args.prompt_template].get("crop_start", 0)
229
+ else:
230
+ crop_start = 0
231
+ max_length = args.text_len + crop_start
232
+
233
+ # prompt_template
234
+ prompt_template = (
235
+ PROMPT_TEMPLATE[args.prompt_template]
236
+ if args.prompt_template is not None
237
+ else None
238
+ )
239
+
240
+ # prompt_template_video
241
+ prompt_template_video = (
242
+ PROMPT_TEMPLATE[args.prompt_template_video]
243
+ if args.prompt_template_video is not None
244
+ else None
245
+ )
246
+
247
+ text_encoder = TextEncoder(
248
+ text_encoder_type=args.text_encoder,
249
+ max_length=max_length,
250
+ text_encoder_precision=args.text_encoder_precision,
251
+ tokenizer_type=args.tokenizer,
252
+ prompt_template=prompt_template,
253
+ prompt_template_video=prompt_template_video,
254
+ hidden_state_skip_layer=args.hidden_state_skip_layer,
255
+ apply_final_norm=args.apply_final_norm,
256
+ reproduce=args.reproduce,
257
+ logger=logger,
258
+ device=device if not args.use_cpu_offload else "cpu",
259
+ text_encoder_path = text_encoder_path
260
+ )
261
+ text_encoder_2 = None
262
+ if args.text_encoder_2 is not None:
263
+ text_encoder_2 = TextEncoder(
264
+ text_encoder_type=args.text_encoder_2,
265
+ max_length=args.text_len_2,
266
+ text_encoder_precision=args.text_encoder_precision_2,
267
+ tokenizer_type=args.tokenizer_2,
268
+ reproduce=args.reproduce,
269
+ logger=logger,
270
+ device=device if not args.use_cpu_offload else "cpu",
271
+ )
272
+
273
+ return cls(
274
+ args=args,
275
+ vae=vae,
276
+ vae_kwargs=vae_kwargs,
277
+ text_encoder=text_encoder,
278
+ text_encoder_2=text_encoder_2,
279
+ model=model,
280
+ use_cpu_offload=args.use_cpu_offload,
281
+ device=device,
282
+ logger=logger,
283
+ parallel_args=parallel_args
284
+ )
285
+
286
+ @staticmethod
287
+ def load_state_dict(args, model, pretrained_model_path):
288
+ load_key = args.load_key
289
+ dit_weight = Path(args.dit_weight)
290
+
291
+ if dit_weight is None:
292
+ model_dir = pretrained_model_path / f"t2v_{args.model_resolution}"
293
+ files = list(model_dir.glob("*.pt"))
294
+ if len(files) == 0:
295
+ raise ValueError(f"No model weights found in {model_dir}")
296
+ if str(files[0]).startswith("pytorch_model_"):
297
+ model_path = dit_weight / f"pytorch_model_{load_key}.pt"
298
+ bare_model = True
299
+ elif any(str(f).endswith("_model_states.pt") for f in files):
300
+ files = [f for f in files if str(f).endswith("_model_states.pt")]
301
+ model_path = files[0]
302
+ if len(files) > 1:
303
+ logger.warning(
304
+ f"Multiple model weights found in {dit_weight}, using {model_path}"
305
+ )
306
+ bare_model = False
307
+ else:
308
+ raise ValueError(
309
+ f"Invalid model path: {dit_weight} with unrecognized weight format: "
310
+ f"{list(map(str, files))}. When given a directory as --dit-weight, only "
311
+ f"`pytorch_model_*.pt`(provided by HunyuanDiT official) and "
312
+ f"`*_model_states.pt`(saved by deepspeed) can be parsed. If you want to load a "
313
+ f"specific weight file, please provide the full path to the file."
314
+ )
315
+ else:
316
+ if dit_weight.is_dir():
317
+ files = list(dit_weight.glob("*.pt"))
318
+ if len(files) == 0:
319
+ raise ValueError(f"No model weights found in {dit_weight}")
320
+ if str(files[0]).startswith("pytorch_model_"):
321
+ model_path = dit_weight / f"pytorch_model_{load_key}.pt"
322
+ bare_model = True
323
+ elif any(str(f).endswith("_model_states.pt") for f in files):
324
+ files = [f for f in files if str(f).endswith("_model_states.pt")]
325
+ model_path = files[0]
326
+ if len(files) > 1:
327
+ logger.warning(
328
+ f"Multiple model weights found in {dit_weight}, using {model_path}"
329
+ )
330
+ bare_model = False
331
+ else:
332
+ raise ValueError(
333
+ f"Invalid model path: {dit_weight} with unrecognized weight format: "
334
+ f"{list(map(str, files))}. When given a directory as --dit-weight, only "
335
+ f"`pytorch_model_*.pt`(provided by HunyuanDiT official) and "
336
+ f"`*_model_states.pt`(saved by deepspeed) can be parsed. If you want to load a "
337
+ f"specific weight file, please provide the full path to the file."
338
+ )
339
+ elif dit_weight.is_file():
340
+ model_path = dit_weight
341
+ bare_model = "unknown"
342
+ else:
343
+ raise ValueError(f"Invalid model path: {dit_weight}")
344
+
345
+ if not model_path.exists():
346
+ raise ValueError(f"model_path not exists: {model_path}")
347
+ logger.info(f"Loading torch model {model_path}...")
348
+ state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)
349
+
350
+ if bare_model == "unknown" and ("ema" in state_dict or "module" in state_dict):
351
+ bare_model = False
352
+ if bare_model is False:
353
+ if load_key in state_dict:
354
+ state_dict = state_dict[load_key]
355
+ else:
356
+ raise KeyError(
357
+ f"Missing key: `{load_key}` in the checkpoint: {model_path}. The keys in the checkpoint "
358
+ f"are: {list(state_dict.keys())}."
359
+ )
360
+ model.load_state_dict(state_dict, strict=True, assign = True )
361
+ return model
362
+
363
+ @staticmethod
364
+ def parse_size(size):
365
+ if isinstance(size, int):
366
+ size = [size]
367
+ if not isinstance(size, (list, tuple)):
368
+ raise ValueError(f"Size must be an integer or (height, width), got {size}.")
369
+ if len(size) == 1:
370
+ size = [size[0], size[0]]
371
+ if len(size) != 2:
372
+ raise ValueError(f"Size must be an integer or (height, width), got {size}.")
373
+ return size
374
+
375
+
376
+ class HunyuanVideoSampler(Inference):
377
+ def __init__(
378
+ self,
379
+ args,
380
+ vae,
381
+ vae_kwargs,
382
+ text_encoder,
383
+ model,
384
+ text_encoder_2=None,
385
+ pipeline=None,
386
+ use_cpu_offload=False,
387
+ device=0,
388
+ logger=None,
389
+ parallel_args=None
390
+ ):
391
+ super().__init__(
392
+ args,
393
+ vae,
394
+ vae_kwargs,
395
+ text_encoder,
396
+ model,
397
+ text_encoder_2=text_encoder_2,
398
+ pipeline=pipeline,
399
+ use_cpu_offload=use_cpu_offload,
400
+ device=device,
401
+ logger=logger,
402
+ parallel_args=parallel_args
403
+ )
404
+
405
+ self.pipeline = self.load_diffusion_pipeline(
406
+ args=args,
407
+ vae=self.vae,
408
+ text_encoder=self.text_encoder,
409
+ text_encoder_2=self.text_encoder_2,
410
+ model=self.model,
411
+ device=self.device,
412
+ )
413
+
414
+ self.default_negative_prompt = NEGATIVE_PROMPT
415
+
416
+ def load_diffusion_pipeline(
417
+ self,
418
+ args,
419
+ vae,
420
+ text_encoder,
421
+ text_encoder_2,
422
+ model,
423
+ scheduler=None,
424
+ device=None,
425
+ progress_bar_config=None,
426
+ data_type="video",
427
+ ):
428
+ """Load the denoising scheduler for inference."""
429
+ if scheduler is None:
430
+ if args.denoise_type == "flow":
431
+ scheduler = FlowMatchDiscreteScheduler(
432
+ shift=args.flow_shift,
433
+ reverse=args.flow_reverse,
434
+ solver=args.flow_solver,
435
+ )
436
+ else:
437
+ raise ValueError(f"Invalid denoise type {args.denoise_type}")
438
+
439
+ pipeline = HunyuanVideoPipeline(
440
+ vae=vae,
441
+ text_encoder=text_encoder,
442
+ text_encoder_2=text_encoder_2,
443
+ transformer=model,
444
+ scheduler=scheduler,
445
+ progress_bar_config=progress_bar_config,
446
+ args=args,
447
+ )
448
+ # if self.use_cpu_offload:
449
+ # pipeline.enable_sequential_cpu_offload()
450
+ # else:
451
+ # pipeline = pipeline.to(device)
452
+
453
+ return pipeline
454
+
455
+ def get_rotary_pos_embed(self, video_length, height, width):
456
+ target_ndim = 3
457
+ ndim = 5 - 2
458
+ # 884
459
+ if "884" in self.args.vae:
460
+ latents_size = [(video_length - 1) // 4 + 1, height // 8, width // 8]
461
+ elif "888" in self.args.vae:
462
+ latents_size = [(video_length - 1) // 8 + 1, height // 8, width // 8]
463
+ else:
464
+ latents_size = [video_length, height // 8, width // 8]
465
+
466
+ if isinstance(self.model.patch_size, int):
467
+ assert all(s % self.model.patch_size == 0 for s in latents_size), (
468
+ f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), "
469
+ f"but got {latents_size}."
470
+ )
471
+ rope_sizes = [s // self.model.patch_size for s in latents_size]
472
+ elif isinstance(self.model.patch_size, list):
473
+ assert all(
474
+ s % self.model.patch_size[idx] == 0
475
+ for idx, s in enumerate(latents_size)
476
+ ), (
477
+ f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), "
478
+ f"but got {latents_size}."
479
+ )
480
+ rope_sizes = [
481
+ s // self.model.patch_size[idx] for idx, s in enumerate(latents_size)
482
+ ]
483
+
484
+ if len(rope_sizes) != target_ndim:
485
+ rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis
486
+ head_dim = self.model.hidden_size // self.model.heads_num
487
+ rope_dim_list = self.model.rope_dim_list
488
+ if rope_dim_list is None:
489
+ rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
490
+ assert (
491
+ sum(rope_dim_list) == head_dim
492
+ ), "sum(rope_dim_list) should equal to head_dim of attention layer"
493
+ freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
494
+ rope_dim_list,
495
+ rope_sizes,
496
+ theta=self.args.rope_theta,
497
+ use_real=True,
498
+ theta_rescale_factor=1,
499
+ )
500
+ return freqs_cos, freqs_sin
501
+
502
+ @torch.no_grad()
503
+ def predict(
504
+ self,
505
+ prompt,
506
+ height=192,
507
+ width=336,
508
+ video_length=129,
509
+ seed=None,
510
+ negative_prompt=None,
511
+ infer_steps=50,
512
+ guidance_scale=6,
513
+ flow_shift=5.0,
514
+ embedded_guidance_scale=None,
515
+ batch_size=1,
516
+ num_videos_per_prompt=1,
517
+ **kwargs,
518
+ ):
519
+ """
520
+ Predict the image/video from the given text.
521
+
522
+ Args:
523
+ prompt (str or List[str]): The input text.
524
+ kwargs:
525
+ height (int): The height of the output video. Default is 192.
526
+ width (int): The width of the output video. Default is 336.
527
+ video_length (int): The frame number of the output video. Default is 129.
528
+ seed (int or List[str]): The random seed for the generation. Default is a random integer.
529
+ negative_prompt (str or List[str]): The negative text prompt. Default is an empty string.
530
+ guidance_scale (float): The guidance scale for the generation. Default is 6.0.
531
+ num_images_per_prompt (int): The number of images per prompt. Default is 1.
532
+ infer_steps (int): The number of inference steps. Default is 100.
533
+ """
534
+ if self.parallel_args['ulysses_degree'] > 1 or self.parallel_args['ring_degree'] > 1:
535
+ assert seed is not None, \
536
+ "You have to set a seed in the distributed environment, please rerun with --seed <your-seed>."
537
+
538
+ parallelize_transformer(self.pipeline)
539
+
540
+ out_dict = dict()
541
+
542
+ # ========================================================================
543
+ # Arguments: seed
544
+ # ========================================================================
545
+ if isinstance(seed, torch.Tensor):
546
+ seed = seed.tolist()
547
+ if seed is None:
548
+ seeds = [
549
+ random.randint(0, 1_000_000)
550
+ for _ in range(batch_size * num_videos_per_prompt)
551
+ ]
552
+ elif isinstance(seed, int):
553
+ seeds = [
554
+ seed + i
555
+ for _ in range(batch_size)
556
+ for i in range(num_videos_per_prompt)
557
+ ]
558
+ elif isinstance(seed, (list, tuple)):
559
+ if len(seed) == batch_size:
560
+ seeds = [
561
+ int(seed[i]) + j
562
+ for i in range(batch_size)
563
+ for j in range(num_videos_per_prompt)
564
+ ]
565
+ elif len(seed) == batch_size * num_videos_per_prompt:
566
+ seeds = [int(s) for s in seed]
567
+ else:
568
+ raise ValueError(
569
+ f"Length of seed must be equal to number of prompt(batch_size) or "
570
+ f"batch_size * num_videos_per_prompt ({batch_size} * {num_videos_per_prompt}), got {seed}."
571
+ )
572
+ else:
573
+ raise ValueError(
574
+ f"Seed must be an integer, a list of integers, or None, got {seed}."
575
+ )
576
+ generator = [torch.Generator(self.device).manual_seed(seed) for seed in seeds]
577
+ out_dict["seeds"] = seeds
578
+
579
+ # ========================================================================
580
+ # Arguments: target_width, target_height, target_video_length
581
+ # ========================================================================
582
+ if width <= 0 or height <= 0 or video_length <= 0:
583
+ raise ValueError(
584
+ f"`height` and `width` and `video_length` must be positive integers, got height={height}, width={width}, video_length={video_length}"
585
+ )
586
+ if (video_length - 1) % 4 != 0:
587
+ raise ValueError(
588
+ f"`video_length-1` must be a multiple of 4, got {video_length}"
589
+ )
590
+
591
+ logger.info(
592
+ f"Input (height, width, video_length) = ({height}, {width}, {video_length})"
593
+ )
594
+
595
+ target_height = align_to(height, 16)
596
+ target_width = align_to(width, 16)
597
+ target_video_length = video_length
598
+
599
+ out_dict["size"] = (target_height, target_width, target_video_length)
600
+
601
+ # ========================================================================
602
+ # Arguments: prompt, new_prompt, negative_prompt
603
+ # ========================================================================
604
+ if not isinstance(prompt, str):
605
+ raise TypeError(f"`prompt` must be a string, but got {type(prompt)}")
606
+ prompt = [prompt.strip()]
607
+
608
+ # negative prompt
609
+ if negative_prompt is None or negative_prompt == "":
610
+ negative_prompt = self.default_negative_prompt
611
+ if not isinstance(negative_prompt, str):
612
+ raise TypeError(
613
+ f"`negative_prompt` must be a string, but got {type(negative_prompt)}"
614
+ )
615
+ negative_prompt = [negative_prompt.strip()]
616
+
617
+ # ========================================================================
618
+ # Scheduler
619
+ # ========================================================================
620
+ scheduler = FlowMatchDiscreteScheduler(
621
+ shift=flow_shift,
622
+ reverse=self.args.flow_reverse,
623
+ solver=self.args.flow_solver
624
+ )
625
+ self.pipeline.scheduler = scheduler
626
+
627
+ # ========================================================================
628
+ # Build Rope freqs
629
+ # ========================================================================
630
+ freqs_cos, freqs_sin = self.get_rotary_pos_embed(
631
+ target_video_length, target_height, target_width
632
+ )
633
+ n_tokens = freqs_cos.shape[0]
634
+
635
+ # ========================================================================
636
+ # Print infer args
637
+ # ========================================================================
638
+ debug_str = f"""
639
+ height: {target_height}
640
+ width: {target_width}
641
+ video_length: {target_video_length}
642
+ prompt: {prompt}
643
+ neg_prompt: {negative_prompt}
644
+ seed: {seed}
645
+ infer_steps: {infer_steps}
646
+ num_videos_per_prompt: {num_videos_per_prompt}
647
+ guidance_scale: {guidance_scale}
648
+ n_tokens: {n_tokens}
649
+ flow_shift: {flow_shift}
650
+ embedded_guidance_scale: {embedded_guidance_scale}"""
651
+ logger.debug(debug_str)
652
+
653
+ # ========================================================================
654
+ # Pipeline inference
655
+ # ========================================================================
656
+ start_time = time.time()
657
+ samples = self.pipeline(
658
+ prompt=prompt,
659
+ height=target_height,
660
+ width=target_width,
661
+ video_length=target_video_length,
662
+ num_inference_steps=infer_steps,
663
+ guidance_scale=guidance_scale,
664
+ negative_prompt=negative_prompt,
665
+ num_videos_per_prompt=num_videos_per_prompt,
666
+ generator=generator,
667
+ output_type="pil",
668
+ freqs_cis=(freqs_cos, freqs_sin),
669
+ n_tokens=n_tokens,
670
+ embedded_guidance_scale=embedded_guidance_scale,
671
+ data_type="video" if target_video_length > 1 else "image",
672
+ is_progress_bar=True,
673
+ vae_ver=self.args.vae,
674
+ enable_tiling=self.args.vae_tiling,
675
+ )[0]
676
+ out_dict["samples"] = samples
677
+ out_dict["prompts"] = prompt
678
+
679
+ gen_time = time.time() - start_time
680
+ logger.info(f"Success, time: {gen_time}")
681
+
682
+ return out_dict
hyvideo/modules/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .models import HYVideoDiffusionTransformer, HUNYUAN_VIDEO_CONFIG
2
+
3
+
4
+ def load_model(args, in_channels, out_channels, factor_kwargs):
5
+ """load hunyuan video model
6
+
7
+ Args:
8
+ args (dict): model args
9
+ in_channels (int): input channels number
10
+ out_channels (int): output channels number
11
+ factor_kwargs (dict): factor kwargs
12
+
13
+ Returns:
14
+ model (nn.Module): The hunyuan video model
15
+ """
16
+ if args.model in HUNYUAN_VIDEO_CONFIG.keys():
17
+ model = HYVideoDiffusionTransformer(
18
+ args,
19
+ in_channels=in_channels,
20
+ out_channels=out_channels,
21
+ **HUNYUAN_VIDEO_CONFIG[args.model],
22
+ **factor_kwargs,
23
+ )
24
+ return model
25
+ else:
26
+ raise NotImplementedError()
hyvideo/modules/activation_layers.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ def get_activation_layer(act_type):
5
+ """get activation layer
6
+
7
+ Args:
8
+ act_type (str): the activation type
9
+
10
+ Returns:
11
+ torch.nn.functional: the activation layer
12
+ """
13
+ if act_type == "gelu":
14
+ return lambda: nn.GELU()
15
+ elif act_type == "gelu_tanh":
16
+ # Approximate `tanh` requires torch >= 1.13
17
+ return lambda: nn.GELU(approximate="tanh")
18
+ elif act_type == "relu":
19
+ return nn.ReLU
20
+ elif act_type == "silu":
21
+ return nn.SiLU
22
+ else:
23
+ raise ValueError(f"Unknown activation type: {act_type}")
hyvideo/modules/attenion.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib.metadata
2
+ import math
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ try:
9
+ import flash_attn
10
+ from flash_attn.flash_attn_interface import _flash_attn_forward
11
+ from flash_attn.flash_attn_interface import flash_attn_varlen_func
12
+ except ImportError:
13
+ flash_attn = None
14
+ flash_attn_varlen_func = None
15
+ _flash_attn_forward = None
16
+
17
+ try:
18
+ from sageattention import sageattn_varlen
19
+ def sageattn_varlen_wrapper(
20
+ q,
21
+ k,
22
+ v,
23
+ cu_seqlens_q,
24
+ cu_seqlens_kv,
25
+ max_seqlen_q,
26
+ max_seqlen_kv,
27
+ ):
28
+ return sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
29
+ except ImportError:
30
+ sageattn_varlen_wrapper = None
31
+
32
+ MEMORY_LAYOUT = {
33
+ "sdpa": (
34
+ lambda x: x.transpose(1, 2),
35
+ lambda x: x.transpose(1, 2),
36
+ ),
37
+ "sage": (
38
+ lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
39
+ lambda x: x,
40
+ ),
41
+ "flash": (
42
+ lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
43
+ lambda x: x,
44
+ ),
45
+ "torch": (
46
+ lambda x: x.transpose(1, 2),
47
+ lambda x: x.transpose(1, 2),
48
+ ),
49
+ "vanilla": (
50
+ lambda x: x.transpose(1, 2),
51
+ lambda x: x.transpose(1, 2),
52
+ ),
53
+ }
54
+
55
+
56
+ def get_cu_seqlens(text_mask, img_len):
57
+ """Calculate cu_seqlens_q, cu_seqlens_kv using text_mask and img_len
58
+
59
+ Args:
60
+ text_mask (torch.Tensor): the mask of text
61
+ img_len (int): the length of image
62
+
63
+ Returns:
64
+ torch.Tensor: the calculated cu_seqlens for flash attention
65
+ """
66
+ batch_size = text_mask.shape[0]
67
+ text_len = text_mask.sum(dim=1)
68
+ max_len = text_mask.shape[1] + img_len
69
+
70
+ cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda")
71
+
72
+ for i in range(batch_size):
73
+ s = text_len[i] + img_len
74
+ s1 = i * max_len + s
75
+ s2 = (i + 1) * max_len
76
+ cu_seqlens[2 * i + 1] = s1
77
+ cu_seqlens[2 * i + 2] = s2
78
+
79
+ return cu_seqlens
80
+
81
+
82
+ def attention(
83
+ q,
84
+ k,
85
+ v,
86
+ mode="flash",
87
+ drop_rate=0,
88
+ attn_mask=None,
89
+ causal=False,
90
+ cu_seqlens_q=None,
91
+ cu_seqlens_kv=None,
92
+ max_seqlen_q=None,
93
+ max_seqlen_kv=None,
94
+ batch_size=1,
95
+ ):
96
+ """
97
+ Perform QKV self attention.
98
+
99
+ Args:
100
+ q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.
101
+ k (torch.Tensor): Key tensor with shape [b, s1, a, d]
102
+ v (torch.Tensor): Value tensor with shape [b, s1, a, d]
103
+ mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
104
+ drop_rate (float): Dropout rate in attention map. (default: 0)
105
+ attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).
106
+ (default: None)
107
+ causal (bool): Whether to use causal attention. (default: False)
108
+ cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
109
+ used to index into q.
110
+ cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
111
+ used to index into kv.
112
+ max_seqlen_q (int): The maximum sequence length in the batch of q.
113
+ max_seqlen_kv (int): The maximum sequence length in the batch of k and v.
114
+
115
+ Returns:
116
+ torch.Tensor: Output tensor after self attention with shape [b, s, ad]
117
+ """
118
+ pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
119
+ q = pre_attn_layout(q)
120
+ k = pre_attn_layout(k)
121
+ v = pre_attn_layout(v)
122
+
123
+ if mode == "torch":
124
+ if attn_mask is not None and attn_mask.dtype != torch.bool:
125
+ attn_mask = attn_mask.to(q.dtype)
126
+ x = F.scaled_dot_product_attention(
127
+ q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal
128
+ )
129
+
130
+ elif mode == "sdpa":
131
+ if attn_mask is not None and attn_mask.dtype != torch.bool:
132
+ attn_mask = attn_mask.to(q.dtype)
133
+ x = F.scaled_dot_product_attention(
134
+ q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal
135
+ )
136
+
137
+ elif mode == "sage":
138
+ x = sageattn_varlen_wrapper(
139
+ q,
140
+ k,
141
+ v,
142
+ cu_seqlens_q,
143
+ cu_seqlens_kv,
144
+ max_seqlen_q,
145
+ max_seqlen_kv,
146
+ )
147
+ # x with shape [(bxs), a, d]
148
+ x = x.view(
149
+ batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]
150
+ ) # reshape x to [b, s, a, d]
151
+
152
+ elif mode == "flash":
153
+ x = flash_attn_varlen_func(
154
+ q,
155
+ k,
156
+ v,
157
+ cu_seqlens_q,
158
+ cu_seqlens_kv,
159
+ max_seqlen_q,
160
+ max_seqlen_kv,
161
+ )
162
+ # x with shape [(bxs), a, d]
163
+ x = x.view(
164
+ batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]
165
+ ) # reshape x to [b, s, a, d]
166
+ elif mode == "vanilla":
167
+ scale_factor = 1 / math.sqrt(q.size(-1))
168
+
169
+ b, a, s, _ = q.shape
170
+ s1 = k.size(2)
171
+ attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
172
+ if causal:
173
+ # Only applied to self attention
174
+ assert (
175
+ attn_mask is None
176
+ ), "Causal mask and attn_mask cannot be used together"
177
+ temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(
178
+ diagonal=0
179
+ )
180
+ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
181
+ attn_bias.to(q.dtype)
182
+
183
+ if attn_mask is not None:
184
+ if attn_mask.dtype == torch.bool:
185
+ attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
186
+ else:
187
+ attn_bias += attn_mask
188
+
189
+ # TODO: Maybe force q and k to be float32 to avoid numerical overflow
190
+ attn = (q @ k.transpose(-2, -1)) * scale_factor
191
+ attn += attn_bias
192
+ attn = attn.softmax(dim=-1)
193
+ attn = torch.dropout(attn, p=drop_rate, train=True)
194
+ x = attn @ v
195
+ else:
196
+ raise NotImplementedError(f"Unsupported attention mode: {mode}")
197
+
198
+ x = post_attn_layout(x)
199
+ b, s, a, d = x.shape
200
+ out = x.reshape(b, s, -1)
201
+ return out
202
+
203
+
204
+ def parallel_attention(
205
+ hybrid_seq_parallel_attn,
206
+ q,
207
+ k,
208
+ v,
209
+ img_q_len,
210
+ img_kv_len,
211
+ cu_seqlens_q,
212
+ cu_seqlens_kv
213
+ ):
214
+ attn1 = hybrid_seq_parallel_attn(
215
+ None,
216
+ q[:, :img_q_len, :, :],
217
+ k[:, :img_kv_len, :, :],
218
+ v[:, :img_kv_len, :, :],
219
+ dropout_p=0.0,
220
+ causal=False,
221
+ joint_tensor_query=q[:,img_q_len:cu_seqlens_q[1]],
222
+ joint_tensor_key=k[:,img_kv_len:cu_seqlens_kv[1]],
223
+ joint_tensor_value=v[:,img_kv_len:cu_seqlens_kv[1]],
224
+ joint_strategy="rear",
225
+ )
226
+ if flash_attn.__version__ >= '2.7.0':
227
+ attn2, *_ = _flash_attn_forward(
228
+ q[:,cu_seqlens_q[1]:],
229
+ k[:,cu_seqlens_kv[1]:],
230
+ v[:,cu_seqlens_kv[1]:],
231
+ dropout_p=0.0,
232
+ softmax_scale=q.shape[-1] ** (-0.5),
233
+ causal=False,
234
+ window_size_left=-1,
235
+ window_size_right=-1,
236
+ softcap=0.0,
237
+ alibi_slopes=None,
238
+ return_softmax=False,
239
+ )
240
+ else:
241
+ attn2, *_ = _flash_attn_forward(
242
+ q[:,cu_seqlens_q[1]:],
243
+ k[:,cu_seqlens_kv[1]:],
244
+ v[:,cu_seqlens_kv[1]:],
245
+ dropout_p=0.0,
246
+ softmax_scale=q.shape[-1] ** (-0.5),
247
+ causal=False,
248
+ window_size=(-1, -1),
249
+ softcap=0.0,
250
+ alibi_slopes=None,
251
+ return_softmax=False,
252
+ )
253
+ attn = torch.cat([attn1, attn2], dim=1)
254
+ b, s, a, d = attn.shape
255
+ attn = attn.reshape(b, s, -1)
256
+
257
+ return attn
hyvideo/modules/embed_layers.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ from einops import rearrange, repeat
5
+
6
+ from ..utils.helpers import to_2tuple
7
+
8
+
9
+ class PatchEmbed(nn.Module):
10
+ """2D Image to Patch Embedding
11
+
12
+ Image to Patch Embedding using Conv2d
13
+
14
+ A convolution based approach to patchifying a 2D image w/ embedding projection.
15
+
16
+ Based on the impl in https://github.com/google-research/vision_transformer
17
+
18
+ Hacked together by / Copyright 2020 Ross Wightman
19
+
20
+ Remove the _assert function in forward function to be compatible with multi-resolution images.
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ patch_size=16,
26
+ in_chans=3,
27
+ embed_dim=768,
28
+ norm_layer=None,
29
+ flatten=True,
30
+ bias=True,
31
+ dtype=None,
32
+ device=None,
33
+ ):
34
+ factory_kwargs = {"dtype": dtype, "device": device}
35
+ super().__init__()
36
+ patch_size = to_2tuple(patch_size)
37
+ self.patch_size = patch_size
38
+ self.flatten = flatten
39
+
40
+ self.proj = nn.Conv3d(
41
+ in_chans,
42
+ embed_dim,
43
+ kernel_size=patch_size,
44
+ stride=patch_size,
45
+ bias=bias,
46
+ **factory_kwargs
47
+ )
48
+ nn.init.xavier_uniform_(self.proj.weight.view(self.proj.weight.size(0), -1))
49
+ if bias:
50
+ nn.init.zeros_(self.proj.bias)
51
+
52
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
53
+
54
+ def forward(self, x):
55
+ x = self.proj(x)
56
+ if self.flatten:
57
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
58
+ x = self.norm(x)
59
+ return x
60
+
61
+
62
+ class TextProjection(nn.Module):
63
+ """
64
+ Projects text embeddings. Also handles dropout for classifier-free guidance.
65
+
66
+ Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
67
+ """
68
+
69
+ def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None):
70
+ factory_kwargs = {"dtype": dtype, "device": device}
71
+ super().__init__()
72
+ self.linear_1 = nn.Linear(
73
+ in_features=in_channels,
74
+ out_features=hidden_size,
75
+ bias=True,
76
+ **factory_kwargs
77
+ )
78
+ self.act_1 = act_layer()
79
+ self.linear_2 = nn.Linear(
80
+ in_features=hidden_size,
81
+ out_features=hidden_size,
82
+ bias=True,
83
+ **factory_kwargs
84
+ )
85
+
86
+ def forward(self, caption):
87
+ hidden_states = self.linear_1(caption)
88
+ hidden_states = self.act_1(hidden_states)
89
+ hidden_states = self.linear_2(hidden_states)
90
+ return hidden_states
91
+
92
+
93
+ def timestep_embedding(t, dim, max_period=10000):
94
+ """
95
+ Create sinusoidal timestep embeddings.
96
+
97
+ Args:
98
+ t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional.
99
+ dim (int): the dimension of the output.
100
+ max_period (int): controls the minimum frequency of the embeddings.
101
+
102
+ Returns:
103
+ embedding (torch.Tensor): An (N, D) Tensor of positional embeddings.
104
+
105
+ .. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
106
+ """
107
+ half = dim // 2
108
+ freqs = torch.exp(
109
+ -math.log(max_period)
110
+ * torch.arange(start=0, end=half, dtype=torch.float32)
111
+ / half
112
+ ).to(device=t.device)
113
+ args = t[:, None].float() * freqs[None]
114
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
115
+ if dim % 2:
116
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
117
+ return embedding
118
+
119
+
120
+ class TimestepEmbedder(nn.Module):
121
+ """
122
+ Embeds scalar timesteps into vector representations.
123
+ """
124
+
125
+ def __init__(
126
+ self,
127
+ hidden_size,
128
+ act_layer,
129
+ frequency_embedding_size=256,
130
+ max_period=10000,
131
+ out_size=None,
132
+ dtype=None,
133
+ device=None,
134
+ ):
135
+ factory_kwargs = {"dtype": dtype, "device": device}
136
+ super().__init__()
137
+ self.frequency_embedding_size = frequency_embedding_size
138
+ self.max_period = max_period
139
+ if out_size is None:
140
+ out_size = hidden_size
141
+
142
+ self.mlp = nn.Sequential(
143
+ nn.Linear(
144
+ frequency_embedding_size, hidden_size, bias=True, **factory_kwargs
145
+ ),
146
+ act_layer(),
147
+ nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs),
148
+ )
149
+ nn.init.normal_(self.mlp[0].weight, std=0.02)
150
+ nn.init.normal_(self.mlp[2].weight, std=0.02)
151
+
152
+ def forward(self, t):
153
+ t_freq = timestep_embedding(
154
+ t, self.frequency_embedding_size, self.max_period
155
+ ).type(self.mlp[0].weight.dtype)
156
+ t_emb = self.mlp(t_freq)
157
+ return t_emb
hyvideo/modules/mlp_layers.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from timm library:
2
+ # https://github.com/huggingface/pytorch-image-models/blob/648aaa41233ba83eb38faf5ba9d415d574823241/timm/layers/mlp.py#L13
3
+
4
+ from functools import partial
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from .modulate_layers import modulate
10
+ from ..utils.helpers import to_2tuple
11
+
12
+
13
+ class MLP(nn.Module):
14
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
15
+
16
+ def __init__(
17
+ self,
18
+ in_channels,
19
+ hidden_channels=None,
20
+ out_features=None,
21
+ act_layer=nn.GELU,
22
+ norm_layer=None,
23
+ bias=True,
24
+ drop=0.0,
25
+ use_conv=False,
26
+ device=None,
27
+ dtype=None,
28
+ ):
29
+ factory_kwargs = {"device": device, "dtype": dtype}
30
+ super().__init__()
31
+ out_features = out_features or in_channels
32
+ hidden_channels = hidden_channels or in_channels
33
+ bias = to_2tuple(bias)
34
+ drop_probs = to_2tuple(drop)
35
+ linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
36
+
37
+ self.fc1 = linear_layer(
38
+ in_channels, hidden_channels, bias=bias[0], **factory_kwargs
39
+ )
40
+ self.act = act_layer()
41
+ self.drop1 = nn.Dropout(drop_probs[0])
42
+ self.norm = (
43
+ norm_layer(hidden_channels, **factory_kwargs)
44
+ if norm_layer is not None
45
+ else nn.Identity()
46
+ )
47
+ self.fc2 = linear_layer(
48
+ hidden_channels, out_features, bias=bias[1], **factory_kwargs
49
+ )
50
+ self.drop2 = nn.Dropout(drop_probs[1])
51
+
52
+ def forward(self, x):
53
+ x = self.fc1(x)
54
+ x = self.act(x)
55
+ x = self.drop1(x)
56
+ x = self.norm(x)
57
+ x = self.fc2(x)
58
+ x = self.drop2(x)
59
+ return x
60
+
61
+
62
+ #
63
+ class MLPEmbedder(nn.Module):
64
+ """copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py"""
65
+ def __init__(self, in_dim: int, hidden_dim: int, device=None, dtype=None):
66
+ factory_kwargs = {"device": device, "dtype": dtype}
67
+ super().__init__()
68
+ self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True, **factory_kwargs)
69
+ self.silu = nn.SiLU()
70
+ self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True, **factory_kwargs)
71
+
72
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
73
+ return self.out_layer(self.silu(self.in_layer(x)))
74
+
75
+
76
+ class FinalLayer(nn.Module):
77
+ """The final layer of DiT."""
78
+
79
+ def __init__(
80
+ self, hidden_size, patch_size, out_channels, act_layer, device=None, dtype=None
81
+ ):
82
+ factory_kwargs = {"device": device, "dtype": dtype}
83
+ super().__init__()
84
+
85
+ # Just use LayerNorm for the final layer
86
+ self.norm_final = nn.LayerNorm(
87
+ hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
88
+ )
89
+ if isinstance(patch_size, int):
90
+ self.linear = nn.Linear(
91
+ hidden_size,
92
+ patch_size * patch_size * out_channels,
93
+ bias=True,
94
+ **factory_kwargs
95
+ )
96
+ else:
97
+ self.linear = nn.Linear(
98
+ hidden_size,
99
+ patch_size[0] * patch_size[1] * patch_size[2] * out_channels,
100
+ bias=True,
101
+ )
102
+ nn.init.zeros_(self.linear.weight)
103
+ nn.init.zeros_(self.linear.bias)
104
+
105
+ # Here we don't distinguish between the modulate types. Just use the simple one.
106
+ self.adaLN_modulation = nn.Sequential(
107
+ act_layer(),
108
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
109
+ )
110
+ # Zero-initialize the modulation
111
+ nn.init.zeros_(self.adaLN_modulation[1].weight)
112
+ nn.init.zeros_(self.adaLN_modulation[1].bias)
113
+
114
+ def forward(self, x, c):
115
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
116
+ x = modulate(self.norm_final(x), shift=shift, scale=scale)
117
+ x = self.linear(x)
118
+ return x
hyvideo/modules/models.py ADDED
@@ -0,0 +1,870 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, List, Tuple, Optional, Union, Dict
2
+ from einops import rearrange
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from diffusers.models import ModelMixin
9
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
10
+
11
+ from .activation_layers import get_activation_layer
12
+ from .norm_layers import get_norm_layer
13
+ from .embed_layers import TimestepEmbedder, PatchEmbed, TextProjection
14
+ from .attenion import attention, parallel_attention, get_cu_seqlens
15
+ from .posemb_layers import apply_rotary_emb
16
+ from .mlp_layers import MLP, MLPEmbedder, FinalLayer
17
+ from .modulate_layers import ModulateDiT, modulate, apply_gate
18
+ from .token_refiner import SingleTokenRefiner
19
+ import numpy as np
20
+
21
+ class MMDoubleStreamBlock(nn.Module):
22
+ """
23
+ A multimodal dit block with seperate modulation for
24
+ text and image/video, see more details (SD3): https://arxiv.org/abs/2403.03206
25
+ (Flux.1): https://github.com/black-forest-labs/flux
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ hidden_size: int,
31
+ heads_num: int,
32
+ mlp_width_ratio: float,
33
+ mlp_act_type: str = "gelu_tanh",
34
+ qk_norm: bool = True,
35
+ qk_norm_type: str = "rms",
36
+ qkv_bias: bool = False,
37
+ dtype: Optional[torch.dtype] = None,
38
+ device: Optional[torch.device] = None,
39
+ attention_mode: str = "sdpa",
40
+ ):
41
+ factory_kwargs = {"device": device, "dtype": dtype}
42
+ super().__init__()
43
+
44
+ self.attention_mode = attention_mode
45
+ self.deterministic = False
46
+ self.heads_num = heads_num
47
+ head_dim = hidden_size // heads_num
48
+ mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
49
+
50
+ self.img_mod = ModulateDiT(
51
+ hidden_size,
52
+ factor=6,
53
+ act_layer=get_activation_layer("silu"),
54
+ **factory_kwargs,
55
+ )
56
+ self.img_norm1 = nn.LayerNorm(
57
+ hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
58
+ )
59
+
60
+ self.img_attn_qkv = nn.Linear(
61
+ hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs
62
+ )
63
+ qk_norm_layer = get_norm_layer(qk_norm_type)
64
+ self.img_attn_q_norm = (
65
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
66
+ if qk_norm
67
+ else nn.Identity()
68
+ )
69
+ self.img_attn_k_norm = (
70
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
71
+ if qk_norm
72
+ else nn.Identity()
73
+ )
74
+ self.img_attn_proj = nn.Linear(
75
+ hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
76
+ )
77
+
78
+ self.img_norm2 = nn.LayerNorm(
79
+ hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
80
+ )
81
+ self.img_mlp = MLP(
82
+ hidden_size,
83
+ mlp_hidden_dim,
84
+ act_layer=get_activation_layer(mlp_act_type),
85
+ bias=True,
86
+ **factory_kwargs,
87
+ )
88
+
89
+ self.txt_mod = ModulateDiT(
90
+ hidden_size,
91
+ factor=6,
92
+ act_layer=get_activation_layer("silu"),
93
+ **factory_kwargs,
94
+ )
95
+ self.txt_norm1 = nn.LayerNorm(
96
+ hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
97
+ )
98
+
99
+ self.txt_attn_qkv = nn.Linear(
100
+ hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs
101
+ )
102
+ self.txt_attn_q_norm = (
103
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
104
+ if qk_norm
105
+ else nn.Identity()
106
+ )
107
+ self.txt_attn_k_norm = (
108
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
109
+ if qk_norm
110
+ else nn.Identity()
111
+ )
112
+ self.txt_attn_proj = nn.Linear(
113
+ hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
114
+ )
115
+
116
+ self.txt_norm2 = nn.LayerNorm(
117
+ hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
118
+ )
119
+ self.txt_mlp = MLP(
120
+ hidden_size,
121
+ mlp_hidden_dim,
122
+ act_layer=get_activation_layer(mlp_act_type),
123
+ bias=True,
124
+ **factory_kwargs,
125
+ )
126
+ self.hybrid_seq_parallel_attn = None
127
+
128
+ def enable_deterministic(self):
129
+ self.deterministic = True
130
+
131
+ def disable_deterministic(self):
132
+ self.deterministic = False
133
+
134
+ def forward(
135
+ self,
136
+ img: torch.Tensor,
137
+ txt: torch.Tensor,
138
+ vec: torch.Tensor,
139
+ attn_mask = None,
140
+ cu_seqlens_q: Optional[torch.Tensor] = None,
141
+ cu_seqlens_kv: Optional[torch.Tensor] = None,
142
+ max_seqlen_q: Optional[int] = None,
143
+ max_seqlen_kv: Optional[int] = None,
144
+ freqs_cis: tuple = None,
145
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
146
+ (
147
+ img_mod1_shift,
148
+ img_mod1_scale,
149
+ img_mod1_gate,
150
+ img_mod2_shift,
151
+ img_mod2_scale,
152
+ img_mod2_gate,
153
+ ) = self.img_mod(vec).chunk(6, dim=-1)
154
+ (
155
+ txt_mod1_shift,
156
+ txt_mod1_scale,
157
+ txt_mod1_gate,
158
+ txt_mod2_shift,
159
+ txt_mod2_scale,
160
+ txt_mod2_gate,
161
+ ) = self.txt_mod(vec).chunk(6, dim=-1)
162
+
163
+ # Prepare image for attention.
164
+ img_modulated = self.img_norm1(img)
165
+ img_modulated = modulate(
166
+ img_modulated, shift=img_mod1_shift, scale=img_mod1_scale
167
+ )
168
+ img_qkv = self.img_attn_qkv(img_modulated)
169
+ img_q, img_k, img_v = rearrange(
170
+ img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num
171
+ )
172
+ # Apply QK-Norm if needed
173
+ img_q = self.img_attn_q_norm(img_q).to(img_v)
174
+ img_k = self.img_attn_k_norm(img_k).to(img_v)
175
+
176
+ # Apply RoPE if needed.
177
+ if freqs_cis is not None:
178
+ img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
179
+ assert (
180
+ img_qq.shape == img_q.shape and img_kk.shape == img_k.shape
181
+ ), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}"
182
+ img_q, img_k = img_qq, img_kk
183
+
184
+ # Prepare txt for attention.
185
+ txt_modulated = self.txt_norm1(txt)
186
+ txt_modulated = modulate(
187
+ txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale
188
+ )
189
+ txt_qkv = self.txt_attn_qkv(txt_modulated)
190
+ txt_q, txt_k, txt_v = rearrange(
191
+ txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num
192
+ )
193
+ # Apply QK-Norm if needed.
194
+ txt_q = self.txt_attn_q_norm(txt_q).to(txt_v)
195
+ txt_k = self.txt_attn_k_norm(txt_k).to(txt_v)
196
+
197
+ # Run actual attention.
198
+ q = torch.cat((img_q, txt_q), dim=1)
199
+ k = torch.cat((img_k, txt_k), dim=1)
200
+ v = torch.cat((img_v, txt_v), dim=1)
201
+ # assert (
202
+ # cu_seqlens_q.shape[0] == 2 * img.shape[0] + 1
203
+ # ), f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, img.shape[0]:{img.shape[0]}"
204
+
205
+ # attention computation start
206
+ if not self.hybrid_seq_parallel_attn:
207
+ attn = attention(
208
+ q,
209
+ k,
210
+ v,
211
+ mode=self.attention_mode,
212
+ attn_mask=attn_mask,
213
+ cu_seqlens_q=cu_seqlens_q,
214
+ cu_seqlens_kv=cu_seqlens_kv,
215
+ max_seqlen_q=max_seqlen_q,
216
+ max_seqlen_kv=max_seqlen_kv,
217
+ batch_size=img_k.shape[0],
218
+ )
219
+ else:
220
+ attn = parallel_attention(
221
+ self.hybrid_seq_parallel_attn,
222
+ q,
223
+ k,
224
+ v,
225
+ img_q_len=img_q.shape[1],
226
+ img_kv_len=img_k.shape[1],
227
+ cu_seqlens_q=cu_seqlens_q,
228
+ cu_seqlens_kv=cu_seqlens_kv
229
+ )
230
+
231
+ # attention computation end
232
+
233
+ img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :]
234
+
235
+ # Calculate the img bloks.
236
+ img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate)
237
+ img = img + apply_gate(
238
+ self.img_mlp(
239
+ modulate(
240
+ self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale
241
+ )
242
+ ),
243
+ gate=img_mod2_gate,
244
+ )
245
+
246
+ # Calculate the txt bloks.
247
+ txt = txt + apply_gate(self.txt_attn_proj(txt_attn), gate=txt_mod1_gate)
248
+ txt = txt + apply_gate(
249
+ self.txt_mlp(
250
+ modulate(
251
+ self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale
252
+ )
253
+ ),
254
+ gate=txt_mod2_gate,
255
+ )
256
+
257
+ return img, txt
258
+
259
+
260
+ class MMSingleStreamBlock(nn.Module):
261
+ """
262
+ A DiT block with parallel linear layers as described in
263
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
264
+ Also refer to (SD3): https://arxiv.org/abs/2403.03206
265
+ (Flux.1): https://github.com/black-forest-labs/flux
266
+ """
267
+
268
+ def __init__(
269
+ self,
270
+ hidden_size: int,
271
+ heads_num: int,
272
+ mlp_width_ratio: float = 4.0,
273
+ mlp_act_type: str = "gelu_tanh",
274
+ qk_norm: bool = True,
275
+ qk_norm_type: str = "rms",
276
+ qk_scale: float = None,
277
+ dtype: Optional[torch.dtype] = None,
278
+ device: Optional[torch.device] = None,
279
+ attention_mode: str = "sdpa",
280
+ ):
281
+ factory_kwargs = {"device": device, "dtype": dtype}
282
+ super().__init__()
283
+ self.attention_mode = attention_mode
284
+ self.deterministic = False
285
+ self.hidden_size = hidden_size
286
+ self.heads_num = heads_num
287
+ head_dim = hidden_size // heads_num
288
+ mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
289
+ self.mlp_hidden_dim = mlp_hidden_dim
290
+ self.scale = qk_scale or head_dim ** -0.5
291
+
292
+ # qkv and mlp_in
293
+ self.linear1 = nn.Linear(
294
+ hidden_size, hidden_size * 3 + mlp_hidden_dim, **factory_kwargs
295
+ )
296
+ # proj and mlp_out
297
+ self.linear2 = nn.Linear(
298
+ hidden_size + mlp_hidden_dim, hidden_size, **factory_kwargs
299
+ )
300
+
301
+ qk_norm_layer = get_norm_layer(qk_norm_type)
302
+ self.q_norm = (
303
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
304
+ if qk_norm
305
+ else nn.Identity()
306
+ )
307
+ self.k_norm = (
308
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
309
+ if qk_norm
310
+ else nn.Identity()
311
+ )
312
+
313
+ self.pre_norm = nn.LayerNorm(
314
+ hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
315
+ )
316
+
317
+ self.mlp_act = get_activation_layer(mlp_act_type)()
318
+ self.modulation = ModulateDiT(
319
+ hidden_size,
320
+ factor=3,
321
+ act_layer=get_activation_layer("silu"),
322
+ **factory_kwargs,
323
+ )
324
+ self.hybrid_seq_parallel_attn = None
325
+
326
+ def enable_deterministic(self):
327
+ self.deterministic = True
328
+
329
+ def disable_deterministic(self):
330
+ self.deterministic = False
331
+
332
+ def forward(
333
+ self,
334
+ x: torch.Tensor,
335
+ vec: torch.Tensor,
336
+ txt_len: int,
337
+ attn_mask= None,
338
+ cu_seqlens_q: Optional[torch.Tensor] = None,
339
+ cu_seqlens_kv: Optional[torch.Tensor] = None,
340
+ max_seqlen_q: Optional[int] = None,
341
+ max_seqlen_kv: Optional[int] = None,
342
+ freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
343
+
344
+ ) -> torch.Tensor:
345
+ mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1)
346
+ x_mod = modulate(self.pre_norm(x), shift=mod_shift, scale=mod_scale)
347
+ qkv, mlp = torch.split(
348
+ self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1
349
+ )
350
+
351
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
352
+
353
+ # Apply QK-Norm if needed.
354
+ q = self.q_norm(q).to(v)
355
+ k = self.k_norm(k).to(v)
356
+
357
+ # Apply RoPE if needed.
358
+ if freqs_cis is not None:
359
+ img_q, txt_q = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :]
360
+ img_k, txt_k = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
361
+ img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
362
+ assert (
363
+ img_qq.shape == img_q.shape and img_kk.shape == img_k.shape
364
+ ), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}"
365
+ img_q, img_k = img_qq, img_kk
366
+ q = torch.cat((img_q, txt_q), dim=1)
367
+ k = torch.cat((img_k, txt_k), dim=1)
368
+
369
+ # Compute attention.
370
+ # assert (
371
+ # cu_seqlens_q.shape[0] == 2 * x.shape[0] + 1
372
+ # ), f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, x.shape[0]:{x.shape[0]}"
373
+
374
+ # attention computation start
375
+ if not self.hybrid_seq_parallel_attn:
376
+ attn = attention(
377
+ q,
378
+ k,
379
+ v,
380
+ mode=self.attention_mode,
381
+ attn_mask=attn_mask,
382
+ cu_seqlens_q=cu_seqlens_q,
383
+ cu_seqlens_kv=cu_seqlens_kv,
384
+ max_seqlen_q=max_seqlen_q,
385
+ max_seqlen_kv=max_seqlen_kv,
386
+ batch_size=x.shape[0],
387
+ )
388
+ else:
389
+ attn = parallel_attention(
390
+ self.hybrid_seq_parallel_attn,
391
+ q,
392
+ k,
393
+ v,
394
+ img_q_len=img_q.shape[1],
395
+ img_kv_len=img_k.shape[1],
396
+ cu_seqlens_q=cu_seqlens_q,
397
+ cu_seqlens_kv=cu_seqlens_kv
398
+ )
399
+ # attention computation end
400
+
401
+ # Compute activation in mlp stream, cat again and run second linear layer.
402
+ output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
403
+ return x + apply_gate(output, gate=mod_gate)
404
+
405
+
406
+ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
407
+ """
408
+ HunyuanVideo Transformer backbone
409
+
410
+ Inherited from ModelMixin and ConfigMixin for compatibility with diffusers' sampler StableDiffusionPipeline.
411
+
412
+ Reference:
413
+ [1] Flux.1: https://github.com/black-forest-labs/flux
414
+ [2] MMDiT: http://arxiv.org/abs/2403.03206
415
+
416
+ Parameters
417
+ ----------
418
+ args: argparse.Namespace
419
+ The arguments parsed by argparse.
420
+ patch_size: list
421
+ The size of the patch.
422
+ in_channels: int
423
+ The number of input channels.
424
+ out_channels: int
425
+ The number of output channels.
426
+ hidden_size: int
427
+ The hidden size of the transformer backbone.
428
+ heads_num: int
429
+ The number of attention heads.
430
+ mlp_width_ratio: float
431
+ The ratio of the hidden size of the MLP in the transformer block.
432
+ mlp_act_type: str
433
+ The activation function of the MLP in the transformer block.
434
+ depth_double_blocks: int
435
+ The number of transformer blocks in the double blocks.
436
+ depth_single_blocks: int
437
+ The number of transformer blocks in the single blocks.
438
+ rope_dim_list: list
439
+ The dimension of the rotary embedding for t, h, w.
440
+ qkv_bias: bool
441
+ Whether to use bias in the qkv linear layer.
442
+ qk_norm: bool
443
+ Whether to use qk norm.
444
+ qk_norm_type: str
445
+ The type of qk norm.
446
+ guidance_embed: bool
447
+ Whether to use guidance embedding for distillation.
448
+ text_projection: str
449
+ The type of the text projection, default is single_refiner.
450
+ use_attention_mask: bool
451
+ Whether to use attention mask for text encoder.
452
+ dtype: torch.dtype
453
+ The dtype of the model.
454
+ device: torch.device
455
+ The device of the model.
456
+ """
457
+
458
+ @register_to_config
459
+ def __init__(
460
+ self,
461
+ args: Any,
462
+ patch_size: list = [1, 2, 2],
463
+ in_channels: int = 4, # Should be VAE.config.latent_channels.
464
+ out_channels: int = None,
465
+ hidden_size: int = 3072,
466
+ heads_num: int = 24,
467
+ mlp_width_ratio: float = 4.0,
468
+ mlp_act_type: str = "gelu_tanh",
469
+ mm_double_blocks_depth: int = 20,
470
+ mm_single_blocks_depth: int = 40,
471
+ rope_dim_list: List[int] = [16, 56, 56],
472
+ qkv_bias: bool = True,
473
+ qk_norm: bool = True,
474
+ qk_norm_type: str = "rms",
475
+ guidance_embed: bool = False, # For modulation.
476
+ text_projection: str = "single_refiner",
477
+ use_attention_mask: bool = True,
478
+ dtype: Optional[torch.dtype] = None,
479
+ device: Optional[torch.device] = None,
480
+ attention_mode: Optional[str] = "sdpa"
481
+ ):
482
+ factory_kwargs = {"device": device, "dtype": dtype}
483
+ super().__init__()
484
+
485
+ self.patch_size = patch_size
486
+ self.in_channels = in_channels
487
+ self.out_channels = in_channels if out_channels is None else out_channels
488
+ self.unpatchify_channels = self.out_channels
489
+ self.guidance_embed = guidance_embed
490
+ self.rope_dim_list = rope_dim_list
491
+ self.attention_mode = attention_mode
492
+
493
+ # Text projection. Default to linear projection.
494
+ # Alternative: TokenRefiner. See more details (LI-DiT): http://arxiv.org/abs/2406.11831
495
+ self.use_attention_mask = use_attention_mask
496
+ self.text_projection = text_projection
497
+
498
+ self.text_states_dim = args.text_states_dim
499
+ self.text_states_dim_2 = args.text_states_dim_2
500
+
501
+ if hidden_size % heads_num != 0:
502
+ raise ValueError(
503
+ f"Hidden size {hidden_size} must be divisible by heads_num {heads_num}"
504
+ )
505
+ pe_dim = hidden_size // heads_num
506
+ if sum(rope_dim_list) != pe_dim:
507
+ raise ValueError(
508
+ f"Got {rope_dim_list} but expected positional dim {pe_dim}"
509
+ )
510
+ self.hidden_size = hidden_size
511
+ self.heads_num = heads_num
512
+
513
+ # image projection
514
+ self.img_in = PatchEmbed(
515
+ self.patch_size, self.in_channels, self.hidden_size, **factory_kwargs
516
+ )
517
+
518
+ # text projection
519
+ if self.text_projection == "linear":
520
+ self.txt_in = TextProjection(
521
+ self.text_states_dim,
522
+ self.hidden_size,
523
+ get_activation_layer("silu"),
524
+ **factory_kwargs,
525
+ )
526
+ elif self.text_projection == "single_refiner":
527
+ self.txt_in = SingleTokenRefiner(
528
+ self.text_states_dim, hidden_size, heads_num, depth=2, **factory_kwargs
529
+ )
530
+ else:
531
+ raise NotImplementedError(
532
+ f"Unsupported text_projection: {self.text_projection}"
533
+ )
534
+
535
+ # time modulation
536
+ self.time_in = TimestepEmbedder(
537
+ self.hidden_size, get_activation_layer("silu"), **factory_kwargs
538
+ )
539
+
540
+ # text modulation
541
+ self.vector_in = MLPEmbedder(
542
+ self.text_states_dim_2, self.hidden_size, **factory_kwargs
543
+ )
544
+
545
+ # guidance modulation
546
+ self.guidance_in = (
547
+ TimestepEmbedder(
548
+ self.hidden_size, get_activation_layer("silu"), **factory_kwargs
549
+ )
550
+ if guidance_embed
551
+ else None
552
+ )
553
+
554
+ # double blocks
555
+ self.double_blocks = nn.ModuleList(
556
+ [
557
+ MMDoubleStreamBlock(
558
+ self.hidden_size,
559
+ self.heads_num,
560
+ mlp_width_ratio=mlp_width_ratio,
561
+ mlp_act_type=mlp_act_type,
562
+ qk_norm=qk_norm,
563
+ qk_norm_type=qk_norm_type,
564
+ qkv_bias=qkv_bias,
565
+ attention_mode = attention_mode,
566
+ **factory_kwargs,
567
+ )
568
+ for _ in range(mm_double_blocks_depth)
569
+ ]
570
+ )
571
+
572
+ # single blocks
573
+ self.single_blocks = nn.ModuleList(
574
+ [
575
+ MMSingleStreamBlock(
576
+ self.hidden_size,
577
+ self.heads_num,
578
+ mlp_width_ratio=mlp_width_ratio,
579
+ mlp_act_type=mlp_act_type,
580
+ qk_norm=qk_norm,
581
+ qk_norm_type=qk_norm_type,
582
+ attention_mode = attention_mode,
583
+ **factory_kwargs,
584
+ )
585
+ for _ in range(mm_single_blocks_depth)
586
+ ]
587
+ )
588
+
589
+ self.final_layer = FinalLayer(
590
+ self.hidden_size,
591
+ self.patch_size,
592
+ self.out_channels,
593
+ get_activation_layer("silu"),
594
+ **factory_kwargs,
595
+ )
596
+
597
+ def enable_deterministic(self):
598
+ for block in self.double_blocks:
599
+ block.enable_deterministic()
600
+ for block in self.single_blocks:
601
+ block.enable_deterministic()
602
+
603
+ def disable_deterministic(self):
604
+ for block in self.double_blocks:
605
+ block.disable_deterministic()
606
+ for block in self.single_blocks:
607
+ block.disable_deterministic()
608
+
609
+ def forward(
610
+ self,
611
+ x: torch.Tensor,
612
+ t: torch.Tensor, # Should be in range(0, 1000).
613
+ text_states: torch.Tensor = None,
614
+ text_mask: torch.Tensor = None, # Now we don't use it.
615
+ text_states_2: Optional[torch.Tensor] = None, # Text embedding for modulation.
616
+ freqs_cos: Optional[torch.Tensor] = None,
617
+ freqs_sin: Optional[torch.Tensor] = None,
618
+ guidance: torch.Tensor = None, # Guidance for modulation, should be cfg_scale x 1000.
619
+ return_dict: bool = True,
620
+ ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
621
+ out = {}
622
+ img = x
623
+ txt = text_states
624
+ _, _, ot, oh, ow = x.shape
625
+ tt, th, tw = (
626
+ ot // self.patch_size[0],
627
+ oh // self.patch_size[1],
628
+ ow // self.patch_size[2],
629
+ )
630
+
631
+ # Prepare modulation vectors.
632
+ vec = self.time_in(t)
633
+
634
+ # text modulation
635
+ vec = vec + self.vector_in(text_states_2)
636
+
637
+ # guidance modulation
638
+ if self.guidance_embed:
639
+ if guidance is None:
640
+ raise ValueError(
641
+ "Didn't get guidance strength for guidance distilled model."
642
+ )
643
+
644
+ # our timestep_embedding is merged into guidance_in(TimestepEmbedder)
645
+ vec = vec + self.guidance_in(guidance)
646
+
647
+ # Embed image and text.
648
+ img = self.img_in(img)
649
+ if self.text_projection == "linear":
650
+ txt = self.txt_in(txt)
651
+ elif self.text_projection == "single_refiner":
652
+ txt = self.txt_in(txt, t, text_mask if self.use_attention_mask else None)
653
+ else:
654
+ raise NotImplementedError(
655
+ f"Unsupported text_projection: {self.text_projection}"
656
+ )
657
+
658
+ txt_seq_len = txt.shape[1]
659
+ img_seq_len = img.shape[1]
660
+
661
+ # Compute cu_squlens and max_seqlen for flash attention
662
+ # cu_seqlens_q = get_cu_seqlens(text_mask, img_seq_len)
663
+ # cu_seqlens_kv = cu_seqlens_q
664
+ max_seqlen_q = img_seq_len + txt_seq_len
665
+ max_seqlen_kv = max_seqlen_q
666
+
667
+ # thanks to kijai (https://github.com/kijai/ComfyUI-HunyuanVideoWrapper/), for the code to support sdpa
668
+ if self.attention_mode == "sdpa":
669
+ cu_seqlens_q, cu_seqlens_kv = None, None
670
+ # Create a square boolean mask filled with False
671
+ attn_mask = torch.zeros((1, max_seqlen_q, max_seqlen_q), dtype=torch.bool, device=text_mask.device)
672
+
673
+ # Calculate the valid attention regions
674
+ text_len = text_mask[0].sum().item()
675
+ total_len = text_len + img_seq_len
676
+
677
+ # Allow attention to all tokens up to total_len
678
+ attn_mask[0, :total_len, :total_len] = True
679
+ else:
680
+ attn_mask = None
681
+ # Compute cu_squlens for flash and sage attention
682
+ cu_seqlens_q = get_cu_seqlens(text_mask, img_seq_len)
683
+ cu_seqlens_kv = cu_seqlens_q
684
+
685
+ freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
686
+
687
+ if self.enable_teacache:
688
+ inp = img.clone()
689
+ vec_ = vec.clone()
690
+ txt_ = txt.clone()
691
+ (
692
+ img_mod1_shift,
693
+ img_mod1_scale,
694
+ img_mod1_gate,
695
+ img_mod2_shift,
696
+ img_mod2_scale,
697
+ img_mod2_gate,
698
+ ) = self.double_blocks[0].img_mod(vec_).chunk(6, dim=-1)
699
+ normed_inp = self.double_blocks[0].img_norm1(inp)
700
+ modulated_inp = modulate(
701
+ normed_inp, shift=img_mod1_shift, scale=img_mod1_scale
702
+ )
703
+ if self.cnt == 0 or self.cnt == self.num_steps-1:
704
+ should_calc = True
705
+ self.accumulated_rel_l1_distance = 0
706
+ else:
707
+ coefficients = [7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02]
708
+ rescale_func = np.poly1d(coefficients)
709
+ self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
710
+ if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
711
+ should_calc = False
712
+ else:
713
+ should_calc = True
714
+ self.accumulated_rel_l1_distance = 0
715
+ self.previous_modulated_input = modulated_inp
716
+ self.cnt += 1
717
+ if self.cnt == self.num_steps:
718
+ self.cnt = 0
719
+
720
+ if self.enable_teacache:
721
+ if not should_calc:
722
+ img += self.previous_residual
723
+ else:
724
+ ori_img = img.clone()
725
+ # --------------------- Pass through DiT blocks ------------------------
726
+ for _, block in enumerate(self.double_blocks):
727
+ double_block_args = [
728
+ img,
729
+ txt,
730
+ vec,
731
+ attn_mask,
732
+ cu_seqlens_q,
733
+ cu_seqlens_kv,
734
+ max_seqlen_q,
735
+ max_seqlen_kv,
736
+ freqs_cis,
737
+ ]
738
+
739
+ img, txt = block(*double_block_args)
740
+
741
+ # Merge txt and img to pass through single stream blocks.
742
+ x = torch.cat((img, txt), 1)
743
+ if len(self.single_blocks) > 0:
744
+ for _, block in enumerate(self.single_blocks):
745
+ single_block_args = [
746
+ x,
747
+ vec,
748
+ txt_seq_len,
749
+ attn_mask,
750
+ cu_seqlens_q,
751
+ cu_seqlens_kv,
752
+ max_seqlen_q,
753
+ max_seqlen_kv,
754
+ (freqs_cos, freqs_sin),
755
+ ]
756
+
757
+ x = block(*single_block_args)
758
+
759
+ img = x[:, :img_seq_len, ...]
760
+ self.previous_residual = img - ori_img
761
+ else:
762
+ # --------------------- Pass through DiT blocks ------------------------
763
+ for _, block in enumerate(self.double_blocks):
764
+ double_block_args = [
765
+ img,
766
+ txt,
767
+ vec,
768
+ attn_mask,
769
+ cu_seqlens_q,
770
+ cu_seqlens_kv,
771
+ max_seqlen_q,
772
+ max_seqlen_kv,
773
+ freqs_cis,
774
+ ]
775
+
776
+ img, txt = block(*double_block_args)
777
+
778
+ # Merge txt and img to pass through single stream blocks.
779
+ x = torch.cat((img, txt), 1)
780
+ if len(self.single_blocks) > 0:
781
+ for _, block in enumerate(self.single_blocks):
782
+ single_block_args = [
783
+ x,
784
+ vec,
785
+ txt_seq_len,
786
+ attn_mask,
787
+ cu_seqlens_q,
788
+ cu_seqlens_kv,
789
+ max_seqlen_q,
790
+ max_seqlen_kv,
791
+ (freqs_cos, freqs_sin),
792
+ ]
793
+
794
+ x = block(*single_block_args)
795
+
796
+ img = x[:, :img_seq_len, ...]
797
+
798
+ # ---------------------------- Final layer ------------------------------
799
+ img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
800
+
801
+ img = self.unpatchify(img, tt, th, tw)
802
+ if return_dict:
803
+ out["x"] = img
804
+ return out
805
+ return img
806
+
807
+ def unpatchify(self, x, t, h, w):
808
+ """
809
+ x: (N, T, patch_size**2 * C)
810
+ imgs: (N, H, W, C)
811
+ """
812
+ c = self.unpatchify_channels
813
+ pt, ph, pw = self.patch_size
814
+ assert t * h * w == x.shape[1]
815
+
816
+ x = x.reshape(shape=(x.shape[0], t, h, w, c, pt, ph, pw))
817
+ x = torch.einsum("nthwcopq->nctohpwq", x)
818
+ imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
819
+
820
+ return imgs
821
+
822
+ def params_count(self):
823
+ counts = {
824
+ "double": sum(
825
+ [
826
+ sum(p.numel() for p in block.img_attn_qkv.parameters())
827
+ + sum(p.numel() for p in block.img_attn_proj.parameters())
828
+ + sum(p.numel() for p in block.img_mlp.parameters())
829
+ + sum(p.numel() for p in block.txt_attn_qkv.parameters())
830
+ + sum(p.numel() for p in block.txt_attn_proj.parameters())
831
+ + sum(p.numel() for p in block.txt_mlp.parameters())
832
+ for block in self.double_blocks
833
+ ]
834
+ ),
835
+ "single": sum(
836
+ [
837
+ sum(p.numel() for p in block.linear1.parameters())
838
+ + sum(p.numel() for p in block.linear2.parameters())
839
+ for block in self.single_blocks
840
+ ]
841
+ ),
842
+ "total": sum(p.numel() for p in self.parameters()),
843
+ }
844
+ counts["attn+mlp"] = counts["double"] + counts["single"]
845
+ return counts
846
+
847
+
848
+ #################################################################################
849
+ # HunyuanVideo Configs #
850
+ #################################################################################
851
+
852
+ HUNYUAN_VIDEO_CONFIG = {
853
+ "HYVideo-T/2": {
854
+ "mm_double_blocks_depth": 20,
855
+ "mm_single_blocks_depth": 40,
856
+ "rope_dim_list": [16, 56, 56],
857
+ "hidden_size": 3072,
858
+ "heads_num": 24,
859
+ "mlp_width_ratio": 4,
860
+ },
861
+ "HYVideo-T/2-cfgdistill": {
862
+ "mm_double_blocks_depth": 20,
863
+ "mm_single_blocks_depth": 40,
864
+ "rope_dim_list": [16, 56, 56],
865
+ "hidden_size": 3072,
866
+ "heads_num": 24,
867
+ "mlp_width_ratio": 4,
868
+ "guidance_embed": True,
869
+ },
870
+ }
hyvideo/modules/modulate_layers.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class ModulateDiT(nn.Module):
8
+ """Modulation layer for DiT."""
9
+ def __init__(
10
+ self,
11
+ hidden_size: int,
12
+ factor: int,
13
+ act_layer: Callable,
14
+ dtype=None,
15
+ device=None,
16
+ ):
17
+ factory_kwargs = {"dtype": dtype, "device": device}
18
+ super().__init__()
19
+ self.act = act_layer()
20
+ self.linear = nn.Linear(
21
+ hidden_size, factor * hidden_size, bias=True, **factory_kwargs
22
+ )
23
+ # Zero-initialize the modulation
24
+ nn.init.zeros_(self.linear.weight)
25
+ nn.init.zeros_(self.linear.bias)
26
+
27
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
28
+ return self.linear(self.act(x))
29
+
30
+
31
+ def modulate(x, shift=None, scale=None):
32
+ """modulate by shift and scale
33
+
34
+ Args:
35
+ x (torch.Tensor): input tensor.
36
+ shift (torch.Tensor, optional): shift tensor. Defaults to None.
37
+ scale (torch.Tensor, optional): scale tensor. Defaults to None.
38
+
39
+ Returns:
40
+ torch.Tensor: the output tensor after modulate.
41
+ """
42
+ if scale is None and shift is None:
43
+ return x
44
+ elif shift is None:
45
+ return x * (1 + scale.unsqueeze(1))
46
+ elif scale is None:
47
+ return x + shift.unsqueeze(1)
48
+ else:
49
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
50
+
51
+
52
+ def apply_gate(x, gate=None, tanh=False):
53
+ """AI is creating summary for apply_gate
54
+
55
+ Args:
56
+ x (torch.Tensor): input tensor.
57
+ gate (torch.Tensor, optional): gate tensor. Defaults to None.
58
+ tanh (bool, optional): whether to use tanh function. Defaults to False.
59
+
60
+ Returns:
61
+ torch.Tensor: the output tensor after apply gate.
62
+ """
63
+ if gate is None:
64
+ return x
65
+ if tanh:
66
+ return x * gate.unsqueeze(1).tanh()
67
+ else:
68
+ return x * gate.unsqueeze(1)
69
+
70
+
71
+ def ckpt_wrapper(module):
72
+ def ckpt_forward(*inputs):
73
+ outputs = module(*inputs)
74
+ return outputs
75
+
76
+ return ckpt_forward
hyvideo/modules/norm_layers.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class RMSNorm(nn.Module):
6
+ def __init__(
7
+ self,
8
+ dim: int,
9
+ elementwise_affine=True,
10
+ eps: float = 1e-6,
11
+ device=None,
12
+ dtype=None,
13
+ ):
14
+ """
15
+ Initialize the RMSNorm normalization layer.
16
+
17
+ Args:
18
+ dim (int): The dimension of the input tensor.
19
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
20
+
21
+ Attributes:
22
+ eps (float): A small value added to the denominator for numerical stability.
23
+ weight (nn.Parameter): Learnable scaling parameter.
24
+
25
+ """
26
+ factory_kwargs = {"device": device, "dtype": dtype}
27
+ super().__init__()
28
+ self.eps = eps
29
+ if elementwise_affine:
30
+ self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
31
+
32
+ def _norm(self, x):
33
+ """
34
+ Apply the RMSNorm normalization to the input tensor.
35
+
36
+ Args:
37
+ x (torch.Tensor): The input tensor.
38
+
39
+ Returns:
40
+ torch.Tensor: The normalized tensor.
41
+
42
+ """
43
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
44
+
45
+ def forward(self, x):
46
+ """
47
+ Forward pass through the RMSNorm layer.
48
+
49
+ Args:
50
+ x (torch.Tensor): The input tensor.
51
+
52
+ Returns:
53
+ torch.Tensor: The output tensor after applying RMSNorm.
54
+
55
+ """
56
+ output = self._norm(x.float()).type_as(x)
57
+ if hasattr(self, "weight"):
58
+ output = output * self.weight
59
+ return output
60
+
61
+
62
+ def get_norm_layer(norm_layer):
63
+ """
64
+ Get the normalization layer.
65
+
66
+ Args:
67
+ norm_layer (str): The type of normalization layer.
68
+
69
+ Returns:
70
+ norm_layer (nn.Module): The normalization layer.
71
+ """
72
+ if norm_layer == "layer":
73
+ return nn.LayerNorm
74
+ elif norm_layer == "rms":
75
+ return RMSNorm
76
+ else:
77
+ raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
hyvideo/modules/posemb_layers.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Union, Tuple, List
3
+
4
+
5
+ def _to_tuple(x, dim=2):
6
+ if isinstance(x, int):
7
+ return (x,) * dim
8
+ elif len(x) == dim:
9
+ return x
10
+ else:
11
+ raise ValueError(f"Expected length {dim} or int, but got {x}")
12
+
13
+
14
+ def get_meshgrid_nd(start, *args, dim=2):
15
+ """
16
+ Get n-D meshgrid with start, stop and num.
17
+
18
+ Args:
19
+ start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
20
+ step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
21
+ should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
22
+ n-tuples.
23
+ *args: See above.
24
+ dim (int): Dimension of the meshgrid. Defaults to 2.
25
+
26
+ Returns:
27
+ grid (np.ndarray): [dim, ...]
28
+ """
29
+ if len(args) == 0:
30
+ # start is grid_size
31
+ num = _to_tuple(start, dim=dim)
32
+ start = (0,) * dim
33
+ stop = num
34
+ elif len(args) == 1:
35
+ # start is start, args[0] is stop, step is 1
36
+ start = _to_tuple(start, dim=dim)
37
+ stop = _to_tuple(args[0], dim=dim)
38
+ num = [stop[i] - start[i] for i in range(dim)]
39
+ elif len(args) == 2:
40
+ # start is start, args[0] is stop, args[1] is num
41
+ start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0
42
+ stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32
43
+ num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124
44
+ else:
45
+ raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
46
+
47
+ # PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
48
+ axis_grid = []
49
+ for i in range(dim):
50
+ a, b, n = start[i], stop[i], num[i]
51
+ g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
52
+ axis_grid.append(g)
53
+ grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D]
54
+ grid = torch.stack(grid, dim=0) # [dim, W, H, D]
55
+
56
+ return grid
57
+
58
+
59
+ #################################################################################
60
+ # Rotary Positional Embedding Functions #
61
+ #################################################################################
62
+ # https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L80
63
+
64
+
65
+ def reshape_for_broadcast(
66
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
67
+ x: torch.Tensor,
68
+ head_first=False,
69
+ ):
70
+ """
71
+ Reshape frequency tensor for broadcasting it with another tensor.
72
+
73
+ This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
74
+ for the purpose of broadcasting the frequency tensor during element-wise operations.
75
+
76
+ Notes:
77
+ When using FlashMHAModified, head_first should be False.
78
+ When using Attention, head_first should be True.
79
+
80
+ Args:
81
+ freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped.
82
+ x (torch.Tensor): Target tensor for broadcasting compatibility.
83
+ head_first (bool): head dimension first (except batch dim) or not.
84
+
85
+ Returns:
86
+ torch.Tensor: Reshaped frequency tensor.
87
+
88
+ Raises:
89
+ AssertionError: If the frequency tensor doesn't match the expected shape.
90
+ AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
91
+ """
92
+ ndim = x.ndim
93
+ assert 0 <= 1 < ndim
94
+
95
+ if isinstance(freqs_cis, tuple):
96
+ # freqs_cis: (cos, sin) in real space
97
+ if head_first:
98
+ assert freqs_cis[0].shape == (
99
+ x.shape[-2],
100
+ x.shape[-1],
101
+ ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
102
+ shape = [
103
+ d if i == ndim - 2 or i == ndim - 1 else 1
104
+ for i, d in enumerate(x.shape)
105
+ ]
106
+ else:
107
+ assert freqs_cis[0].shape == (
108
+ x.shape[1],
109
+ x.shape[-1],
110
+ ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
111
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
112
+ return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
113
+ else:
114
+ # freqs_cis: values in complex space
115
+ if head_first:
116
+ assert freqs_cis.shape == (
117
+ x.shape[-2],
118
+ x.shape[-1],
119
+ ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
120
+ shape = [
121
+ d if i == ndim - 2 or i == ndim - 1 else 1
122
+ for i, d in enumerate(x.shape)
123
+ ]
124
+ else:
125
+ assert freqs_cis.shape == (
126
+ x.shape[1],
127
+ x.shape[-1],
128
+ ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
129
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
130
+ return freqs_cis.view(*shape)
131
+
132
+
133
+ def rotate_half(x):
134
+ x_real, x_imag = (
135
+ x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
136
+ ) # [B, S, H, D//2]
137
+ return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
138
+
139
+
140
+ def apply_rotary_emb(
141
+ xq: torch.Tensor,
142
+ xk: torch.Tensor,
143
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
144
+ head_first: bool = False,
145
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
146
+ """
147
+ Apply rotary embeddings to input tensors using the given frequency tensor.
148
+
149
+ This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
150
+ frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
151
+ is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
152
+ returned as real tensors.
153
+
154
+ Args:
155
+ xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
156
+ xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D]
157
+ freqs_cis (torch.Tensor or tuple): Precomputed frequency tensor for complex exponential.
158
+ head_first (bool): head dimension first (except batch dim) or not.
159
+
160
+ Returns:
161
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
162
+
163
+ """
164
+ xk_out = None
165
+ if isinstance(freqs_cis, tuple):
166
+ cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
167
+ cos, sin = cos.to(xq.device), sin.to(xq.device)
168
+ # real * cos - imag * sin
169
+ # imag * cos + real * sin
170
+ xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
171
+ xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
172
+ else:
173
+ # view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex)
174
+ xq_ = torch.view_as_complex(
175
+ xq.float().reshape(*xq.shape[:-1], -1, 2)
176
+ ) # [B, S, H, D//2]
177
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(
178
+ xq.device
179
+ ) # [S, D//2] --> [1, S, 1, D//2]
180
+ # (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin)
181
+ # view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real)
182
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
183
+ xk_ = torch.view_as_complex(
184
+ xk.float().reshape(*xk.shape[:-1], -1, 2)
185
+ ) # [B, S, H, D//2]
186
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
187
+
188
+ return xq_out, xk_out
189
+
190
+
191
+ def get_nd_rotary_pos_embed(
192
+ rope_dim_list,
193
+ start,
194
+ *args,
195
+ theta=10000.0,
196
+ use_real=False,
197
+ theta_rescale_factor: Union[float, List[float]] = 1.0,
198
+ interpolation_factor: Union[float, List[float]] = 1.0,
199
+ ):
200
+ """
201
+ This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
202
+
203
+ Args:
204
+ rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
205
+ sum(rope_dim_list) should equal to head_dim of attention layer.
206
+ start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start,
207
+ args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.
208
+ *args: See above.
209
+ theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
210
+ use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers.
211
+ Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real
212
+ part and an imaginary part separately.
213
+ theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.
214
+
215
+ Returns:
216
+ pos_embed (torch.Tensor): [HW, D/2]
217
+ """
218
+
219
+ grid = get_meshgrid_nd(
220
+ start, *args, dim=len(rope_dim_list)
221
+ ) # [3, W, H, D] / [2, W, H]
222
+
223
+ if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float):
224
+ theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
225
+ elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
226
+ theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
227
+ assert len(theta_rescale_factor) == len(
228
+ rope_dim_list
229
+ ), "len(theta_rescale_factor) should equal to len(rope_dim_list)"
230
+
231
+ if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float):
232
+ interpolation_factor = [interpolation_factor] * len(rope_dim_list)
233
+ elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
234
+ interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
235
+ assert len(interpolation_factor) == len(
236
+ rope_dim_list
237
+ ), "len(interpolation_factor) should equal to len(rope_dim_list)"
238
+
239
+ # use 1/ndim of dimensions to encode grid_axis
240
+ embs = []
241
+ for i in range(len(rope_dim_list)):
242
+ emb = get_1d_rotary_pos_embed(
243
+ rope_dim_list[i],
244
+ grid[i].reshape(-1),
245
+ theta,
246
+ use_real=use_real,
247
+ theta_rescale_factor=theta_rescale_factor[i],
248
+ interpolation_factor=interpolation_factor[i],
249
+ ) # 2 x [WHD, rope_dim_list[i]]
250
+ embs.append(emb)
251
+
252
+ if use_real:
253
+ cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2)
254
+ sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2)
255
+ return cos, sin
256
+ else:
257
+ emb = torch.cat(embs, dim=1) # (WHD, D/2)
258
+ return emb
259
+
260
+
261
+ def get_1d_rotary_pos_embed(
262
+ dim: int,
263
+ pos: Union[torch.FloatTensor, int],
264
+ theta: float = 10000.0,
265
+ use_real: bool = False,
266
+ theta_rescale_factor: float = 1.0,
267
+ interpolation_factor: float = 1.0,
268
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
269
+ """
270
+ Precompute the frequency tensor for complex exponential (cis) with given dimensions.
271
+ (Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)
272
+
273
+ This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
274
+ and the end index 'end'. The 'theta' parameter scales the frequencies.
275
+ The returned tensor contains complex values in complex64 data type.
276
+
277
+ Args:
278
+ dim (int): Dimension of the frequency tensor.
279
+ pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
280
+ theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
281
+ use_real (bool, optional): If True, return real part and imaginary part separately.
282
+ Otherwise, return complex numbers.
283
+ theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.
284
+
285
+ Returns:
286
+ freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2]
287
+ freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
288
+ """
289
+ if isinstance(pos, int):
290
+ pos = torch.arange(pos).float()
291
+
292
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
293
+ # has some connection to NTK literature
294
+ if theta_rescale_factor != 1.0:
295
+ theta *= theta_rescale_factor ** (dim / (dim - 2))
296
+
297
+ freqs = 1.0 / (
298
+ theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
299
+ ) # [D/2]
300
+ # assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}"
301
+ freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2]
302
+ if use_real:
303
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
304
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
305
+ return freqs_cos, freqs_sin
306
+ else:
307
+ freqs_cis = torch.polar(
308
+ torch.ones_like(freqs), freqs
309
+ ) # complex64 # [S, D/2]
310
+ return freqs_cis
hyvideo/modules/token_refiner.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ from einops import rearrange
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from .activation_layers import get_activation_layer
8
+ from .attenion import attention
9
+ from .norm_layers import get_norm_layer
10
+ from .embed_layers import TimestepEmbedder, TextProjection
11
+ from .attenion import attention
12
+ from .mlp_layers import MLP
13
+ from .modulate_layers import modulate, apply_gate
14
+
15
+
16
+ class IndividualTokenRefinerBlock(nn.Module):
17
+ def __init__(
18
+ self,
19
+ hidden_size,
20
+ heads_num,
21
+ mlp_width_ratio: str = 4.0,
22
+ mlp_drop_rate: float = 0.0,
23
+ act_type: str = "silu",
24
+ qk_norm: bool = False,
25
+ qk_norm_type: str = "layer",
26
+ qkv_bias: bool = True,
27
+ dtype: Optional[torch.dtype] = None,
28
+ device: Optional[torch.device] = None,
29
+ ):
30
+ factory_kwargs = {"device": device, "dtype": dtype}
31
+ super().__init__()
32
+ self.heads_num = heads_num
33
+ head_dim = hidden_size // heads_num
34
+ mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
35
+
36
+ self.norm1 = nn.LayerNorm(
37
+ hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
38
+ )
39
+ self.self_attn_qkv = nn.Linear(
40
+ hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs
41
+ )
42
+ qk_norm_layer = get_norm_layer(qk_norm_type)
43
+ self.self_attn_q_norm = (
44
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
45
+ if qk_norm
46
+ else nn.Identity()
47
+ )
48
+ self.self_attn_k_norm = (
49
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
50
+ if qk_norm
51
+ else nn.Identity()
52
+ )
53
+ self.self_attn_proj = nn.Linear(
54
+ hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
55
+ )
56
+
57
+ self.norm2 = nn.LayerNorm(
58
+ hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
59
+ )
60
+ act_layer = get_activation_layer(act_type)
61
+ self.mlp = MLP(
62
+ in_channels=hidden_size,
63
+ hidden_channels=mlp_hidden_dim,
64
+ act_layer=act_layer,
65
+ drop=mlp_drop_rate,
66
+ **factory_kwargs,
67
+ )
68
+
69
+ self.adaLN_modulation = nn.Sequential(
70
+ act_layer(),
71
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
72
+ )
73
+ # Zero-initialize the modulation
74
+ nn.init.zeros_(self.adaLN_modulation[1].weight)
75
+ nn.init.zeros_(self.adaLN_modulation[1].bias)
76
+
77
+ def forward(
78
+ self,
79
+ x: torch.Tensor,
80
+ c: torch.Tensor, # timestep_aware_representations + context_aware_representations
81
+ attn_mask: torch.Tensor = None,
82
+ ):
83
+ gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
84
+
85
+ norm_x = self.norm1(x)
86
+ qkv = self.self_attn_qkv(norm_x)
87
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
88
+ # Apply QK-Norm if needed
89
+ q = self.self_attn_q_norm(q).to(v)
90
+ k = self.self_attn_k_norm(k).to(v)
91
+
92
+ # Self-Attention
93
+ attn = attention(q, k, v, mode="torch", attn_mask=attn_mask)
94
+
95
+ x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
96
+
97
+ # FFN Layer
98
+ x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp)
99
+
100
+ return x
101
+
102
+
103
+ class IndividualTokenRefiner(nn.Module):
104
+ def __init__(
105
+ self,
106
+ hidden_size,
107
+ heads_num,
108
+ depth,
109
+ mlp_width_ratio: float = 4.0,
110
+ mlp_drop_rate: float = 0.0,
111
+ act_type: str = "silu",
112
+ qk_norm: bool = False,
113
+ qk_norm_type: str = "layer",
114
+ qkv_bias: bool = True,
115
+ dtype: Optional[torch.dtype] = None,
116
+ device: Optional[torch.device] = None,
117
+ ):
118
+ factory_kwargs = {"device": device, "dtype": dtype}
119
+ super().__init__()
120
+ self.blocks = nn.ModuleList(
121
+ [
122
+ IndividualTokenRefinerBlock(
123
+ hidden_size=hidden_size,
124
+ heads_num=heads_num,
125
+ mlp_width_ratio=mlp_width_ratio,
126
+ mlp_drop_rate=mlp_drop_rate,
127
+ act_type=act_type,
128
+ qk_norm=qk_norm,
129
+ qk_norm_type=qk_norm_type,
130
+ qkv_bias=qkv_bias,
131
+ **factory_kwargs,
132
+ )
133
+ for _ in range(depth)
134
+ ]
135
+ )
136
+
137
+ def forward(
138
+ self,
139
+ x: torch.Tensor,
140
+ c: torch.LongTensor,
141
+ mask: Optional[torch.Tensor] = None,
142
+ ):
143
+ self_attn_mask = None
144
+ if mask is not None:
145
+ batch_size = mask.shape[0]
146
+ seq_len = mask.shape[1]
147
+ mask = mask.to(x.device)
148
+ # batch_size x 1 x seq_len x seq_len
149
+ self_attn_mask_1 = mask.view(batch_size, 1, 1, seq_len).repeat(
150
+ 1, 1, seq_len, 1
151
+ )
152
+ # batch_size x 1 x seq_len x seq_len
153
+ self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
154
+ # batch_size x 1 x seq_len x seq_len, 1 for broadcasting of heads_num
155
+ self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
156
+ # avoids self-attention weight being NaN for padding tokens
157
+ self_attn_mask[:, :, :, 0] = True
158
+
159
+ for block in self.blocks:
160
+ x = block(x, c, self_attn_mask)
161
+ return x
162
+
163
+
164
+ class SingleTokenRefiner(nn.Module):
165
+ """
166
+ A single token refiner block for llm text embedding refine.
167
+ """
168
+ def __init__(
169
+ self,
170
+ in_channels,
171
+ hidden_size,
172
+ heads_num,
173
+ depth,
174
+ mlp_width_ratio: float = 4.0,
175
+ mlp_drop_rate: float = 0.0,
176
+ act_type: str = "silu",
177
+ qk_norm: bool = False,
178
+ qk_norm_type: str = "layer",
179
+ qkv_bias: bool = True,
180
+ attn_mode: str = "torch",
181
+ dtype: Optional[torch.dtype] = None,
182
+ device: Optional[torch.device] = None,
183
+ ):
184
+ factory_kwargs = {"device": device, "dtype": dtype}
185
+ super().__init__()
186
+ self.attn_mode = attn_mode
187
+ assert self.attn_mode == "torch", "Only support 'torch' mode for token refiner."
188
+
189
+ self.input_embedder = nn.Linear(
190
+ in_channels, hidden_size, bias=True, **factory_kwargs
191
+ )
192
+
193
+ act_layer = get_activation_layer(act_type)
194
+ # Build timestep embedding layer
195
+ self.t_embedder = TimestepEmbedder(hidden_size, act_layer, **factory_kwargs)
196
+ # Build context embedding layer
197
+ self.c_embedder = TextProjection(
198
+ in_channels, hidden_size, act_layer, **factory_kwargs
199
+ )
200
+
201
+ self.individual_token_refiner = IndividualTokenRefiner(
202
+ hidden_size=hidden_size,
203
+ heads_num=heads_num,
204
+ depth=depth,
205
+ mlp_width_ratio=mlp_width_ratio,
206
+ mlp_drop_rate=mlp_drop_rate,
207
+ act_type=act_type,
208
+ qk_norm=qk_norm,
209
+ qk_norm_type=qk_norm_type,
210
+ qkv_bias=qkv_bias,
211
+ **factory_kwargs,
212
+ )
213
+
214
+ def forward(
215
+ self,
216
+ x: torch.Tensor,
217
+ t: torch.LongTensor,
218
+ mask: Optional[torch.LongTensor] = None,
219
+ ):
220
+ timestep_aware_representations = self.t_embedder(t)
221
+
222
+ if mask is None:
223
+ context_aware_representations = x.mean(dim=1)
224
+ else:
225
+ mask_float = mask.float().unsqueeze(-1) # [b, s1, 1]
226
+ context_aware_representations = (x * mask_float).sum(
227
+ dim=1
228
+ ) / mask_float.sum(dim=1)
229
+ context_aware_representations = self.c_embedder(context_aware_representations)
230
+ c = timestep_aware_representations + context_aware_representations
231
+
232
+ x = self.input_embedder(x)
233
+
234
+ x = self.individual_token_refiner(x, c, mask)
235
+
236
+ return x
hyvideo/prompt_rewrite.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ normal_mode_prompt = """Normal mode - Video Recaption Task:
2
+
3
+ You are a large language model specialized in rewriting video descriptions. Your task is to modify the input description.
4
+
5
+ 0. Preserve ALL information, including style words and technical terms.
6
+
7
+ 1. If the input is in Chinese, translate the entire description to English.
8
+
9
+ 2. If the input is just one or two words describing an object or person, provide a brief, simple description focusing on basic visual characteristics. Limit the description to 1-2 short sentences.
10
+
11
+ 3. If the input does not include style, lighting, atmosphere, you can make reasonable associations.
12
+
13
+ 4. Output ALL must be in English.
14
+
15
+ Given Input:
16
+ input: "{input}"
17
+ """
18
+
19
+
20
+ master_mode_prompt = """Master mode - Video Recaption Task:
21
+
22
+ You are a large language model specialized in rewriting video descriptions. Your task is to modify the input description.
23
+
24
+ 0. Preserve ALL information, including style words and technical terms.
25
+
26
+ 1. If the input is in Chinese, translate the entire description to English.
27
+
28
+ 2. If the input is just one or two words describing an object or person, provide a brief, simple description focusing on basic visual characteristics. Limit the description to 1-2 short sentences.
29
+
30
+ 3. If the input does not include style, lighting, atmosphere, you can make reasonable associations.
31
+
32
+ 4. Output ALL must be in English.
33
+
34
+ Given Input:
35
+ input: "{input}"
36
+ """
37
+
38
+ def get_rewrite_prompt(ori_prompt, mode="Normal"):
39
+ if mode == "Normal":
40
+ prompt = normal_mode_prompt.format(input=ori_prompt)
41
+ elif mode == "Master":
42
+ prompt = master_mode_prompt.format(input=ori_prompt)
43
+ else:
44
+ raise Exception("Only supports Normal and Normal", mode)
45
+ return prompt
46
+
47
+ ori_prompt = "一只小狗在草地上奔跑。"
48
+ normal_prompt = get_rewrite_prompt(ori_prompt, mode="Normal")
49
+ master_prompt = get_rewrite_prompt(ori_prompt, mode="Master")
50
+
51
+ # Then you can use the normal_prompt or master_prompt to access the hunyuan-large rewrite model to get the final prompt.
hyvideo/text_encoder/__init__.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional, Tuple
3
+ from copy import deepcopy
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from transformers import CLIPTextModel, CLIPTokenizer, AutoTokenizer, AutoModel
8
+ from transformers.utils import ModelOutput
9
+
10
+ from ..constants import TEXT_ENCODER_PATH, TOKENIZER_PATH
11
+ from ..constants import PRECISION_TO_TYPE
12
+
13
+
14
+ def use_default(value, default):
15
+ return value if value is not None else default
16
+
17
+
18
+ def load_text_encoder(
19
+ text_encoder_type,
20
+ text_encoder_precision=None,
21
+ text_encoder_path=None,
22
+ logger=None,
23
+ device=None,
24
+ ):
25
+ if text_encoder_path is None:
26
+ text_encoder_path = TEXT_ENCODER_PATH[text_encoder_type]
27
+ if logger is not None:
28
+ logger.info(
29
+ f"Loading text encoder model ({text_encoder_type}) from: {text_encoder_path}"
30
+ )
31
+
32
+ if text_encoder_type == "clipL":
33
+ text_encoder = CLIPTextModel.from_pretrained(text_encoder_path)
34
+ text_encoder.final_layer_norm = text_encoder.text_model.final_layer_norm
35
+ elif text_encoder_type == "llm":
36
+ text_encoder = AutoModel.from_pretrained(
37
+ text_encoder_path, low_cpu_mem_usage=True
38
+ )
39
+ text_encoder.final_layer_norm = text_encoder.norm
40
+ else:
41
+ raise ValueError(f"Unsupported text encoder type: {text_encoder_type}")
42
+ # from_pretrained will ensure that the model is in eval mode.
43
+
44
+ if text_encoder_precision is not None:
45
+ text_encoder = text_encoder.to(dtype=PRECISION_TO_TYPE[text_encoder_precision])
46
+
47
+ text_encoder.requires_grad_(False)
48
+
49
+ if logger is not None:
50
+ logger.info(f"Text encoder to dtype: {text_encoder.dtype}")
51
+
52
+ if device is not None:
53
+ text_encoder = text_encoder.to(device)
54
+
55
+ return text_encoder, text_encoder_path
56
+
57
+
58
+ def load_tokenizer(
59
+ tokenizer_type, tokenizer_path=None, padding_side="right", logger=None
60
+ ):
61
+ if tokenizer_path is None:
62
+ tokenizer_path = TOKENIZER_PATH[tokenizer_type]
63
+ if logger is not None:
64
+ logger.info(f"Loading tokenizer ({tokenizer_type}) from: {tokenizer_path}")
65
+
66
+ if tokenizer_type == "clipL":
67
+ tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path, max_length=77)
68
+ elif tokenizer_type == "llm":
69
+ tokenizer = AutoTokenizer.from_pretrained(
70
+ tokenizer_path, padding_side=padding_side
71
+ )
72
+ else:
73
+ raise ValueError(f"Unsupported tokenizer type: {tokenizer_type}")
74
+
75
+ return tokenizer, tokenizer_path
76
+
77
+
78
+ @dataclass
79
+ class TextEncoderModelOutput(ModelOutput):
80
+ """
81
+ Base class for model's outputs that also contains a pooling of the last hidden states.
82
+
83
+ Args:
84
+ hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
85
+ Sequence of hidden-states at the output of the last layer of the model.
86
+ attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
87
+ Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
88
+ hidden_states_list (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed):
89
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
90
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
91
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
92
+ text_outputs (`list`, *optional*, returned when `return_texts=True` is passed):
93
+ List of decoded texts.
94
+ """
95
+
96
+ hidden_state: torch.FloatTensor = None
97
+ attention_mask: Optional[torch.LongTensor] = None
98
+ hidden_states_list: Optional[Tuple[torch.FloatTensor, ...]] = None
99
+ text_outputs: Optional[list] = None
100
+
101
+
102
+ class TextEncoder(nn.Module):
103
+ def __init__(
104
+ self,
105
+ text_encoder_type: str,
106
+ max_length: int,
107
+ text_encoder_precision: Optional[str] = None,
108
+ text_encoder_path: Optional[str] = None,
109
+ tokenizer_type: Optional[str] = None,
110
+ tokenizer_path: Optional[str] = None,
111
+ output_key: Optional[str] = None,
112
+ use_attention_mask: bool = True,
113
+ input_max_length: Optional[int] = None,
114
+ prompt_template: Optional[dict] = None,
115
+ prompt_template_video: Optional[dict] = None,
116
+ hidden_state_skip_layer: Optional[int] = None,
117
+ apply_final_norm: bool = False,
118
+ reproduce: bool = False,
119
+ logger=None,
120
+ device=None,
121
+ ):
122
+ super().__init__()
123
+ self.text_encoder_type = text_encoder_type
124
+ self.max_length = max_length
125
+ self.precision = text_encoder_precision
126
+ self.model_path = text_encoder_path
127
+ self.tokenizer_type = (
128
+ tokenizer_type if tokenizer_type is not None else text_encoder_type
129
+ )
130
+ self.tokenizer_path = (
131
+ tokenizer_path if tokenizer_path is not None else None # text_encoder_path
132
+ )
133
+ self.use_attention_mask = use_attention_mask
134
+ if prompt_template_video is not None:
135
+ assert (
136
+ use_attention_mask is True
137
+ ), "Attention mask is True required when training videos."
138
+ self.input_max_length = (
139
+ input_max_length if input_max_length is not None else max_length
140
+ )
141
+ self.prompt_template = prompt_template
142
+ self.prompt_template_video = prompt_template_video
143
+ self.hidden_state_skip_layer = hidden_state_skip_layer
144
+ self.apply_final_norm = apply_final_norm
145
+ self.reproduce = reproduce
146
+ self.logger = logger
147
+
148
+ self.use_template = self.prompt_template is not None
149
+ if self.use_template:
150
+ assert (
151
+ isinstance(self.prompt_template, dict)
152
+ and "template" in self.prompt_template
153
+ ), f"`prompt_template` must be a dictionary with a key 'template', got {self.prompt_template}"
154
+ assert "{}" in str(self.prompt_template["template"]), (
155
+ "`prompt_template['template']` must contain a placeholder `{}` for the input text, "
156
+ f"got {self.prompt_template['template']}"
157
+ )
158
+
159
+ self.use_video_template = self.prompt_template_video is not None
160
+ if self.use_video_template:
161
+ if self.prompt_template_video is not None:
162
+ assert (
163
+ isinstance(self.prompt_template_video, dict)
164
+ and "template" in self.prompt_template_video
165
+ ), f"`prompt_template_video` must be a dictionary with a key 'template', got {self.prompt_template_video}"
166
+ assert "{}" in str(self.prompt_template_video["template"]), (
167
+ "`prompt_template_video['template']` must contain a placeholder `{}` for the input text, "
168
+ f"got {self.prompt_template_video['template']}"
169
+ )
170
+
171
+ if "t5" in text_encoder_type:
172
+ self.output_key = output_key or "last_hidden_state"
173
+ elif "clip" in text_encoder_type:
174
+ self.output_key = output_key or "pooler_output"
175
+ elif "llm" in text_encoder_type or "glm" in text_encoder_type:
176
+ self.output_key = output_key or "last_hidden_state"
177
+ else:
178
+ raise ValueError(f"Unsupported text encoder type: {text_encoder_type}")
179
+
180
+ if "llm" in text_encoder_type:
181
+
182
+ from mmgp import offload
183
+
184
+ self.model= offload.fast_load_transformers_model(self.model_path) #, pinInMemory = True, partialPinning = True
185
+ self.model.final_layer_norm = self.model.norm
186
+
187
+ else:
188
+ self.model, self.model_path = load_text_encoder(
189
+ text_encoder_type=self.text_encoder_type,
190
+ text_encoder_precision=self.precision,
191
+ text_encoder_path=self.model_path,
192
+ logger=self.logger,
193
+ device=device,
194
+ )
195
+
196
+ self.dtype = self.model.dtype
197
+ self.device = self.model.device
198
+
199
+ self.tokenizer, self.tokenizer_path = load_tokenizer(
200
+ tokenizer_type=self.tokenizer_type,
201
+ tokenizer_path=self.tokenizer_path,
202
+ padding_side="right",
203
+ logger=self.logger,
204
+ )
205
+
206
+ def __repr__(self):
207
+ return f"{self.text_encoder_type} ({self.precision} - {self.model_path})"
208
+
209
+ @staticmethod
210
+ def apply_text_to_template(text, template, prevent_empty_text=True):
211
+ """
212
+ Apply text to template.
213
+
214
+ Args:
215
+ text (str): Input text.
216
+ template (str or list): Template string or list of chat conversation.
217
+ prevent_empty_text (bool): If Ture, we will prevent the user text from being empty
218
+ by adding a space. Defaults to True.
219
+ """
220
+ if isinstance(template, str):
221
+ # Will send string to tokenizer. Used for llm
222
+ return template.format(text)
223
+ else:
224
+ raise TypeError(f"Unsupported template type: {type(template)}")
225
+
226
+ def text2tokens(self, text, data_type="image"):
227
+ """
228
+ Tokenize the input text.
229
+
230
+ Args:
231
+ text (str or list): Input text.
232
+ """
233
+ tokenize_input_type = "str"
234
+ if self.use_template:
235
+ if data_type == "image":
236
+ prompt_template = self.prompt_template["template"]
237
+ elif data_type == "video":
238
+ prompt_template = self.prompt_template_video["template"]
239
+ else:
240
+ raise ValueError(f"Unsupported data type: {data_type}")
241
+ if isinstance(text, (list, tuple)):
242
+ text = [
243
+ self.apply_text_to_template(one_text, prompt_template)
244
+ for one_text in text
245
+ ]
246
+ if isinstance(text[0], list):
247
+ tokenize_input_type = "list"
248
+ elif isinstance(text, str):
249
+ text = self.apply_text_to_template(text, prompt_template)
250
+ if isinstance(text, list):
251
+ tokenize_input_type = "list"
252
+ else:
253
+ raise TypeError(f"Unsupported text type: {type(text)}")
254
+
255
+ kwargs = dict(
256
+ truncation=True,
257
+ max_length=self.max_length,
258
+ padding="max_length",
259
+ return_tensors="pt",
260
+ )
261
+ if tokenize_input_type == "str":
262
+ return self.tokenizer(
263
+ text,
264
+ return_length=False,
265
+ return_overflowing_tokens=False,
266
+ return_attention_mask=True,
267
+ **kwargs,
268
+ )
269
+ elif tokenize_input_type == "list":
270
+ return self.tokenizer.apply_chat_template(
271
+ text,
272
+ add_generation_prompt=True,
273
+ tokenize=True,
274
+ return_dict=True,
275
+ **kwargs,
276
+ )
277
+ else:
278
+ raise ValueError(f"Unsupported tokenize_input_type: {tokenize_input_type}")
279
+
280
+ def encode(
281
+ self,
282
+ batch_encoding,
283
+ use_attention_mask=None,
284
+ output_hidden_states=False,
285
+ do_sample=None,
286
+ hidden_state_skip_layer=None,
287
+ return_texts=False,
288
+ data_type="image",
289
+ device=None,
290
+ ):
291
+ """
292
+ Args:
293
+ batch_encoding (dict): Batch encoding from tokenizer.
294
+ use_attention_mask (bool): Whether to use attention mask. If None, use self.use_attention_mask.
295
+ Defaults to None.
296
+ output_hidden_states (bool): Whether to output hidden states. If False, return the value of
297
+ self.output_key. If True, return the entire output. If set self.hidden_state_skip_layer,
298
+ output_hidden_states will be set True. Defaults to False.
299
+ do_sample (bool): Whether to sample from the model. Used for Decoder-Only LLMs. Defaults to None.
300
+ When self.produce is False, do_sample is set to True by default.
301
+ hidden_state_skip_layer (int): Number of hidden states to hidden_state_skip_layer. 0 means the last layer.
302
+ If None, self.output_key will be used. Defaults to None.
303
+ return_texts (bool): Whether to return the decoded texts. Defaults to False.
304
+ """
305
+ device = self.model.device if device is None else device
306
+ use_attention_mask = use_default(use_attention_mask, self.use_attention_mask)
307
+ hidden_state_skip_layer = use_default(
308
+ hidden_state_skip_layer, self.hidden_state_skip_layer
309
+ )
310
+ do_sample = use_default(do_sample, not self.reproduce)
311
+ attention_mask = (
312
+ batch_encoding["attention_mask"].to(device) if use_attention_mask else None
313
+ )
314
+ outputs = self.model(
315
+ input_ids=batch_encoding["input_ids"].to(device),
316
+ attention_mask=attention_mask,
317
+ output_hidden_states=output_hidden_states
318
+ or hidden_state_skip_layer is not None,
319
+ )
320
+ if hidden_state_skip_layer is not None:
321
+ last_hidden_state = outputs.hidden_states[-(hidden_state_skip_layer + 1)]
322
+ # Real last hidden state already has layer norm applied. So here we only apply it
323
+ # for intermediate layers.
324
+ if hidden_state_skip_layer > 0 and self.apply_final_norm:
325
+ last_hidden_state = self.model.final_layer_norm(last_hidden_state)
326
+ else:
327
+ last_hidden_state = outputs[self.output_key]
328
+
329
+ # Remove hidden states of instruction tokens, only keep prompt tokens.
330
+ if self.use_template:
331
+ if data_type == "image":
332
+ crop_start = self.prompt_template.get("crop_start", -1)
333
+ elif data_type == "video":
334
+ crop_start = self.prompt_template_video.get("crop_start", -1)
335
+ else:
336
+ raise ValueError(f"Unsupported data type: {data_type}")
337
+ if crop_start > 0:
338
+ last_hidden_state = last_hidden_state[:, crop_start:]
339
+ attention_mask = (
340
+ attention_mask[:, crop_start:] if use_attention_mask else None
341
+ )
342
+
343
+ if output_hidden_states:
344
+ return TextEncoderModelOutput(
345
+ last_hidden_state, attention_mask, outputs.hidden_states
346
+ )
347
+ return TextEncoderModelOutput(last_hidden_state, attention_mask)
348
+
349
+ def forward(
350
+ self,
351
+ text,
352
+ use_attention_mask=None,
353
+ output_hidden_states=False,
354
+ do_sample=False,
355
+ hidden_state_skip_layer=None,
356
+ return_texts=False,
357
+ ):
358
+ batch_encoding = self.text2tokens(text)
359
+ return self.encode(
360
+ batch_encoding,
361
+ use_attention_mask=use_attention_mask,
362
+ output_hidden_states=output_hidden_states,
363
+ do_sample=do_sample,
364
+ hidden_state_skip_layer=hidden_state_skip_layer,
365
+ return_texts=return_texts,
366
+ )
hyvideo/utils/__init__.py ADDED
File without changes
hyvideo/utils/data_utils.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import math
3
+
4
+
5
+ def align_to(value, alignment):
6
+ """align hight, width according to alignment
7
+
8
+ Args:
9
+ value (int): height or width
10
+ alignment (int): target alignment factor
11
+
12
+ Returns:
13
+ int: the aligned value
14
+ """
15
+ return int(math.ceil(value / alignment) * alignment)
hyvideo/utils/file_utils.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from einops import rearrange
4
+
5
+ import torch
6
+ import torchvision
7
+ import numpy as np
8
+ import imageio
9
+
10
+ CODE_SUFFIXES = {
11
+ ".py", # Python codes
12
+ ".sh", # Shell scripts
13
+ ".yaml",
14
+ ".yml", # Configuration files
15
+ }
16
+
17
+
18
+ def safe_dir(path):
19
+ """
20
+ Create a directory (or the parent directory of a file) if it does not exist.
21
+
22
+ Args:
23
+ path (str or Path): Path to the directory.
24
+
25
+ Returns:
26
+ path (Path): Path object of the directory.
27
+ """
28
+ path = Path(path)
29
+ path.mkdir(exist_ok=True, parents=True)
30
+ return path
31
+
32
+
33
+ def safe_file(path):
34
+ """
35
+ Create the parent directory of a file if it does not exist.
36
+
37
+ Args:
38
+ path (str or Path): Path to the file.
39
+
40
+ Returns:
41
+ path (Path): Path object of the file.
42
+ """
43
+ path = Path(path)
44
+ path.parent.mkdir(exist_ok=True, parents=True)
45
+ return path
46
+
47
+ def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=1, fps=24):
48
+ """save videos by video tensor
49
+ copy from https://github.com/guoyww/AnimateDiff/blob/e92bd5671ba62c0d774a32951453e328018b7c5b/animatediff/utils/util.py#L61
50
+
51
+ Args:
52
+ videos (torch.Tensor): video tensor predicted by the model
53
+ path (str): path to save video
54
+ rescale (bool, optional): rescale the video tensor from [-1, 1] to . Defaults to False.
55
+ n_rows (int, optional): Defaults to 1.
56
+ fps (int, optional): video save fps. Defaults to 8.
57
+ """
58
+ videos = rearrange(videos, "b c t h w -> t b c h w")
59
+ outputs = []
60
+ for x in videos:
61
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
62
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
63
+ if rescale:
64
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
65
+ x = torch.clamp(x, 0, 1)
66
+ x = (x * 255).numpy().astype(np.uint8)
67
+ outputs.append(x)
68
+
69
+ os.makedirs(os.path.dirname(path), exist_ok=True)
70
+ imageio.mimsave(path, outputs, fps=fps)
hyvideo/utils/helpers.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections.abc
2
+
3
+ from itertools import repeat
4
+
5
+
6
+ def _ntuple(n):
7
+ def parse(x):
8
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
9
+ x = tuple(x)
10
+ if len(x) == 1:
11
+ x = tuple(repeat(x[0], n))
12
+ return x
13
+ return tuple(repeat(x, n))
14
+ return parse
15
+
16
+
17
+ to_1tuple = _ntuple(1)
18
+ to_2tuple = _ntuple(2)
19
+ to_3tuple = _ntuple(3)
20
+ to_4tuple = _ntuple(4)
21
+
22
+
23
+ def as_tuple(x):
24
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
25
+ return tuple(x)
26
+ if x is None or isinstance(x, (int, float, str)):
27
+ return (x,)
28
+ else:
29
+ raise ValueError(f"Unknown type {type(x)}")
30
+
31
+
32
+ def as_list_of_2tuple(x):
33
+ x = as_tuple(x)
34
+ if len(x) == 1:
35
+ x = (x[0], x[0])
36
+ assert len(x) % 2 == 0, f"Expect even length, got {len(x)}."
37
+ lst = []
38
+ for i in range(0, len(x), 2):
39
+ lst.append((x[i], x[i + 1]))
40
+ return lst
hyvideo/utils/preprocess_text_encoder_tokenizer_utils.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ from transformers import (
4
+ AutoProcessor,
5
+ LlavaForConditionalGeneration,
6
+ )
7
+
8
+
9
+ def preprocess_text_encoder_tokenizer(args):
10
+
11
+ processor = AutoProcessor.from_pretrained(args.input_dir)
12
+ model = LlavaForConditionalGeneration.from_pretrained(
13
+ args.input_dir,
14
+ torch_dtype=torch.float16,
15
+ low_cpu_mem_usage=True,
16
+ ).to(0)
17
+
18
+ model.language_model.save_pretrained(
19
+ f"{args.output_dir}"
20
+ )
21
+ processor.tokenizer.save_pretrained(
22
+ f"{args.output_dir}"
23
+ )
24
+
25
+ if __name__ == "__main__":
26
+
27
+ parser = argparse.ArgumentParser()
28
+ parser.add_argument(
29
+ "--input_dir",
30
+ type=str,
31
+ required=True,
32
+ help="The path to the llava-llama-3-8b-v1_1-transformers.",
33
+ )
34
+ parser.add_argument(
35
+ "--output_dir",
36
+ type=str,
37
+ default="",
38
+ help="The output path of the llava-llama-3-8b-text-encoder-tokenizer."
39
+ "if '', the parent dir of output will be the same as input dir.",
40
+ )
41
+ args = parser.parse_args()
42
+
43
+ if len(args.output_dir) == 0:
44
+ args.output_dir = "/".join(args.input_dir.split("/")[:-1])
45
+
46
+ preprocess_text_encoder_tokenizer(args)
hyvideo/vae/__init__.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import torch
4
+
5
+ from .autoencoder_kl_causal_3d import AutoencoderKLCausal3D
6
+ from ..constants import VAE_PATH, PRECISION_TO_TYPE
7
+
8
+ def load_vae(vae_type: str="884-16c-hy",
9
+ vae_precision: str=None,
10
+ sample_size: tuple=None,
11
+ vae_path: str=None,
12
+ logger=None,
13
+ device=None
14
+ ):
15
+ """the fucntion to load the 3D VAE model
16
+
17
+ Args:
18
+ vae_type (str): the type of the 3D VAE model. Defaults to "884-16c-hy".
19
+ vae_precision (str, optional): the precision to load vae. Defaults to None.
20
+ sample_size (tuple, optional): the tiling size. Defaults to None.
21
+ vae_path (str, optional): the path to vae. Defaults to None.
22
+ logger (_type_, optional): logger. Defaults to None.
23
+ device (_type_, optional): device to load vae. Defaults to None.
24
+ """
25
+ if vae_path is None:
26
+ vae_path = VAE_PATH[vae_type]
27
+
28
+ if logger is not None:
29
+ logger.info(f"Loading 3D VAE model ({vae_type}) from: {vae_path}")
30
+ config = AutoencoderKLCausal3D.load_config(vae_path)
31
+ if sample_size:
32
+ vae = AutoencoderKLCausal3D.from_config(config, sample_size=sample_size)
33
+ else:
34
+ vae = AutoencoderKLCausal3D.from_config(config)
35
+
36
+ vae_ckpt = Path(vae_path) / "pytorch_model.pt"
37
+ assert vae_ckpt.exists(), f"VAE checkpoint not found: {vae_ckpt}"
38
+
39
+ ckpt = torch.load(vae_ckpt, weights_only=True, map_location=vae.device)
40
+ if "state_dict" in ckpt:
41
+ ckpt = ckpt["state_dict"]
42
+ if any(k.startswith("vae.") for k in ckpt.keys()):
43
+ ckpt = {k.replace("vae.", ""): v for k, v in ckpt.items() if k.startswith("vae.")}
44
+ vae.load_state_dict(ckpt)
45
+
46
+ spatial_compression_ratio = vae.config.spatial_compression_ratio
47
+ time_compression_ratio = vae.config.time_compression_ratio
48
+
49
+ if vae_precision is not None:
50
+ vae = vae.to(dtype=PRECISION_TO_TYPE[vae_precision])
51
+
52
+ vae.requires_grad_(False)
53
+
54
+ if logger is not None:
55
+ logger.info(f"VAE to dtype: {vae.dtype}")
56
+
57
+ if device is not None:
58
+ vae = vae.to(device)
59
+
60
+ vae.eval()
61
+
62
+ return vae, vae_path, spatial_compression_ratio, time_compression_ratio
hyvideo/vae/autoencoder_kl_causal_3d.py ADDED
@@ -0,0 +1,603 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ #
16
+ # Modified from diffusers==0.29.2
17
+ #
18
+ # ==============================================================================
19
+ from typing import Dict, Optional, Tuple, Union
20
+ from dataclasses import dataclass
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+
25
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
26
+
27
+ try:
28
+ # This diffusers is modified and packed in the mirror.
29
+ from diffusers.loaders import FromOriginalVAEMixin
30
+ except ImportError:
31
+ # Use this to be compatible with the original diffusers.
32
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin as FromOriginalVAEMixin
33
+ from diffusers.utils.accelerate_utils import apply_forward_hook
34
+ from diffusers.models.attention_processor import (
35
+ ADDED_KV_ATTENTION_PROCESSORS,
36
+ CROSS_ATTENTION_PROCESSORS,
37
+ Attention,
38
+ AttentionProcessor,
39
+ AttnAddedKVProcessor,
40
+ AttnProcessor,
41
+ )
42
+ from diffusers.models.modeling_outputs import AutoencoderKLOutput
43
+ from diffusers.models.modeling_utils import ModelMixin
44
+ from .vae import DecoderCausal3D, BaseOutput, DecoderOutput, DiagonalGaussianDistribution, EncoderCausal3D
45
+
46
+
47
+ @dataclass
48
+ class DecoderOutput2(BaseOutput):
49
+ sample: torch.FloatTensor
50
+ posterior: Optional[DiagonalGaussianDistribution] = None
51
+
52
+
53
+ class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
54
+ r"""
55
+ A VAE model with KL loss for encoding images/videos into latents and decoding latent representations into images/videos.
56
+
57
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
58
+ for all models (such as downloading or saving).
59
+ """
60
+
61
+ _supports_gradient_checkpointing = True
62
+
63
+ @register_to_config
64
+ def __init__(
65
+ self,
66
+ in_channels: int = 3,
67
+ out_channels: int = 3,
68
+ down_block_types: Tuple[str] = ("DownEncoderBlockCausal3D",),
69
+ up_block_types: Tuple[str] = ("UpDecoderBlockCausal3D",),
70
+ block_out_channels: Tuple[int] = (64,),
71
+ layers_per_block: int = 1,
72
+ act_fn: str = "silu",
73
+ latent_channels: int = 4,
74
+ norm_num_groups: int = 32,
75
+ sample_size: int = 32,
76
+ sample_tsize: int = 64,
77
+ scaling_factor: float = 0.18215,
78
+ force_upcast: float = True,
79
+ spatial_compression_ratio: int = 8,
80
+ time_compression_ratio: int = 4,
81
+ mid_block_add_attention: bool = True,
82
+ ):
83
+ super().__init__()
84
+
85
+ self.time_compression_ratio = time_compression_ratio
86
+
87
+ self.encoder = EncoderCausal3D(
88
+ in_channels=in_channels,
89
+ out_channels=latent_channels,
90
+ down_block_types=down_block_types,
91
+ block_out_channels=block_out_channels,
92
+ layers_per_block=layers_per_block,
93
+ act_fn=act_fn,
94
+ norm_num_groups=norm_num_groups,
95
+ double_z=True,
96
+ time_compression_ratio=time_compression_ratio,
97
+ spatial_compression_ratio=spatial_compression_ratio,
98
+ mid_block_add_attention=mid_block_add_attention,
99
+ )
100
+
101
+ self.decoder = DecoderCausal3D(
102
+ in_channels=latent_channels,
103
+ out_channels=out_channels,
104
+ up_block_types=up_block_types,
105
+ block_out_channels=block_out_channels,
106
+ layers_per_block=layers_per_block,
107
+ norm_num_groups=norm_num_groups,
108
+ act_fn=act_fn,
109
+ time_compression_ratio=time_compression_ratio,
110
+ spatial_compression_ratio=spatial_compression_ratio,
111
+ mid_block_add_attention=mid_block_add_attention,
112
+ )
113
+
114
+ self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, kernel_size=1)
115
+ self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, kernel_size=1)
116
+
117
+ self.use_slicing = False
118
+ self.use_spatial_tiling = False
119
+ self.use_temporal_tiling = False
120
+
121
+ # only relevant if vae tiling is enabled
122
+ self.tile_sample_min_tsize = sample_tsize
123
+ self.tile_latent_min_tsize = sample_tsize // time_compression_ratio
124
+
125
+ self.tile_sample_min_size = self.config.sample_size
126
+ sample_size = (
127
+ self.config.sample_size[0]
128
+ if isinstance(self.config.sample_size, (list, tuple))
129
+ else self.config.sample_size
130
+ )
131
+ self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
132
+ self.tile_overlap_factor = 0.25
133
+
134
+ def _set_gradient_checkpointing(self, module, value=False):
135
+ if isinstance(module, (EncoderCausal3D, DecoderCausal3D)):
136
+ module.gradient_checkpointing = value
137
+
138
+ def enable_temporal_tiling(self, use_tiling: bool = True):
139
+ self.use_temporal_tiling = use_tiling
140
+
141
+ def disable_temporal_tiling(self):
142
+ self.enable_temporal_tiling(False)
143
+
144
+ def enable_spatial_tiling(self, use_tiling: bool = True):
145
+ self.use_spatial_tiling = use_tiling
146
+
147
+ def disable_spatial_tiling(self):
148
+ self.enable_spatial_tiling(False)
149
+
150
+ def enable_tiling(self, use_tiling: bool = True):
151
+ r"""
152
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
153
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
154
+ processing larger videos.
155
+ """
156
+ self.enable_spatial_tiling(use_tiling)
157
+ self.enable_temporal_tiling(use_tiling)
158
+
159
+ def disable_tiling(self):
160
+ r"""
161
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
162
+ decoding in one step.
163
+ """
164
+ self.disable_spatial_tiling()
165
+ self.disable_temporal_tiling()
166
+
167
+ def enable_slicing(self):
168
+ r"""
169
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
170
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
171
+ """
172
+ self.use_slicing = True
173
+
174
+ def disable_slicing(self):
175
+ r"""
176
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
177
+ decoding in one step.
178
+ """
179
+ self.use_slicing = False
180
+
181
+ @property
182
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
183
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
184
+ r"""
185
+ Returns:
186
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
187
+ indexed by its weight name.
188
+ """
189
+ # set recursively
190
+ processors = {}
191
+
192
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
193
+ if hasattr(module, "get_processor"):
194
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
195
+
196
+ for sub_name, child in module.named_children():
197
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
198
+
199
+ return processors
200
+
201
+ for name, module in self.named_children():
202
+ fn_recursive_add_processors(name, module, processors)
203
+
204
+ return processors
205
+
206
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
207
+ def set_attn_processor(
208
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
209
+ ):
210
+ r"""
211
+ Sets the attention processor to use to compute attention.
212
+
213
+ Parameters:
214
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
215
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
216
+ for **all** `Attention` layers.
217
+
218
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
219
+ processor. This is strongly recommended when setting trainable attention processors.
220
+
221
+ """
222
+ count = len(self.attn_processors.keys())
223
+
224
+ if isinstance(processor, dict) and len(processor) != count:
225
+ raise ValueError(
226
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
227
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
228
+ )
229
+
230
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
231
+ if hasattr(module, "set_processor"):
232
+ if not isinstance(processor, dict):
233
+ module.set_processor(processor, _remove_lora=_remove_lora)
234
+ else:
235
+ module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
236
+
237
+ for sub_name, child in module.named_children():
238
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
239
+
240
+ for name, module in self.named_children():
241
+ fn_recursive_attn_processor(name, module, processor)
242
+
243
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
244
+ def set_default_attn_processor(self):
245
+ """
246
+ Disables custom attention processors and sets the default attention implementation.
247
+ """
248
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
249
+ processor = AttnAddedKVProcessor()
250
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
251
+ processor = AttnProcessor()
252
+ else:
253
+ raise ValueError(
254
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
255
+ )
256
+
257
+ self.set_attn_processor(processor, _remove_lora=True)
258
+
259
+ @apply_forward_hook
260
+ def encode(
261
+ self, x: torch.FloatTensor, return_dict: bool = True
262
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
263
+ """
264
+ Encode a batch of images/videos into latents.
265
+
266
+ Args:
267
+ x (`torch.FloatTensor`): Input batch of images/videos.
268
+ return_dict (`bool`, *optional*, defaults to `True`):
269
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
270
+
271
+ Returns:
272
+ The latent representations of the encoded images/videos. If `return_dict` is True, a
273
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
274
+ """
275
+ assert len(x.shape) == 5, "The input tensor should have 5 dimensions."
276
+
277
+ if self.use_temporal_tiling and x.shape[2] > self.tile_sample_min_tsize:
278
+ return self.temporal_tiled_encode(x, return_dict=return_dict)
279
+
280
+ if self.use_spatial_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
281
+ return self.spatial_tiled_encode(x, return_dict=return_dict)
282
+
283
+ if self.use_slicing and x.shape[0] > 1:
284
+ encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
285
+ h = torch.cat(encoded_slices)
286
+ else:
287
+ h = self.encoder(x)
288
+
289
+ moments = self.quant_conv(h)
290
+ posterior = DiagonalGaussianDistribution(moments)
291
+
292
+ if not return_dict:
293
+ return (posterior,)
294
+
295
+ return AutoencoderKLOutput(latent_dist=posterior)
296
+
297
+ def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
298
+ assert len(z.shape) == 5, "The input tensor should have 5 dimensions."
299
+
300
+ if self.use_temporal_tiling and z.shape[2] > self.tile_latent_min_tsize:
301
+ return self.temporal_tiled_decode(z, return_dict=return_dict)
302
+
303
+ if self.use_spatial_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
304
+ return self.spatial_tiled_decode(z, return_dict=return_dict)
305
+
306
+ z = self.post_quant_conv(z)
307
+ dec = self.decoder(z)
308
+
309
+ if not return_dict:
310
+ return (dec,)
311
+
312
+ return DecoderOutput(sample=dec)
313
+
314
+ @apply_forward_hook
315
+ def decode(
316
+ self, z: torch.FloatTensor, return_dict: bool = True, generator=None
317
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
318
+ """
319
+ Decode a batch of images/videos.
320
+
321
+ Args:
322
+ z (`torch.FloatTensor`): Input batch of latent vectors.
323
+ return_dict (`bool`, *optional*, defaults to `True`):
324
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
325
+
326
+ Returns:
327
+ [`~models.vae.DecoderOutput`] or `tuple`:
328
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
329
+ returned.
330
+
331
+ """
332
+ if self.use_slicing and z.shape[0] > 1:
333
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
334
+ decoded = torch.cat(decoded_slices)
335
+ else:
336
+ decoded = self._decode(z).sample
337
+
338
+ if not return_dict:
339
+ return (decoded,)
340
+
341
+ return DecoderOutput(sample=decoded)
342
+
343
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
344
+ blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
345
+ for y in range(blend_extent):
346
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent)
347
+ return b
348
+
349
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
350
+ blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
351
+ for x in range(blend_extent):
352
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent)
353
+ return b
354
+
355
+ def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
356
+ blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
357
+ for x in range(blend_extent):
358
+ b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (x / blend_extent)
359
+ return b
360
+
361
+ def spatial_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True, return_moments: bool = False) -> AutoencoderKLOutput:
362
+ r"""Encode a batch of images/videos using a tiled encoder.
363
+
364
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
365
+ steps. This is useful to keep memory use constant regardless of image/videos size. The end result of tiled encoding is
366
+ different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
367
+ tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
368
+ output, but they should be much less noticeable.
369
+
370
+ Args:
371
+ x (`torch.FloatTensor`): Input batch of images/videos.
372
+ return_dict (`bool`, *optional*, defaults to `True`):
373
+ Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
374
+
375
+ Returns:
376
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
377
+ If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
378
+ `tuple` is returned.
379
+ """
380
+ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
381
+ blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
382
+ row_limit = self.tile_latent_min_size - blend_extent
383
+
384
+ # Split video into tiles and encode them separately.
385
+ rows = []
386
+ for i in range(0, x.shape[-2], overlap_size):
387
+ row = []
388
+ for j in range(0, x.shape[-1], overlap_size):
389
+ tile = x[:, :, :, i: i + self.tile_sample_min_size, j: j + self.tile_sample_min_size]
390
+ tile = self.encoder(tile)
391
+ tile = self.quant_conv(tile)
392
+ row.append(tile)
393
+ rows.append(row)
394
+ result_rows = []
395
+ for i, row in enumerate(rows):
396
+ result_row = []
397
+ for j, tile in enumerate(row):
398
+ # blend the above tile and the left tile
399
+ # to the current tile and add the current tile to the result row
400
+ if i > 0:
401
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
402
+ if j > 0:
403
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
404
+ result_row.append(tile[:, :, :, :row_limit, :row_limit])
405
+ result_rows.append(torch.cat(result_row, dim=-1))
406
+
407
+ moments = torch.cat(result_rows, dim=-2)
408
+ if return_moments:
409
+ return moments
410
+
411
+ posterior = DiagonalGaussianDistribution(moments)
412
+ if not return_dict:
413
+ return (posterior,)
414
+
415
+ return AutoencoderKLOutput(latent_dist=posterior)
416
+
417
+ def spatial_tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
418
+ r"""
419
+ Decode a batch of images/videos using a tiled decoder.
420
+
421
+ Args:
422
+ z (`torch.FloatTensor`): Input batch of latent vectors.
423
+ return_dict (`bool`, *optional*, defaults to `True`):
424
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
425
+
426
+ Returns:
427
+ [`~models.vae.DecoderOutput`] or `tuple`:
428
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
429
+ returned.
430
+ """
431
+ overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
432
+ blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
433
+ row_limit = self.tile_sample_min_size - blend_extent
434
+
435
+ # Split z into overlapping tiles and decode them separately.
436
+ # The tiles have an overlap to avoid seams between tiles.
437
+ rows = []
438
+ for i in range(0, z.shape[-2], overlap_size):
439
+ row = []
440
+ for j in range(0, z.shape[-1], overlap_size):
441
+ tile = z[:, :, :, i: i + self.tile_latent_min_size, j: j + self.tile_latent_min_size]
442
+ tile = self.post_quant_conv(tile)
443
+ decoded = self.decoder(tile)
444
+ row.append(decoded)
445
+ rows.append(row)
446
+ result_rows = []
447
+ for i, row in enumerate(rows):
448
+ result_row = []
449
+ for j, tile in enumerate(row):
450
+ # blend the above tile and the left tile
451
+ # to the current tile and add the current tile to the result row
452
+ if i > 0:
453
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
454
+ if j > 0:
455
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
456
+ result_row.append(tile[:, :, :, :row_limit, :row_limit])
457
+ result_rows.append(torch.cat(result_row, dim=-1))
458
+
459
+ dec = torch.cat(result_rows, dim=-2)
460
+ if not return_dict:
461
+ return (dec,)
462
+
463
+ return DecoderOutput(sample=dec)
464
+
465
+ def temporal_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
466
+
467
+ B, C, T, H, W = x.shape
468
+ overlap_size = int(self.tile_sample_min_tsize * (1 - self.tile_overlap_factor))
469
+ blend_extent = int(self.tile_latent_min_tsize * self.tile_overlap_factor)
470
+ t_limit = self.tile_latent_min_tsize - blend_extent
471
+
472
+ # Split the video into tiles and encode them separately.
473
+ row = []
474
+ for i in range(0, T, overlap_size):
475
+ tile = x[:, :, i: i + self.tile_sample_min_tsize + 1, :, :]
476
+ if self.use_spatial_tiling and (tile.shape[-1] > self.tile_sample_min_size or tile.shape[-2] > self.tile_sample_min_size):
477
+ tile = self.spatial_tiled_encode(tile, return_moments=True)
478
+ else:
479
+ tile = self.encoder(tile)
480
+ tile = self.quant_conv(tile)
481
+ if i > 0:
482
+ tile = tile[:, :, 1:, :, :]
483
+ row.append(tile)
484
+ result_row = []
485
+ for i, tile in enumerate(row):
486
+ if i > 0:
487
+ tile = self.blend_t(row[i - 1], tile, blend_extent)
488
+ result_row.append(tile[:, :, :t_limit, :, :])
489
+ else:
490
+ result_row.append(tile[:, :, :t_limit + 1, :, :])
491
+
492
+ moments = torch.cat(result_row, dim=2)
493
+ posterior = DiagonalGaussianDistribution(moments)
494
+
495
+ if not return_dict:
496
+ return (posterior,)
497
+
498
+ return AutoencoderKLOutput(latent_dist=posterior)
499
+
500
+ def temporal_tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
501
+ # Split z into overlapping tiles and decode them separately.
502
+
503
+ B, C, T, H, W = z.shape
504
+ overlap_size = int(self.tile_latent_min_tsize * (1 - self.tile_overlap_factor))
505
+ blend_extent = int(self.tile_sample_min_tsize * self.tile_overlap_factor)
506
+ t_limit = self.tile_sample_min_tsize - blend_extent
507
+
508
+ row = []
509
+ for i in range(0, T, overlap_size):
510
+ tile = z[:, :, i: i + self.tile_latent_min_tsize + 1, :, :]
511
+ if self.use_spatial_tiling and (tile.shape[-1] > self.tile_latent_min_size or tile.shape[-2] > self.tile_latent_min_size):
512
+ decoded = self.spatial_tiled_decode(tile, return_dict=True).sample
513
+ else:
514
+ tile = self.post_quant_conv(tile)
515
+ decoded = self.decoder(tile)
516
+ if i > 0:
517
+ decoded = decoded[:, :, 1:, :, :]
518
+ row.append(decoded)
519
+ result_row = []
520
+ for i, tile in enumerate(row):
521
+ if i > 0:
522
+ tile = self.blend_t(row[i - 1], tile, blend_extent)
523
+ result_row.append(tile[:, :, :t_limit, :, :])
524
+ else:
525
+ result_row.append(tile[:, :, :t_limit + 1, :, :])
526
+
527
+ dec = torch.cat(result_row, dim=2)
528
+ if not return_dict:
529
+ return (dec,)
530
+
531
+ return DecoderOutput(sample=dec)
532
+
533
+ def forward(
534
+ self,
535
+ sample: torch.FloatTensor,
536
+ sample_posterior: bool = False,
537
+ return_dict: bool = True,
538
+ return_posterior: bool = False,
539
+ generator: Optional[torch.Generator] = None,
540
+ ) -> Union[DecoderOutput2, torch.FloatTensor]:
541
+ r"""
542
+ Args:
543
+ sample (`torch.FloatTensor`): Input sample.
544
+ sample_posterior (`bool`, *optional*, defaults to `False`):
545
+ Whether to sample from the posterior.
546
+ return_dict (`bool`, *optional*, defaults to `True`):
547
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
548
+ """
549
+ x = sample
550
+ posterior = self.encode(x).latent_dist
551
+ if sample_posterior:
552
+ z = posterior.sample(generator=generator)
553
+ else:
554
+ z = posterior.mode()
555
+ dec = self.decode(z).sample
556
+
557
+ if not return_dict:
558
+ if return_posterior:
559
+ return (dec, posterior)
560
+ else:
561
+ return (dec,)
562
+ if return_posterior:
563
+ return DecoderOutput2(sample=dec, posterior=posterior)
564
+ else:
565
+ return DecoderOutput2(sample=dec)
566
+
567
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
568
+ def fuse_qkv_projections(self):
569
+ """
570
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
571
+ key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
572
+
573
+ <Tip warning={true}>
574
+
575
+ This API is 🧪 experimental.
576
+
577
+ </Tip>
578
+ """
579
+ self.original_attn_processors = None
580
+
581
+ for _, attn_processor in self.attn_processors.items():
582
+ if "Added" in str(attn_processor.__class__.__name__):
583
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
584
+
585
+ self.original_attn_processors = self.attn_processors
586
+
587
+ for module in self.modules():
588
+ if isinstance(module, Attention):
589
+ module.fuse_projections(fuse=True)
590
+
591
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
592
+ def unfuse_qkv_projections(self):
593
+ """Disables the fused QKV projection if enabled.
594
+
595
+ <Tip warning={true}>
596
+
597
+ This API is 🧪 experimental.
598
+
599
+ </Tip>
600
+
601
+ """
602
+ if self.original_attn_processors is not None:
603
+ self.set_attn_processor(self.original_attn_processors)
hyvideo/vae/unet_causal_3d_blocks.py ADDED
@@ -0,0 +1,764 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ #
16
+ # Modified from diffusers==0.29.2
17
+ #
18
+ # ==============================================================================
19
+
20
+ from typing import Optional, Tuple, Union
21
+
22
+ import torch
23
+ import torch.nn.functional as F
24
+ from torch import nn
25
+ from einops import rearrange
26
+
27
+ from diffusers.utils import logging
28
+ from diffusers.models.activations import get_activation
29
+ from diffusers.models.attention_processor import SpatialNorm
30
+ from diffusers.models.attention_processor import Attention
31
+ from diffusers.models.normalization import AdaGroupNorm
32
+ from diffusers.models.normalization import RMSNorm
33
+
34
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
35
+
36
+
37
+ def prepare_causal_attention_mask(n_frame: int, n_hw: int, dtype, device, batch_size: int = None):
38
+ seq_len = n_frame * n_hw
39
+ mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device)
40
+ for i in range(seq_len):
41
+ i_frame = i // n_hw
42
+ mask[i, : (i_frame + 1) * n_hw] = 0
43
+ if batch_size is not None:
44
+ mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
45
+ return mask
46
+
47
+
48
+ class CausalConv3d(nn.Module):
49
+ """
50
+ Implements a causal 3D convolution layer where each position only depends on previous timesteps and current spatial locations.
51
+ This maintains temporal causality in video generation tasks.
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ chan_in,
57
+ chan_out,
58
+ kernel_size: Union[int, Tuple[int, int, int]],
59
+ stride: Union[int, Tuple[int, int, int]] = 1,
60
+ dilation: Union[int, Tuple[int, int, int]] = 1,
61
+ pad_mode='replicate',
62
+ **kwargs
63
+ ):
64
+ super().__init__()
65
+
66
+ self.pad_mode = pad_mode
67
+ padding = (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size - 1, 0) # W, H, T
68
+ self.time_causal_padding = padding
69
+
70
+ self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
71
+
72
+ def forward(self, x):
73
+ x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
74
+ return self.conv(x)
75
+
76
+
77
+ class UpsampleCausal3D(nn.Module):
78
+ """
79
+ A 3D upsampling layer with an optional convolution.
80
+ """
81
+
82
+ def __init__(
83
+ self,
84
+ channels: int,
85
+ use_conv: bool = False,
86
+ use_conv_transpose: bool = False,
87
+ out_channels: Optional[int] = None,
88
+ name: str = "conv",
89
+ kernel_size: Optional[int] = None,
90
+ padding=1,
91
+ norm_type=None,
92
+ eps=None,
93
+ elementwise_affine=None,
94
+ bias=True,
95
+ interpolate=True,
96
+ upsample_factor=(2, 2, 2),
97
+ ):
98
+ super().__init__()
99
+ self.channels = channels
100
+ self.out_channels = out_channels or channels
101
+ self.use_conv = use_conv
102
+ self.use_conv_transpose = use_conv_transpose
103
+ self.name = name
104
+ self.interpolate = interpolate
105
+ self.upsample_factor = upsample_factor
106
+
107
+ if norm_type == "ln_norm":
108
+ self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
109
+ elif norm_type == "rms_norm":
110
+ self.norm = RMSNorm(channels, eps, elementwise_affine)
111
+ elif norm_type is None:
112
+ self.norm = None
113
+ else:
114
+ raise ValueError(f"unknown norm_type: {norm_type}")
115
+
116
+ conv = None
117
+ if use_conv_transpose:
118
+ raise NotImplementedError
119
+ elif use_conv:
120
+ if kernel_size is None:
121
+ kernel_size = 3
122
+ conv = CausalConv3d(self.channels, self.out_channels, kernel_size=kernel_size, bias=bias)
123
+
124
+ if name == "conv":
125
+ self.conv = conv
126
+ else:
127
+ self.Conv2d_0 = conv
128
+
129
+ def forward(
130
+ self,
131
+ hidden_states: torch.FloatTensor,
132
+ output_size: Optional[int] = None,
133
+ scale: float = 1.0,
134
+ ) -> torch.FloatTensor:
135
+ assert hidden_states.shape[1] == self.channels
136
+
137
+ if self.norm is not None:
138
+ raise NotImplementedError
139
+
140
+ if self.use_conv_transpose:
141
+ return self.conv(hidden_states)
142
+
143
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
144
+ dtype = hidden_states.dtype
145
+ if dtype == torch.bfloat16:
146
+ hidden_states = hidden_states.to(torch.float32)
147
+
148
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
149
+ if hidden_states.shape[0] >= 64:
150
+ hidden_states = hidden_states.contiguous()
151
+
152
+ # if `output_size` is passed we force the interpolation output
153
+ # size and do not make use of `scale_factor=2`
154
+ if self.interpolate:
155
+ B, C, T, H, W = hidden_states.shape
156
+ first_h, other_h = hidden_states.split((1, T - 1), dim=2)
157
+ if output_size is None:
158
+ if T > 1:
159
+ other_h = F.interpolate(other_h, scale_factor=self.upsample_factor, mode="nearest")
160
+
161
+ first_h = first_h.squeeze(2)
162
+ first_h = F.interpolate(first_h, scale_factor=self.upsample_factor[1:], mode="nearest")
163
+ first_h = first_h.unsqueeze(2)
164
+ else:
165
+ raise NotImplementedError
166
+
167
+ if T > 1:
168
+ hidden_states = torch.cat((first_h, other_h), dim=2)
169
+ else:
170
+ hidden_states = first_h
171
+
172
+ # If the input is bfloat16, we cast back to bfloat16
173
+ if dtype == torch.bfloat16:
174
+ hidden_states = hidden_states.to(dtype)
175
+
176
+ if self.use_conv:
177
+ if self.name == "conv":
178
+ hidden_states = self.conv(hidden_states)
179
+ else:
180
+ hidden_states = self.Conv2d_0(hidden_states)
181
+
182
+ return hidden_states
183
+
184
+
185
+ class DownsampleCausal3D(nn.Module):
186
+ """
187
+ A 3D downsampling layer with an optional convolution.
188
+ """
189
+
190
+ def __init__(
191
+ self,
192
+ channels: int,
193
+ use_conv: bool = False,
194
+ out_channels: Optional[int] = None,
195
+ padding: int = 1,
196
+ name: str = "conv",
197
+ kernel_size=3,
198
+ norm_type=None,
199
+ eps=None,
200
+ elementwise_affine=None,
201
+ bias=True,
202
+ stride=2,
203
+ ):
204
+ super().__init__()
205
+ self.channels = channels
206
+ self.out_channels = out_channels or channels
207
+ self.use_conv = use_conv
208
+ self.padding = padding
209
+ stride = stride
210
+ self.name = name
211
+
212
+ if norm_type == "ln_norm":
213
+ self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
214
+ elif norm_type == "rms_norm":
215
+ self.norm = RMSNorm(channels, eps, elementwise_affine)
216
+ elif norm_type is None:
217
+ self.norm = None
218
+ else:
219
+ raise ValueError(f"unknown norm_type: {norm_type}")
220
+
221
+ if use_conv:
222
+ conv = CausalConv3d(
223
+ self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, bias=bias
224
+ )
225
+ else:
226
+ raise NotImplementedError
227
+
228
+ if name == "conv":
229
+ self.Conv2d_0 = conv
230
+ self.conv = conv
231
+ elif name == "Conv2d_0":
232
+ self.conv = conv
233
+ else:
234
+ self.conv = conv
235
+
236
+ def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
237
+ assert hidden_states.shape[1] == self.channels
238
+
239
+ if self.norm is not None:
240
+ hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
241
+
242
+ assert hidden_states.shape[1] == self.channels
243
+
244
+ hidden_states = self.conv(hidden_states)
245
+
246
+ return hidden_states
247
+
248
+
249
+ class ResnetBlockCausal3D(nn.Module):
250
+ r"""
251
+ A Resnet block.
252
+ """
253
+
254
+ def __init__(
255
+ self,
256
+ *,
257
+ in_channels: int,
258
+ out_channels: Optional[int] = None,
259
+ conv_shortcut: bool = False,
260
+ dropout: float = 0.0,
261
+ temb_channels: int = 512,
262
+ groups: int = 32,
263
+ groups_out: Optional[int] = None,
264
+ pre_norm: bool = True,
265
+ eps: float = 1e-6,
266
+ non_linearity: str = "swish",
267
+ skip_time_act: bool = False,
268
+ # default, scale_shift, ada_group, spatial
269
+ time_embedding_norm: str = "default",
270
+ kernel: Optional[torch.FloatTensor] = None,
271
+ output_scale_factor: float = 1.0,
272
+ use_in_shortcut: Optional[bool] = None,
273
+ up: bool = False,
274
+ down: bool = False,
275
+ conv_shortcut_bias: bool = True,
276
+ conv_3d_out_channels: Optional[int] = None,
277
+ ):
278
+ super().__init__()
279
+ self.pre_norm = pre_norm
280
+ self.pre_norm = True
281
+ self.in_channels = in_channels
282
+ out_channels = in_channels if out_channels is None else out_channels
283
+ self.out_channels = out_channels
284
+ self.use_conv_shortcut = conv_shortcut
285
+ self.up = up
286
+ self.down = down
287
+ self.output_scale_factor = output_scale_factor
288
+ self.time_embedding_norm = time_embedding_norm
289
+ self.skip_time_act = skip_time_act
290
+
291
+ linear_cls = nn.Linear
292
+
293
+ if groups_out is None:
294
+ groups_out = groups
295
+
296
+ if self.time_embedding_norm == "ada_group":
297
+ self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
298
+ elif self.time_embedding_norm == "spatial":
299
+ self.norm1 = SpatialNorm(in_channels, temb_channels)
300
+ else:
301
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
302
+
303
+ self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, stride=1)
304
+
305
+ if temb_channels is not None:
306
+ if self.time_embedding_norm == "default":
307
+ self.time_emb_proj = linear_cls(temb_channels, out_channels)
308
+ elif self.time_embedding_norm == "scale_shift":
309
+ self.time_emb_proj = linear_cls(temb_channels, 2 * out_channels)
310
+ elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
311
+ self.time_emb_proj = None
312
+ else:
313
+ raise ValueError(f"Unknown time_embedding_norm : {self.time_embedding_norm} ")
314
+ else:
315
+ self.time_emb_proj = None
316
+
317
+ if self.time_embedding_norm == "ada_group":
318
+ self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
319
+ elif self.time_embedding_norm == "spatial":
320
+ self.norm2 = SpatialNorm(out_channels, temb_channels)
321
+ else:
322
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
323
+
324
+ self.dropout = torch.nn.Dropout(dropout)
325
+ conv_3d_out_channels = conv_3d_out_channels or out_channels
326
+ self.conv2 = CausalConv3d(out_channels, conv_3d_out_channels, kernel_size=3, stride=1)
327
+
328
+ self.nonlinearity = get_activation(non_linearity)
329
+
330
+ self.upsample = self.downsample = None
331
+ if self.up:
332
+ self.upsample = UpsampleCausal3D(in_channels, use_conv=False)
333
+ elif self.down:
334
+ self.downsample = DownsampleCausal3D(in_channels, use_conv=False, name="op")
335
+
336
+ self.use_in_shortcut = self.in_channels != conv_3d_out_channels if use_in_shortcut is None else use_in_shortcut
337
+
338
+ self.conv_shortcut = None
339
+ if self.use_in_shortcut:
340
+ self.conv_shortcut = CausalConv3d(
341
+ in_channels,
342
+ conv_3d_out_channels,
343
+ kernel_size=1,
344
+ stride=1,
345
+ bias=conv_shortcut_bias,
346
+ )
347
+
348
+ def forward(
349
+ self,
350
+ input_tensor: torch.FloatTensor,
351
+ temb: torch.FloatTensor,
352
+ scale: float = 1.0,
353
+ ) -> torch.FloatTensor:
354
+ hidden_states = input_tensor
355
+
356
+ if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
357
+ hidden_states = self.norm1(hidden_states, temb)
358
+ else:
359
+ hidden_states = self.norm1(hidden_states)
360
+
361
+ hidden_states = self.nonlinearity(hidden_states)
362
+
363
+ if self.upsample is not None:
364
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
365
+ if hidden_states.shape[0] >= 64:
366
+ input_tensor = input_tensor.contiguous()
367
+ hidden_states = hidden_states.contiguous()
368
+ input_tensor = (
369
+ self.upsample(input_tensor, scale=scale)
370
+ )
371
+ hidden_states = (
372
+ self.upsample(hidden_states, scale=scale)
373
+ )
374
+ elif self.downsample is not None:
375
+ input_tensor = (
376
+ self.downsample(input_tensor, scale=scale)
377
+ )
378
+ hidden_states = (
379
+ self.downsample(hidden_states, scale=scale)
380
+ )
381
+
382
+ hidden_states = self.conv1(hidden_states)
383
+
384
+ if self.time_emb_proj is not None:
385
+ if not self.skip_time_act:
386
+ temb = self.nonlinearity(temb)
387
+ temb = (
388
+ self.time_emb_proj(temb, scale)[:, :, None, None]
389
+ )
390
+
391
+ if temb is not None and self.time_embedding_norm == "default":
392
+ hidden_states = hidden_states + temb
393
+
394
+ if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
395
+ hidden_states = self.norm2(hidden_states, temb)
396
+ else:
397
+ hidden_states = self.norm2(hidden_states)
398
+
399
+ if temb is not None and self.time_embedding_norm == "scale_shift":
400
+ scale, shift = torch.chunk(temb, 2, dim=1)
401
+ hidden_states = hidden_states * (1 + scale) + shift
402
+
403
+ hidden_states = self.nonlinearity(hidden_states)
404
+
405
+ hidden_states = self.dropout(hidden_states)
406
+ hidden_states = self.conv2(hidden_states)
407
+
408
+ if self.conv_shortcut is not None:
409
+ input_tensor = (
410
+ self.conv_shortcut(input_tensor)
411
+ )
412
+
413
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
414
+
415
+ return output_tensor
416
+
417
+
418
+ def get_down_block3d(
419
+ down_block_type: str,
420
+ num_layers: int,
421
+ in_channels: int,
422
+ out_channels: int,
423
+ temb_channels: int,
424
+ add_downsample: bool,
425
+ downsample_stride: int,
426
+ resnet_eps: float,
427
+ resnet_act_fn: str,
428
+ transformer_layers_per_block: int = 1,
429
+ num_attention_heads: Optional[int] = None,
430
+ resnet_groups: Optional[int] = None,
431
+ cross_attention_dim: Optional[int] = None,
432
+ downsample_padding: Optional[int] = None,
433
+ dual_cross_attention: bool = False,
434
+ use_linear_projection: bool = False,
435
+ only_cross_attention: bool = False,
436
+ upcast_attention: bool = False,
437
+ resnet_time_scale_shift: str = "default",
438
+ attention_type: str = "default",
439
+ resnet_skip_time_act: bool = False,
440
+ resnet_out_scale_factor: float = 1.0,
441
+ cross_attention_norm: Optional[str] = None,
442
+ attention_head_dim: Optional[int] = None,
443
+ downsample_type: Optional[str] = None,
444
+ dropout: float = 0.0,
445
+ ):
446
+ # If attn head dim is not defined, we default it to the number of heads
447
+ if attention_head_dim is None:
448
+ logger.warn(
449
+ f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
450
+ )
451
+ attention_head_dim = num_attention_heads
452
+
453
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
454
+ if down_block_type == "DownEncoderBlockCausal3D":
455
+ return DownEncoderBlockCausal3D(
456
+ num_layers=num_layers,
457
+ in_channels=in_channels,
458
+ out_channels=out_channels,
459
+ dropout=dropout,
460
+ add_downsample=add_downsample,
461
+ downsample_stride=downsample_stride,
462
+ resnet_eps=resnet_eps,
463
+ resnet_act_fn=resnet_act_fn,
464
+ resnet_groups=resnet_groups,
465
+ downsample_padding=downsample_padding,
466
+ resnet_time_scale_shift=resnet_time_scale_shift,
467
+ )
468
+ raise ValueError(f"{down_block_type} does not exist.")
469
+
470
+
471
+ def get_up_block3d(
472
+ up_block_type: str,
473
+ num_layers: int,
474
+ in_channels: int,
475
+ out_channels: int,
476
+ prev_output_channel: int,
477
+ temb_channels: int,
478
+ add_upsample: bool,
479
+ upsample_scale_factor: Tuple,
480
+ resnet_eps: float,
481
+ resnet_act_fn: str,
482
+ resolution_idx: Optional[int] = None,
483
+ transformer_layers_per_block: int = 1,
484
+ num_attention_heads: Optional[int] = None,
485
+ resnet_groups: Optional[int] = None,
486
+ cross_attention_dim: Optional[int] = None,
487
+ dual_cross_attention: bool = False,
488
+ use_linear_projection: bool = False,
489
+ only_cross_attention: bool = False,
490
+ upcast_attention: bool = False,
491
+ resnet_time_scale_shift: str = "default",
492
+ attention_type: str = "default",
493
+ resnet_skip_time_act: bool = False,
494
+ resnet_out_scale_factor: float = 1.0,
495
+ cross_attention_norm: Optional[str] = None,
496
+ attention_head_dim: Optional[int] = None,
497
+ upsample_type: Optional[str] = None,
498
+ dropout: float = 0.0,
499
+ ) -> nn.Module:
500
+ # If attn head dim is not defined, we default it to the number of heads
501
+ if attention_head_dim is None:
502
+ logger.warn(
503
+ f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
504
+ )
505
+ attention_head_dim = num_attention_heads
506
+
507
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
508
+ if up_block_type == "UpDecoderBlockCausal3D":
509
+ return UpDecoderBlockCausal3D(
510
+ num_layers=num_layers,
511
+ in_channels=in_channels,
512
+ out_channels=out_channels,
513
+ resolution_idx=resolution_idx,
514
+ dropout=dropout,
515
+ add_upsample=add_upsample,
516
+ upsample_scale_factor=upsample_scale_factor,
517
+ resnet_eps=resnet_eps,
518
+ resnet_act_fn=resnet_act_fn,
519
+ resnet_groups=resnet_groups,
520
+ resnet_time_scale_shift=resnet_time_scale_shift,
521
+ temb_channels=temb_channels,
522
+ )
523
+ raise ValueError(f"{up_block_type} does not exist.")
524
+
525
+
526
+ class UNetMidBlockCausal3D(nn.Module):
527
+ """
528
+ A 3D UNet mid-block [`UNetMidBlockCausal3D`] with multiple residual blocks and optional attention blocks.
529
+ """
530
+
531
+ def __init__(
532
+ self,
533
+ in_channels: int,
534
+ temb_channels: int,
535
+ dropout: float = 0.0,
536
+ num_layers: int = 1,
537
+ resnet_eps: float = 1e-6,
538
+ resnet_time_scale_shift: str = "default", # default, spatial
539
+ resnet_act_fn: str = "swish",
540
+ resnet_groups: int = 32,
541
+ attn_groups: Optional[int] = None,
542
+ resnet_pre_norm: bool = True,
543
+ add_attention: bool = True,
544
+ attention_head_dim: int = 1,
545
+ output_scale_factor: float = 1.0,
546
+ ):
547
+ super().__init__()
548
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
549
+ self.add_attention = add_attention
550
+
551
+ if attn_groups is None:
552
+ attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None
553
+
554
+ # there is always at least one resnet
555
+ resnets = [
556
+ ResnetBlockCausal3D(
557
+ in_channels=in_channels,
558
+ out_channels=in_channels,
559
+ temb_channels=temb_channels,
560
+ eps=resnet_eps,
561
+ groups=resnet_groups,
562
+ dropout=dropout,
563
+ time_embedding_norm=resnet_time_scale_shift,
564
+ non_linearity=resnet_act_fn,
565
+ output_scale_factor=output_scale_factor,
566
+ pre_norm=resnet_pre_norm,
567
+ )
568
+ ]
569
+ attentions = []
570
+
571
+ if attention_head_dim is None:
572
+ logger.warn(
573
+ f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
574
+ )
575
+ attention_head_dim = in_channels
576
+
577
+ for _ in range(num_layers):
578
+ if self.add_attention:
579
+ attentions.append(
580
+ Attention(
581
+ in_channels,
582
+ heads=in_channels // attention_head_dim,
583
+ dim_head=attention_head_dim,
584
+ rescale_output_factor=output_scale_factor,
585
+ eps=resnet_eps,
586
+ norm_num_groups=attn_groups,
587
+ spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None,
588
+ residual_connection=True,
589
+ bias=True,
590
+ upcast_softmax=True,
591
+ _from_deprecated_attn_block=True,
592
+ )
593
+ )
594
+ else:
595
+ attentions.append(None)
596
+
597
+ resnets.append(
598
+ ResnetBlockCausal3D(
599
+ in_channels=in_channels,
600
+ out_channels=in_channels,
601
+ temb_channels=temb_channels,
602
+ eps=resnet_eps,
603
+ groups=resnet_groups,
604
+ dropout=dropout,
605
+ time_embedding_norm=resnet_time_scale_shift,
606
+ non_linearity=resnet_act_fn,
607
+ output_scale_factor=output_scale_factor,
608
+ pre_norm=resnet_pre_norm,
609
+ )
610
+ )
611
+
612
+ self.attentions = nn.ModuleList(attentions)
613
+ self.resnets = nn.ModuleList(resnets)
614
+
615
+ def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
616
+ hidden_states = self.resnets[0](hidden_states, temb)
617
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
618
+ if attn is not None:
619
+ B, C, T, H, W = hidden_states.shape
620
+ hidden_states = rearrange(hidden_states, "b c f h w -> b (f h w) c")
621
+ attention_mask = prepare_causal_attention_mask(
622
+ T, H * W, hidden_states.dtype, hidden_states.device, batch_size=B
623
+ )
624
+ hidden_states = attn(hidden_states, temb=temb, attention_mask=attention_mask)
625
+ hidden_states = rearrange(hidden_states, "b (f h w) c -> b c f h w", f=T, h=H, w=W)
626
+ hidden_states = resnet(hidden_states, temb)
627
+
628
+ return hidden_states
629
+
630
+
631
+ class DownEncoderBlockCausal3D(nn.Module):
632
+ def __init__(
633
+ self,
634
+ in_channels: int,
635
+ out_channels: int,
636
+ dropout: float = 0.0,
637
+ num_layers: int = 1,
638
+ resnet_eps: float = 1e-6,
639
+ resnet_time_scale_shift: str = "default",
640
+ resnet_act_fn: str = "swish",
641
+ resnet_groups: int = 32,
642
+ resnet_pre_norm: bool = True,
643
+ output_scale_factor: float = 1.0,
644
+ add_downsample: bool = True,
645
+ downsample_stride: int = 2,
646
+ downsample_padding: int = 1,
647
+ ):
648
+ super().__init__()
649
+ resnets = []
650
+
651
+ for i in range(num_layers):
652
+ in_channels = in_channels if i == 0 else out_channels
653
+ resnets.append(
654
+ ResnetBlockCausal3D(
655
+ in_channels=in_channels,
656
+ out_channels=out_channels,
657
+ temb_channels=None,
658
+ eps=resnet_eps,
659
+ groups=resnet_groups,
660
+ dropout=dropout,
661
+ time_embedding_norm=resnet_time_scale_shift,
662
+ non_linearity=resnet_act_fn,
663
+ output_scale_factor=output_scale_factor,
664
+ pre_norm=resnet_pre_norm,
665
+ )
666
+ )
667
+
668
+ self.resnets = nn.ModuleList(resnets)
669
+
670
+ if add_downsample:
671
+ self.downsamplers = nn.ModuleList(
672
+ [
673
+ DownsampleCausal3D(
674
+ out_channels,
675
+ use_conv=True,
676
+ out_channels=out_channels,
677
+ padding=downsample_padding,
678
+ name="op",
679
+ stride=downsample_stride,
680
+ )
681
+ ]
682
+ )
683
+ else:
684
+ self.downsamplers = None
685
+
686
+ def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
687
+ for resnet in self.resnets:
688
+ hidden_states = resnet(hidden_states, temb=None, scale=scale)
689
+
690
+ if self.downsamplers is not None:
691
+ for downsampler in self.downsamplers:
692
+ hidden_states = downsampler(hidden_states, scale)
693
+
694
+ return hidden_states
695
+
696
+
697
+ class UpDecoderBlockCausal3D(nn.Module):
698
+ def __init__(
699
+ self,
700
+ in_channels: int,
701
+ out_channels: int,
702
+ resolution_idx: Optional[int] = None,
703
+ dropout: float = 0.0,
704
+ num_layers: int = 1,
705
+ resnet_eps: float = 1e-6,
706
+ resnet_time_scale_shift: str = "default", # default, spatial
707
+ resnet_act_fn: str = "swish",
708
+ resnet_groups: int = 32,
709
+ resnet_pre_norm: bool = True,
710
+ output_scale_factor: float = 1.0,
711
+ add_upsample: bool = True,
712
+ upsample_scale_factor=(2, 2, 2),
713
+ temb_channels: Optional[int] = None,
714
+ ):
715
+ super().__init__()
716
+ resnets = []
717
+
718
+ for i in range(num_layers):
719
+ input_channels = in_channels if i == 0 else out_channels
720
+
721
+ resnets.append(
722
+ ResnetBlockCausal3D(
723
+ in_channels=input_channels,
724
+ out_channels=out_channels,
725
+ temb_channels=temb_channels,
726
+ eps=resnet_eps,
727
+ groups=resnet_groups,
728
+ dropout=dropout,
729
+ time_embedding_norm=resnet_time_scale_shift,
730
+ non_linearity=resnet_act_fn,
731
+ output_scale_factor=output_scale_factor,
732
+ pre_norm=resnet_pre_norm,
733
+ )
734
+ )
735
+
736
+ self.resnets = nn.ModuleList(resnets)
737
+
738
+ if add_upsample:
739
+ self.upsamplers = nn.ModuleList(
740
+ [
741
+ UpsampleCausal3D(
742
+ out_channels,
743
+ use_conv=True,
744
+ out_channels=out_channels,
745
+ upsample_factor=upsample_scale_factor,
746
+ )
747
+ ]
748
+ )
749
+ else:
750
+ self.upsamplers = None
751
+
752
+ self.resolution_idx = resolution_idx
753
+
754
+ def forward(
755
+ self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
756
+ ) -> torch.FloatTensor:
757
+ for resnet in self.resnets:
758
+ hidden_states = resnet(hidden_states, temb=temb, scale=scale)
759
+
760
+ if self.upsamplers is not None:
761
+ for upsampler in self.upsamplers:
762
+ hidden_states = upsampler(hidden_states)
763
+
764
+ return hidden_states
hyvideo/vae/vae.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional, Tuple
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from diffusers.utils import BaseOutput, is_torch_version
9
+ from diffusers.utils.torch_utils import randn_tensor
10
+ from diffusers.models.attention_processor import SpatialNorm
11
+ from .unet_causal_3d_blocks import (
12
+ CausalConv3d,
13
+ UNetMidBlockCausal3D,
14
+ get_down_block3d,
15
+ get_up_block3d,
16
+ )
17
+
18
+
19
+ @dataclass
20
+ class DecoderOutput(BaseOutput):
21
+ r"""
22
+ Output of decoding method.
23
+
24
+ Args:
25
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
26
+ The decoded output sample from the last layer of the model.
27
+ """
28
+
29
+ sample: torch.FloatTensor
30
+
31
+
32
+ class EncoderCausal3D(nn.Module):
33
+ r"""
34
+ The `EncoderCausal3D` layer of a variational autoencoder that encodes its input into a latent representation.
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ in_channels: int = 3,
40
+ out_channels: int = 3,
41
+ down_block_types: Tuple[str, ...] = ("DownEncoderBlockCausal3D",),
42
+ block_out_channels: Tuple[int, ...] = (64,),
43
+ layers_per_block: int = 2,
44
+ norm_num_groups: int = 32,
45
+ act_fn: str = "silu",
46
+ double_z: bool = True,
47
+ mid_block_add_attention=True,
48
+ time_compression_ratio: int = 4,
49
+ spatial_compression_ratio: int = 8,
50
+ ):
51
+ super().__init__()
52
+ self.layers_per_block = layers_per_block
53
+
54
+ self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1)
55
+ self.mid_block = None
56
+ self.down_blocks = nn.ModuleList([])
57
+
58
+ # down
59
+ output_channel = block_out_channels[0]
60
+ for i, down_block_type in enumerate(down_block_types):
61
+ input_channel = output_channel
62
+ output_channel = block_out_channels[i]
63
+ is_final_block = i == len(block_out_channels) - 1
64
+ num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio))
65
+ num_time_downsample_layers = int(np.log2(time_compression_ratio))
66
+
67
+ if time_compression_ratio == 4:
68
+ add_spatial_downsample = bool(i < num_spatial_downsample_layers)
69
+ add_time_downsample = bool(
70
+ i >= (len(block_out_channels) - 1 - num_time_downsample_layers)
71
+ and not is_final_block
72
+ )
73
+ else:
74
+ raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.")
75
+
76
+ downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1)
77
+ downsample_stride_T = (2,) if add_time_downsample else (1,)
78
+ downsample_stride = tuple(downsample_stride_T + downsample_stride_HW)
79
+ down_block = get_down_block3d(
80
+ down_block_type,
81
+ num_layers=self.layers_per_block,
82
+ in_channels=input_channel,
83
+ out_channels=output_channel,
84
+ add_downsample=bool(add_spatial_downsample or add_time_downsample),
85
+ downsample_stride=downsample_stride,
86
+ resnet_eps=1e-6,
87
+ downsample_padding=0,
88
+ resnet_act_fn=act_fn,
89
+ resnet_groups=norm_num_groups,
90
+ attention_head_dim=output_channel,
91
+ temb_channels=None,
92
+ )
93
+ self.down_blocks.append(down_block)
94
+
95
+ # mid
96
+ self.mid_block = UNetMidBlockCausal3D(
97
+ in_channels=block_out_channels[-1],
98
+ resnet_eps=1e-6,
99
+ resnet_act_fn=act_fn,
100
+ output_scale_factor=1,
101
+ resnet_time_scale_shift="default",
102
+ attention_head_dim=block_out_channels[-1],
103
+ resnet_groups=norm_num_groups,
104
+ temb_channels=None,
105
+ add_attention=mid_block_add_attention,
106
+ )
107
+
108
+ # out
109
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
110
+ self.conv_act = nn.SiLU()
111
+
112
+ conv_out_channels = 2 * out_channels if double_z else out_channels
113
+ self.conv_out = CausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3)
114
+
115
+ def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
116
+ r"""The forward method of the `EncoderCausal3D` class."""
117
+ assert len(sample.shape) == 5, "The input tensor should have 5 dimensions"
118
+
119
+ sample = self.conv_in(sample)
120
+
121
+ # down
122
+ for down_block in self.down_blocks:
123
+ sample = down_block(sample)
124
+
125
+ # middle
126
+ sample = self.mid_block(sample)
127
+
128
+ # post-process
129
+ sample = self.conv_norm_out(sample)
130
+ sample = self.conv_act(sample)
131
+ sample = self.conv_out(sample)
132
+
133
+ return sample
134
+
135
+
136
+ class DecoderCausal3D(nn.Module):
137
+ r"""
138
+ The `DecoderCausal3D` layer of a variational autoencoder that decodes its latent representation into an output sample.
139
+ """
140
+
141
+ def __init__(
142
+ self,
143
+ in_channels: int = 3,
144
+ out_channels: int = 3,
145
+ up_block_types: Tuple[str, ...] = ("UpDecoderBlockCausal3D",),
146
+ block_out_channels: Tuple[int, ...] = (64,),
147
+ layers_per_block: int = 2,
148
+ norm_num_groups: int = 32,
149
+ act_fn: str = "silu",
150
+ norm_type: str = "group", # group, spatial
151
+ mid_block_add_attention=True,
152
+ time_compression_ratio: int = 4,
153
+ spatial_compression_ratio: int = 8,
154
+ ):
155
+ super().__init__()
156
+ self.layers_per_block = layers_per_block
157
+
158
+ self.conv_in = CausalConv3d(in_channels, block_out_channels[-1], kernel_size=3, stride=1)
159
+ self.mid_block = None
160
+ self.up_blocks = nn.ModuleList([])
161
+
162
+ temb_channels = in_channels if norm_type == "spatial" else None
163
+
164
+ # mid
165
+ self.mid_block = UNetMidBlockCausal3D(
166
+ in_channels=block_out_channels[-1],
167
+ resnet_eps=1e-6,
168
+ resnet_act_fn=act_fn,
169
+ output_scale_factor=1,
170
+ resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
171
+ attention_head_dim=block_out_channels[-1],
172
+ resnet_groups=norm_num_groups,
173
+ temb_channels=temb_channels,
174
+ add_attention=mid_block_add_attention,
175
+ )
176
+
177
+ # up
178
+ reversed_block_out_channels = list(reversed(block_out_channels))
179
+ output_channel = reversed_block_out_channels[0]
180
+ for i, up_block_type in enumerate(up_block_types):
181
+ prev_output_channel = output_channel
182
+ output_channel = reversed_block_out_channels[i]
183
+ is_final_block = i == len(block_out_channels) - 1
184
+ num_spatial_upsample_layers = int(np.log2(spatial_compression_ratio))
185
+ num_time_upsample_layers = int(np.log2(time_compression_ratio))
186
+
187
+ if time_compression_ratio == 4:
188
+ add_spatial_upsample = bool(i < num_spatial_upsample_layers)
189
+ add_time_upsample = bool(
190
+ i >= len(block_out_channels) - 1 - num_time_upsample_layers
191
+ and not is_final_block
192
+ )
193
+ else:
194
+ raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.")
195
+
196
+ upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, 1)
197
+ upsample_scale_factor_T = (2,) if add_time_upsample else (1,)
198
+ upsample_scale_factor = tuple(upsample_scale_factor_T + upsample_scale_factor_HW)
199
+ up_block = get_up_block3d(
200
+ up_block_type,
201
+ num_layers=self.layers_per_block + 1,
202
+ in_channels=prev_output_channel,
203
+ out_channels=output_channel,
204
+ prev_output_channel=None,
205
+ add_upsample=bool(add_spatial_upsample or add_time_upsample),
206
+ upsample_scale_factor=upsample_scale_factor,
207
+ resnet_eps=1e-6,
208
+ resnet_act_fn=act_fn,
209
+ resnet_groups=norm_num_groups,
210
+ attention_head_dim=output_channel,
211
+ temb_channels=temb_channels,
212
+ resnet_time_scale_shift=norm_type,
213
+ )
214
+ self.up_blocks.append(up_block)
215
+ prev_output_channel = output_channel
216
+
217
+ # out
218
+ if norm_type == "spatial":
219
+ self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
220
+ else:
221
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
222
+ self.conv_act = nn.SiLU()
223
+ self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3)
224
+
225
+ self.gradient_checkpointing = False
226
+
227
+ def forward(
228
+ self,
229
+ sample: torch.FloatTensor,
230
+ latent_embeds: Optional[torch.FloatTensor] = None,
231
+ ) -> torch.FloatTensor:
232
+ r"""The forward method of the `DecoderCausal3D` class."""
233
+ assert len(sample.shape) == 5, "The input tensor should have 5 dimensions."
234
+
235
+ sample = self.conv_in(sample)
236
+
237
+ upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
238
+ if self.training and self.gradient_checkpointing:
239
+
240
+ def create_custom_forward(module):
241
+ def custom_forward(*inputs):
242
+ return module(*inputs)
243
+
244
+ return custom_forward
245
+
246
+ if is_torch_version(">=", "1.11.0"):
247
+ # middle
248
+ sample = torch.utils.checkpoint.checkpoint(
249
+ create_custom_forward(self.mid_block),
250
+ sample,
251
+ latent_embeds,
252
+ use_reentrant=False,
253
+ )
254
+ sample = sample.to(upscale_dtype)
255
+
256
+ # up
257
+ for up_block in self.up_blocks:
258
+ sample = torch.utils.checkpoint.checkpoint(
259
+ create_custom_forward(up_block),
260
+ sample,
261
+ latent_embeds,
262
+ use_reentrant=False,
263
+ )
264
+ else:
265
+ # middle
266
+ sample = torch.utils.checkpoint.checkpoint(
267
+ create_custom_forward(self.mid_block), sample, latent_embeds
268
+ )
269
+ sample = sample.to(upscale_dtype)
270
+
271
+ # up
272
+ for up_block in self.up_blocks:
273
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
274
+ else:
275
+ # middle
276
+ sample = self.mid_block(sample, latent_embeds)
277
+ sample = sample.to(upscale_dtype)
278
+
279
+ # up
280
+ for up_block in self.up_blocks:
281
+ sample = up_block(sample, latent_embeds)
282
+
283
+ # post-process
284
+ if latent_embeds is None:
285
+ sample = self.conv_norm_out(sample)
286
+ else:
287
+ sample = self.conv_norm_out(sample, latent_embeds)
288
+ sample = self.conv_act(sample)
289
+ sample = self.conv_out(sample)
290
+
291
+ return sample
292
+
293
+
294
+ class DiagonalGaussianDistribution(object):
295
+ def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
296
+ if parameters.ndim == 3:
297
+ dim = 2 # (B, L, C)
298
+ elif parameters.ndim == 5 or parameters.ndim == 4:
299
+ dim = 1 # (B, C, T, H ,W) / (B, C, H, W)
300
+ else:
301
+ raise NotImplementedError
302
+ self.parameters = parameters
303
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=dim)
304
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
305
+ self.deterministic = deterministic
306
+ self.std = torch.exp(0.5 * self.logvar)
307
+ self.var = torch.exp(self.logvar)
308
+ if self.deterministic:
309
+ self.var = self.std = torch.zeros_like(
310
+ self.mean, device=self.parameters.device, dtype=self.parameters.dtype
311
+ )
312
+
313
+ def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
314
+ # make sure sample is on the same device as the parameters and has same dtype
315
+ sample = randn_tensor(
316
+ self.mean.shape,
317
+ generator=generator,
318
+ device=self.parameters.device,
319
+ dtype=self.parameters.dtype,
320
+ )
321
+ x = self.mean + self.std * sample
322
+ return x
323
+
324
+ def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
325
+ if self.deterministic:
326
+ return torch.Tensor([0.0])
327
+ else:
328
+ reduce_dim = list(range(1, self.mean.ndim))
329
+ if other is None:
330
+ return 0.5 * torch.sum(
331
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
332
+ dim=reduce_dim,
333
+ )
334
+ else:
335
+ return 0.5 * torch.sum(
336
+ torch.pow(self.mean - other.mean, 2) / other.var
337
+ + self.var / other.var
338
+ - 1.0
339
+ - self.logvar
340
+ + other.logvar,
341
+ dim=reduce_dim,
342
+ )
343
+
344
+ def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
345
+ if self.deterministic:
346
+ return torch.Tensor([0.0])
347
+ logtwopi = np.log(2.0 * np.pi)
348
+ return 0.5 * torch.sum(
349
+ logtwopi + self.logvar +
350
+ torch.pow(sample - self.mean, 2) / self.var,
351
+ dim=dims,
352
+ )
353
+
354
+ def mode(self) -> torch.Tensor:
355
+ return self.mean
requirements.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.5.1
2
+ torchvision==0.20.1
3
+ torchaudio==2.5.1
4
+ opencv-python==4.9.0.80
5
+ diffusers==0.30.2
6
+ transformers==4.39.3
7
+ tokenizers==0.15.2
8
+ accelerate==1.1.1
9
+ pandas==2.0.3
10
+ numpy==1.24.4
11
+ einops==0.7.0
12
+ tqdm==4.66.2
13
+ loguru==0.7.2
14
+ imageio==2.34.0
15
+ imageio-ffmpeg==0.5.1
16
+ safetensors==0.4.3
17
+ mmgp==3.0.3
18
+ gradio==5.8.0
19
+ moviepy==1.0.3
20
+ #flash-attn==2.7.2.post1
21
+ #sageattention==1.0.6
requirements_xdit.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ opencv-python==4.9.0.80
2
+ diffusers==0.31.0
3
+ transformers==4.46.3
4
+ tokenizers==0.20.3
5
+ accelerate==1.1.1
6
+ pandas==2.0.3
7
+ numpy==1.24.4
8
+ einops==0.7.0
9
+ tqdm==4.66.2
10
+ loguru==0.7.2
11
+ imageio==2.34.0
12
+ imageio-ffmpeg==0.5.1
13
+ safetensors==0.4.3
14
+ ninja
15
+ flash-attn==2.6.3
16
+ xfuser==0.4.0
sample_video.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from pathlib import Path
4
+ from loguru import logger
5
+ from datetime import datetime
6
+
7
+ from hyvideo.utils.file_utils import save_videos_grid
8
+ from hyvideo.config import parse_args
9
+ from hyvideo.inference import HunyuanVideoSampler
10
+
11
+
12
+ def main():
13
+ args = parse_args()
14
+ print(args)
15
+ models_root_path = Path(args.model_base)
16
+ if not models_root_path.exists():
17
+ raise ValueError(f"`models_root` not exists: {models_root_path}")
18
+
19
+ # Create save folder to save the samples
20
+ save_path = args.save_path if args.save_path_suffix=="" else f'{args.save_path}_{args.save_path_suffix}'
21
+ if not os.path.exists(args.save_path):
22
+ os.makedirs(save_path, exist_ok=True)
23
+
24
+ models_root_path = "ckpts/hunyuan-video-t2v-720p/transformers/hunyuan_video_720_bf16.safetensors"
25
+ text_encoder_filename = "ckpts/hunyuan-video-t2v-720p/text_encoder/llava-llama-3-8b_fp16.safetensors"
26
+ import json
27
+ with open("./ckpts/hunyuan-video-t2v-720p/vae/config.json", "r", encoding="utf-8") as reader:
28
+ text = reader.read()
29
+ vae_config= json.loads(text)
30
+ # reduce time window used by the VAE for temporal splitting (former time windows is too large for 24 GB)
31
+ if vae_config["sample_tsize"] == 64:
32
+ vae_config["sample_tsize"] = 32
33
+ with open("./ckpts/hunyuan-video-t2v-720p/vae/config.json", "w", encoding="utf-8") as writer:
34
+ writer.write(json.dumps(vae_config))
35
+
36
+ # Load models
37
+ hunyuan_video_sampler = HunyuanVideoSampler.from_pretrained(models_root_path,text_encoder_filename, args=args)
38
+
39
+ from mmgp import offload, profile_type
40
+ pipe = hunyuan_video_sampler.pipeline
41
+ offload.profile(pipe, profile_no= profile_type.HighRAM_LowVRAM_Fast)
42
+
43
+ # Get the updated args
44
+ args = hunyuan_video_sampler.args
45
+
46
+ # Start sampling
47
+ # TODO: batch inference check
48
+ outputs = hunyuan_video_sampler.predict(
49
+ prompt=args.prompt,
50
+ height=args.video_size[0],
51
+ width=args.video_size[1],
52
+ video_length=args.video_length,
53
+ seed=args.seed,
54
+ negative_prompt=args.neg_prompt,
55
+ infer_steps=args.infer_steps,
56
+ guidance_scale=args.cfg_scale,
57
+ num_videos_per_prompt=args.num_videos,
58
+ flow_shift=args.flow_shift,
59
+ batch_size=args.batch_size,
60
+ embedded_guidance_scale=args.embedded_cfg_scale
61
+ )
62
+ samples = outputs['samples']
63
+
64
+ # Save samples
65
+ if 'LOCAL_RANK' not in os.environ or int(os.environ['LOCAL_RANK']) == 0:
66
+ for i, sample in enumerate(samples):
67
+ sample = samples[i].unsqueeze(0)
68
+ time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%H:%M:%S")
69
+ save_path = f"{save_path}/{time_flag}_seed{outputs['seeds'][i]}_{outputs['prompts'][i][:100].replace('/','')}.mp4"
70
+ save_videos_grid(sample, save_path, fps=24)
71
+ logger.info(f'Sample save to: {save_path}')
72
+
73
+ if __name__ == "__main__":
74
+ main()