zRzRzRzRzRzRzR commited on
Commit
2e8cf5f
1 Parent(s): 39ab106
.gitignore CHANGED
@@ -1,133 +1 @@
1
- gradio_queue.db*
2
- pretrained/*
3
- icetk_models/*
4
- !*/.gitkeep
5
- # Byte-compiled / optimized / DLL files
6
- __pycache__/
7
- *.py[cod]
8
- *$py.class
9
-
10
- # C extensions
11
- *.so
12
-
13
- # Distribution / packaging
14
- .Python
15
- build/
16
- develop-eggs/
17
- dist/
18
- downloads/
19
- eggs/
20
- .eggs/
21
- lib/
22
- lib64/
23
- parts/
24
- sdist/
25
- var/
26
- wheels/
27
- pip-wheel-metadata/
28
- share/python-wheels/
29
- *.egg-info/
30
- .installed.cfg
31
- *.egg
32
- MANIFEST
33
-
34
- # PyInstaller
35
- # Usually these files are written by a python script from a template
36
- # before PyInstaller builds the exe, so as to inject date/other infos into it.
37
- *.manifest
38
- *.spec
39
-
40
- # Installer logs
41
- pip-log.txt
42
- pip-delete-this-directory.txt
43
-
44
- # Unit test / coverage reports
45
- htmlcov/
46
- .tox/
47
- .nox/
48
- .coverage
49
- .coverage.*
50
- .cache
51
- nosetests.xml
52
- coverage.xml
53
- *.cover
54
- *.py,cover
55
- .hypothesis/
56
- .pytest_cache/
57
-
58
- # Translations
59
- *.mo
60
- *.pot
61
-
62
- # Django stuff:
63
- *.log
64
- local_settings.py
65
- db.sqlite3
66
- db.sqlite3-journal
67
-
68
- # Flask stuff:
69
- instance/
70
- .webassets-cache
71
-
72
- # Scrapy stuff:
73
- .scrapy
74
-
75
- # Sphinx documentation
76
- docs/_build/
77
-
78
- # PyBuilder
79
- target/
80
-
81
- # Jupyter Notebook
82
- .ipynb_checkpoints
83
-
84
- # IPython
85
- profile_default/
86
- ipython_config.py
87
-
88
- # pyenv
89
- .python-version
90
-
91
- # pipenv
92
- # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
93
- # However, in case of collaboration, if having platform-specific dependencies or dependencies
94
- # having no cross-platform support, pipenv may install dependencies that don't work, or not
95
- # install all needed dependencies.
96
- #Pipfile.lock
97
-
98
- # PEP 582; used by e.g. github.com/David-OConnor/pyflow
99
- __pypackages__/
100
-
101
- # Celery stuff
102
- celerybeat-schedule
103
- celerybeat.pid
104
-
105
- # SageMath parsed files
106
- *.sage.py
107
-
108
- # Environments
109
- .env
110
- .venv
111
- env/
112
- venv/
113
- ENV/
114
- env.bak/
115
- venv.bak/
116
-
117
- # Spyder project settings
118
- .spyderproject
119
- .spyproject
120
-
121
- # Rope project settings
122
- .ropeproject
123
-
124
- # mkdocs documentation
125
- /site
126
-
127
- # mypy
128
- .mypy_cache/
129
- .dmypy.json
130
- dmypy.json
131
-
132
- # Pyre type checker
133
- .pyre/
 
1
+ .venv
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.idea/.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
4
+ # Editor-based HTTP Client requests
5
+ /httpRequests/
6
+ # Datasource local storage ignored files
7
+ /dataSources/
8
+ /dataSources.local.xml
.idea/CogVideo.iml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$" />
5
+ <orderEntry type="inheritedJdk" />
6
+ <orderEntry type="sourceFolder" forTests="false" />
7
+ </component>
8
+ <component name="PyDocumentationSettings">
9
+ <option name="format" value="PLAIN" />
10
+ <option name="myDocStringFormat" value="Plain" />
11
+ </component>
12
+ </module>
.idea/inspectionProfiles/Project_Default.xml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <profile version="1.0">
3
+ <option name="myName" value="Project Default" />
4
+ <inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
5
+ <option name="ignoredPackages">
6
+ <value>
7
+ <list size="7">
8
+ <item index="0" class="java.lang.String" itemvalue="openai" />
9
+ <item index="1" class="java.lang.String" itemvalue="sse_starlette" />
10
+ <item index="2" class="java.lang.String" itemvalue="fastapi" />
11
+ <item index="3" class="java.lang.String" itemvalue="timm" />
12
+ <item index="4" class="java.lang.String" itemvalue="gradio" />
13
+ <item index="5" class="java.lang.String" itemvalue="uvicorn" />
14
+ <item index="6" class="java.lang.String" itemvalue="diffusers" />
15
+ </list>
16
+ </value>
17
+ </option>
18
+ </inspection_tool>
19
+ </profile>
20
+ </component>
.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="USE_PROJECT_PROFILE" value="false" />
4
+ <version value="1.0" />
5
+ </settings>
6
+ </component>
.idea/misc.xml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="Black">
4
+ <option name="sdkName" value="Remote Python 3.10.14 (sftp://[email protected]:22/share/home/zyx/.conda/envs/cogvideox/bin/python)" />
5
+ </component>
6
+ <component name="ProjectRootManager" version="2" project-jdk-name="Remote Python 3.10.14 (sftp://[email protected]:22/share/home/zyx/.conda/envs/cogvideox/bin/python)" project-jdk-type="Python SDK" />
7
+ </project>
.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/CogVideo.iml" filepath="$PROJECT_DIR$/.idea/CogVideo.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
.idea/vcs.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="VcsDirectoryMappings">
4
+ <mapping directory="" vcs="Git" />
5
+ </component>
6
+ </project>
.pre-commit-config.yaml DELETED
@@ -1,46 +0,0 @@
1
- exclude: ^patch
2
- repos:
3
- - repo: https://github.com/pre-commit/pre-commit-hooks
4
- rev: v4.2.0
5
- hooks:
6
- - id: check-executables-have-shebangs
7
- - id: check-json
8
- - id: check-merge-conflict
9
- - id: check-shebang-scripts-are-executable
10
- - id: check-toml
11
- - id: check-yaml
12
- - id: double-quote-string-fixer
13
- - id: end-of-file-fixer
14
- - id: mixed-line-ending
15
- args: ['--fix=lf']
16
- - id: requirements-txt-fixer
17
- - id: trailing-whitespace
18
- - repo: https://github.com/myint/docformatter
19
- rev: v1.4
20
- hooks:
21
- - id: docformatter
22
- args: ['--in-place']
23
- - repo: https://github.com/pycqa/isort
24
- rev: 5.10.1
25
- hooks:
26
- - id: isort
27
- - repo: https://github.com/pre-commit/mirrors-mypy
28
- rev: v0.812
29
- hooks:
30
- - id: mypy
31
- args: ['--ignore-missing-imports']
32
- - repo: https://github.com/google/yapf
33
- rev: v0.32.0
34
- hooks:
35
- - id: yapf
36
- args: ['--parallel', '--in-place']
37
- - repo: https://github.com/kynan/nbstripout
38
- rev: 0.5.0
39
- hooks:
40
- - id: nbstripout
41
- args: ['--extra-keys', 'metadata.interpreter metadata.kernelspec cell.metadata.pycharm']
42
- - repo: https://github.com/nbQA-dev/nbQA
43
- rev: 1.3.1
44
- hooks:
45
- - id: nbqa-isort
46
- - id: nbqa-yapf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.style.yapf DELETED
@@ -1,5 +0,0 @@
1
- [style]
2
- based_on_style = pep8
3
- blank_line_before_nested_class_or_def = false
4
- spaces_before_comment = 2
5
- split_before_logical_operator = true
 
 
 
 
 
 
CogVideo DELETED
@@ -1 +0,0 @@
1
- Subproject commit ff423aa169978fb2f636f761e348631fa3178b03
 
 
LICENSE DELETED
@@ -1,21 +0,0 @@
1
- MIT License
2
-
3
- Copyright (c) 2022 hysts
4
-
5
- Permission is hereby granted, free of charge, to any person obtaining a copy
6
- of this software and associated documentation files (the "Software"), to deal
7
- in the Software without restriction, including without limitation the rights
8
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
- copies of the Software, and to permit persons to whom the Software is
10
- furnished to do so, subject to the following conditions:
11
-
12
- The above copyright notice and this permission notice shall be included in all
13
- copies or substantial portions of the Software.
14
-
15
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
- SOFTWARE.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LICENSE.CogVideo DELETED
@@ -1,201 +0,0 @@
1
- Apache License
2
- Version 2.0, January 2004
3
- http://www.apache.org/licenses/
4
-
5
- TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
-
7
- 1. Definitions.
8
-
9
- "License" shall mean the terms and conditions for use, reproduction,
10
- and distribution as defined by Sections 1 through 9 of this document.
11
-
12
- "Licensor" shall mean the copyright owner or entity authorized by
13
- the copyright owner that is granting the License.
14
-
15
- "Legal Entity" shall mean the union of the acting entity and all
16
- other entities that control, are controlled by, or are under common
17
- control with that entity. For the purposes of this definition,
18
- "control" means (i) the power, direct or indirect, to cause the
19
- direction or management of such entity, whether by contract or
20
- otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
- outstanding shares, or (iii) beneficial ownership of such entity.
22
-
23
- "You" (or "Your") shall mean an individual or Legal Entity
24
- exercising permissions granted by this License.
25
-
26
- "Source" form shall mean the preferred form for making modifications,
27
- including but not limited to software source code, documentation
28
- source, and configuration files.
29
-
30
- "Object" form shall mean any form resulting from mechanical
31
- transformation or translation of a Source form, including but
32
- not limited to compiled object code, generated documentation,
33
- and conversions to other media types.
34
-
35
- "Work" shall mean the work of authorship, whether in Source or
36
- Object form, made available under the License, as indicated by a
37
- copyright notice that is included in or attached to the work
38
- (an example is provided in the Appendix below).
39
-
40
- "Derivative Works" shall mean any work, whether in Source or Object
41
- form, that is based on (or derived from) the Work and for which the
42
- editorial revisions, annotations, elaborations, or other modifications
43
- represent, as a whole, an original work of authorship. For the purposes
44
- of this License, Derivative Works shall not include works that remain
45
- separable from, or merely link (or bind by name) to the interfaces of,
46
- the Work and Derivative Works thereof.
47
-
48
- "Contribution" shall mean any work of authorship, including
49
- the original version of the Work and any modifications or additions
50
- to that Work or Derivative Works thereof, that is intentionally
51
- submitted to Licensor for inclusion in the Work by the copyright owner
52
- or by an individual or Legal Entity authorized to submit on behalf of
53
- the copyright owner. For the purposes of this definition, "submitted"
54
- means any form of electronic, verbal, or written communication sent
55
- to the Licensor or its representatives, including but not limited to
56
- communication on electronic mailing lists, source code control systems,
57
- and issue tracking systems that are managed by, or on behalf of, the
58
- Licensor for the purpose of discussing and improving the Work, but
59
- excluding communication that is conspicuously marked or otherwise
60
- designated in writing by the copyright owner as "Not a Contribution."
61
-
62
- "Contributor" shall mean Licensor and any individual or Legal Entity
63
- on behalf of whom a Contribution has been received by Licensor and
64
- subsequently incorporated within the Work.
65
-
66
- 2. Grant of Copyright License. Subject to the terms and conditions of
67
- this License, each Contributor hereby grants to You a perpetual,
68
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
- copyright license to reproduce, prepare Derivative Works of,
70
- publicly display, publicly perform, sublicense, and distribute the
71
- Work and such Derivative Works in Source or Object form.
72
-
73
- 3. Grant of Patent License. Subject to the terms and conditions of
74
- this License, each Contributor hereby grants to You a perpetual,
75
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
- (except as stated in this section) patent license to make, have made,
77
- use, offer to sell, sell, import, and otherwise transfer the Work,
78
- where such license applies only to those patent claims licensable
79
- by such Contributor that are necessarily infringed by their
80
- Contribution(s) alone or by combination of their Contribution(s)
81
- with the Work to which such Contribution(s) was submitted. If You
82
- institute patent litigation against any entity (including a
83
- cross-claim or counterclaim in a lawsuit) alleging that the Work
84
- or a Contribution incorporated within the Work constitutes direct
85
- or contributory patent infringement, then any patent licenses
86
- granted to You under this License for that Work shall terminate
87
- as of the date such litigation is filed.
88
-
89
- 4. Redistribution. You may reproduce and distribute copies of the
90
- Work or Derivative Works thereof in any medium, with or without
91
- modifications, and in Source or Object form, provided that You
92
- meet the following conditions:
93
-
94
- (a) You must give any other recipients of the Work or
95
- Derivative Works a copy of this License; and
96
-
97
- (b) You must cause any modified files to carry prominent notices
98
- stating that You changed the files; and
99
-
100
- (c) You must retain, in the Source form of any Derivative Works
101
- that You distribute, all copyright, patent, trademark, and
102
- attribution notices from the Source form of the Work,
103
- excluding those notices that do not pertain to any part of
104
- the Derivative Works; and
105
-
106
- (d) If the Work includes a "NOTICE" text file as part of its
107
- distribution, then any Derivative Works that You distribute must
108
- include a readable copy of the attribution notices contained
109
- within such NOTICE file, excluding those notices that do not
110
- pertain to any part of the Derivative Works, in at least one
111
- of the following places: within a NOTICE text file distributed
112
- as part of the Derivative Works; within the Source form or
113
- documentation, if provided along with the Derivative Works; or,
114
- within a display generated by the Derivative Works, if and
115
- wherever such third-party notices normally appear. The contents
116
- of the NOTICE file are for informational purposes only and
117
- do not modify the License. You may add Your own attribution
118
- notices within Derivative Works that You distribute, alongside
119
- or as an addendum to the NOTICE text from the Work, provided
120
- that such additional attribution notices cannot be construed
121
- as modifying the License.
122
-
123
- You may add Your own copyright statement to Your modifications and
124
- may provide additional or different license terms and conditions
125
- for use, reproduction, or distribution of Your modifications, or
126
- for any such Derivative Works as a whole, provided Your use,
127
- reproduction, and distribution of the Work otherwise complies with
128
- the conditions stated in this License.
129
-
130
- 5. Submission of Contributions. Unless You explicitly state otherwise,
131
- any Contribution intentionally submitted for inclusion in the Work
132
- by You to the Licensor shall be under the terms and conditions of
133
- this License, without any additional terms or conditions.
134
- Notwithstanding the above, nothing herein shall supersede or modify
135
- the terms of any separate license agreement you may have executed
136
- with Licensor regarding such Contributions.
137
-
138
- 6. Trademarks. This License does not grant permission to use the trade
139
- names, trademarks, service marks, or product names of the Licensor,
140
- except as required for reasonable and customary use in describing the
141
- origin of the Work and reproducing the content of the NOTICE file.
142
-
143
- 7. Disclaimer of Warranty. Unless required by applicable law or
144
- agreed to in writing, Licensor provides the Work (and each
145
- Contributor provides its Contributions) on an "AS IS" BASIS,
146
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
- implied, including, without limitation, any warranties or conditions
148
- of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
- PARTICULAR PURPOSE. You are solely responsible for determining the
150
- appropriateness of using or redistributing the Work and assume any
151
- risks associated with Your exercise of permissions under this License.
152
-
153
- 8. Limitation of Liability. In no event and under no legal theory,
154
- whether in tort (including negligence), contract, or otherwise,
155
- unless required by applicable law (such as deliberate and grossly
156
- negligent acts) or agreed to in writing, shall any Contributor be
157
- liable to You for damages, including any direct, indirect, special,
158
- incidental, or consequential damages of any character arising as a
159
- result of this License or out of the use or inability to use the
160
- Work (including but not limited to damages for loss of goodwill,
161
- work stoppage, computer failure or malfunction, or any and all
162
- other commercial damages or losses), even if such Contributor
163
- has been advised of the possibility of such damages.
164
-
165
- 9. Accepting Warranty or Additional Liability. While redistributing
166
- the Work or Derivative Works thereof, You may choose to offer,
167
- and charge a fee for, acceptance of support, warranty, indemnity,
168
- or other liability obligations and/or rights consistent with this
169
- License. However, in accepting such obligations, You may act only
170
- on Your own behalf and on Your sole responsibility, not on behalf
171
- of any other Contributor, and only if You agree to indemnify,
172
- defend, and hold each Contributor harmless for any liability
173
- incurred by, or claims asserted against, such Contributor by reason
174
- of your accepting any such warranty or additional liability.
175
-
176
- END OF TERMS AND CONDITIONS
177
-
178
- APPENDIX: How to apply the Apache License to your work.
179
-
180
- To apply the Apache License to your work, attach the following
181
- boilerplate notice, with the fields enclosed by brackets "[]"
182
- replaced with your own identifying information. (Don't include
183
- the brackets!) The text should be enclosed in the appropriate
184
- comment syntax for the file format. We also recommend that a
185
- file or class name and description of purpose be included on the
186
- same "printed page" as the copyright notice for easier
187
- identification within third-party archives.
188
-
189
- Copyright [yyyy] [name of copyright owner]
190
-
191
- Licensed under the Apache License, Version 2.0 (the "License");
192
- you may not use this file except in compliance with the License.
193
- You may obtain a copy of the License at
194
-
195
- http://www.apache.org/licenses/LICENSE-2.0
196
-
197
- Unless required by applicable law or agreed to in writing, software
198
- distributed under the License is distributed on an "AS IS" BASIS,
199
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
- See the License for the specific language governing permissions and
201
- limitations under the License.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md DELETED
@@ -1,13 +0,0 @@
1
- ---
2
- title: CogVideo
3
- emoji: 🌍
4
- colorFrom: indigo
5
- colorTo: yellow
6
- sdk: gradio
7
- sdk_version: 3.1.6
8
- python_version: 3.9.13
9
- app_file: app.py
10
- pinned: false
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -1,138 +1,188 @@
1
- #!/usr/bin/env python
2
-
3
- from __future__ import annotations
4
-
5
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
- # from model import AppModel
8
-
9
- MAINTENANCE_NOTICE='Sorry, due to computing resources issues, this space is under maintenance, and will be restored as soon as possible. '
10
-
11
- DESCRIPTION = '''# <a href="https://github.com/THUDM/CogVideo">CogVideo</a>
12
- Currently, this Space only supports the first stage of the CogVideo pipeline due to hardware limitations.
13
- The model accepts only Chinese as input.
14
- By checking the "Translate to Chinese" checkbox, the results of English to Chinese translation with [this Space](https://huggingface.co/spaces/chinhon/translation_eng2ch) will be used as input.
15
- Since the translation model may mistranslate, you may want to use the translation results from other translation services.
16
- '''
17
- NOTES = 'This app is adapted from <a href="https://github.com/hysts/CogVideo_demo">https://github.com/hysts/CogVideo_demo</a>. It would be recommended to use the repo if you want to run the app yourself.'
18
- FOOTER = '<img id="visitor-badge" alt="visitor badge" src="https://visitor-badge.glitch.me/badge?page_id=THUDM.CogVideo" />'
19
-
20
- import json
21
- import requests
22
- import numpy as np
23
- import imageio.v2 as iio
24
- import base64
25
- import urllib.request
26
-
27
- def post(
28
- text,
29
- translate,
30
- seed,
31
- only_first_stage,
32
- image_prompt
33
- ):
34
- url = 'https://tianqi.aminer.cn/cogvideo/api/generate'
35
- headers = {
36
- "Content-Type": "application/json; charset=UTF-8",
37
- "User-Agent": "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/67.0.3396.87 Safari/537.36",
38
- }
39
- if image_prompt:
40
- with open(image_prompt, "rb") as image_file:
41
- encoded_img = str(base64.b64encode(image_file.read()), encoding='utf-8')
42
- else:
43
- encoded_img = None
44
- print('开始请求...')
45
- data = json.dumps({'text': text,
46
- 'translate': translate,
47
- 'seed': seed,
48
- 'only_first_stage': only_first_stage,
49
- 'image_prompt': encoded_img
50
- })
51
- r = requests.post(url, data, headers=headers)
52
- print(r)
53
-
54
- print('请求完毕...')
55
- # translated_text = r.json()['data']['translated_text']
56
- frames = r.json()['data']['frames']
57
-
58
- result_video = ["" for i in range(len(frames))]
59
- result_video[0] = "./temp1.mp4"
60
- result_video[1] = "./temp2.mp4"
61
- for i in range(len(result_video)):
62
- url = frames[i]
63
- result_video[i] = "./temp" + str(i) + ".mp4"
64
- urllib.request.urlretrieve(url, result_video[i])
65
-
66
- print('finished')
67
- return result_video[0], result_video[1]
68
- # return result_video[0], result_video[1], result_video[2], result_video[3]
69
-
70
- def main():
71
- only_first_stage = True
72
- # model = AppModel(only_first_stage)
73
-
74
- with gr.Blocks(css='style.css') as demo:
75
- # gr.Markdown(MAINTENANCE_NOTICE)
76
-
77
- gr.Markdown(DESCRIPTION)
78
-
79
- with gr.Row():
80
  with gr.Column():
81
- with gr.Group():
82
- text = gr.Textbox(label='Input Text')
83
- translate = gr.Checkbox(label='Translate to Chinese',
84
- value=False)
85
- seed = gr.Slider(0,
86
- 100000,
87
- step=1,
88
- value=1234,
89
- label='Seed')
90
- only_first_stage = gr.Checkbox(
91
- label='Only First Stage',
92
- value=only_first_stage,
93
- visible=not only_first_stage)
94
- image_prompt = gr.Image(type="filepath",
95
- label="Image Prompt",
96
- value=None)
97
- run_button = gr.Button('Run')
98
 
99
- with gr.Column():
100
- with gr.Group():
101
- #translated_text = gr.Textbox(label='Translated Text')
102
- with gr.Tabs():
103
- with gr.TabItem('Output (Video)'):
104
- result_video1 = gr.Video(show_label=False)
105
- result_video2 = gr.Video(show_label=False)
106
- # result_video3 = gr.Video(show_label=False)
107
- # result_video4 = gr.Video(show_label=False)
108
-
109
-
110
-
111
- # examples = gr.Examples(
112
- # examples=[['骑滑板的皮卡丘', False, 1234, True,None],
113
- # ['a cat playing chess', True, 1253, True,None]],
114
- # fn=model.run_with_translation,
115
- # inputs=[text, translate, seed, only_first_stage,image_prompt],
116
- # outputs=[translated_text, result_video],
117
- # cache_examples=True)
118
-
119
- gr.Markdown(NOTES)
120
- gr.Markdown(FOOTER)
121
- print(gr.__version__)
122
- run_button.click(fn=post,
123
- inputs=[
124
- text,
125
- translate,
126
- seed,
127
- only_first_stage,
128
- image_prompt
129
- ],
130
- outputs=[result_video1, result_video2])
131
- # outputs=[result_video1, result_video2, result_video3, result_video4])
132
- print(gr.__version__)
133
- demo.queue(concurrency_count=6)
134
- demo.launch()
135
-
136
-
137
- if __name__ == '__main__':
138
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
 
 
 
2
  import gradio as gr
3
+ import torch
4
+ from diffusers import CogVideoXPipeline
5
+ from diffusers.utils import export_to_video
6
+ from datetime import datetime
7
+ from openai import OpenAI
8
+ import spaces
9
+ import moviepy.editor as mp
10
+
11
+ dtype = torch.float16
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+ pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=dtype).to(device)
14
+
15
+ sys_prompt = """You are part of a team of bots that creates videos. You work with an assistant bot that will draw anything you say in square brackets.
16
+
17
+ For example , outputting " a beautiful morning in the woods with the sun peaking through the trees " will trigger your partner bot to output an video of a forest morning , as described. You will be prompted by people looking to create detailed , amazing videos. The way to accomplish this is to take their short prompts and make them extremely detailed and descriptive.
18
+ There are a few rules to follow:
19
+
20
+ You will only ever output a single video description per user request.
21
+
22
+ When modifications are requested , you should not simply make the description longer . You should refactor the entire description to integrate the suggestions.
23
+ Other times the user will not want modifications , but instead want a new image . In this case , you should ignore your previous conversation with the user.
24
+
25
+ Video descriptions must have the same num of words as examples below. Extra words will be ignored.
26
+ """
27
+
28
+
29
+ def convert_prompt(prompt: str, retry_times: int = 3) -> str:
30
+ if not os.environ.get("OPENAI_API_KEY"):
31
+ return prompt
32
+ client = OpenAI()
33
+ text = prompt.strip()
34
+
35
+ for i in range(retry_times):
36
+ response = client.chat.completions.create(
37
+ messages=[
38
+ {"role": "system", "content": sys_prompt},
39
+ {"role": "user",
40
+ "content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : "a girl is on the beach"'},
41
+ {"role": "assistant",
42
+ "content": "A radiant woman stands on a deserted beach, arms outstretched, wearing a beige trench coat, white blouse, light blue jeans, and chic boots, against a backdrop of soft sky and sea. Moments later, she is seen mid-twirl, arms exuberant, with the lighting suggesting dawn or dusk. Then, she runs along the beach, her attire complemented by an off-white scarf and black ankle boots, the tranquil sea behind her. Finally, she holds a paper airplane, her pose reflecting joy and freedom, with the ocean's gentle waves and the sky's soft pastel hues enhancing the serene ambiance."},
43
+ {"role": "user",
44
+ "content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : "A man jogging on a football field"'},
45
+ {"role": "assistant",
46
+ "content": "A determined man in athletic attire, including a blue long-sleeve shirt, black shorts, and blue socks, jogs around a snow-covered soccer field, showcasing his solitary exercise in a quiet, overcast setting. His long dreadlocks, focused expression, and the serene winter backdrop highlight his dedication to fitness. As he moves, his attire, consisting of a blue sports sweatshirt, black athletic pants, gloves, and sneakers, grips the snowy ground. He is seen running past a chain-link fence enclosing the playground area, with a basketball hoop and children's slide, suggesting a moment of solitary exercise amidst the empty field."},
47
+ {"role": "user",
48
+ "content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : " A woman is dancing, HD footage, close-up"'},
49
+ {"role": "assistant",
50
+ "content": "A young woman with her hair in an updo and wearing a teal hoodie stands against a light backdrop, initially looking over her shoulder with a contemplative expression. She then confidently makes a subtle dance move, suggesting rhythm and movement. Next, she appears poised and focused, looking directly at the camera. Her expression shifts to one of introspection as she gazes downward slightly. Finally, she dances with confidence, her left hand over her heart, symbolizing a poignant moment, all while dressed in the same teal hoodie against a plain, light-colored background."},
51
+ {"role": "user",
52
+ "content": f'Create an imaginative video descriptive caption or modify an earlier caption in ENGLISH for the user input: "{text}"'},
53
+ ],
54
+ model="glm-4-0520",
55
+ temperature=0.01,
56
+ top_p=0.7,
57
+ stream=False,
58
+ max_tokens=250,
59
+ )
60
+ if response.choices:
61
+ return response.choices[0].message.content
62
+ return prompt
63
+
64
+
65
+ @spaces.GPU()
66
+ def infer(
67
+ prompt: str,
68
+ num_inference_steps: int,
69
+ guidance_scale: float,
70
+ progress=gr.Progress(track_tqdm=True)
71
+ ):
72
+ torch.cuda.empty_cache()
73
+
74
+ prompt_embeds, _ = pipe.encode_prompt(
75
+ prompt=prompt,
76
+ negative_prompt=None,
77
+ do_classifier_free_guidance=True,
78
+ num_videos_per_prompt=1,
79
+ max_sequence_length=226,
80
+ device=device,
81
+ dtype=dtype,
82
+ )
83
+
84
+ video = pipe(
85
+ num_inference_steps=num_inference_steps,
86
+ guidance_scale=guidance_scale,
87
+ prompt_embeds=prompt_embeds,
88
+ negative_prompt_embeds=torch.zeros_like(prompt_embeds),
89
+ ).frames[0]
90
+
91
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
92
+ video_path = f"./output/{timestamp}.mp4"
93
+ os.makedirs(os.path.dirname(video_path), exist_ok=True)
94
+ export_to_video(video, video_path)
95
+ return video_path
96
+
97
+
98
+ def convert_to_gif(video_path):
99
+ clip = mp.VideoFileClip(video_path)
100
+ clip = clip.set_fps(8)
101
+ clip = clip.resize(height=240)
102
+ gif_path = video_path.replace('.mp4', '.gif')
103
+ clip.write_gif(gif_path, fps=8)
104
+ return gif_path
105
+
106
+
107
+ with gr.Blocks() as demo:
108
+ gr.Markdown("""
109
+ <div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
110
+ CogVideoX-2B Huggingface Space🤗
111
+ </div>
112
+ <div style="text-align: center;">
113
+ <a href="https://huggingface.co/THUDM/CogVideoX-2b">🤗 Model Hub</a> |
114
+ <a href="https://github.com/THUDM/CogVideo">🌐 Github</a>
115
+ </div>
116
+
117
+ <div style="text-align: center; font-size: 15px; font-weight: bold; color: red; margin-bottom: 20px;">
118
+ ⚠️ This demo is for academic research and experiential use only.
119
+ Users should strictly adhere to local laws and ethics.
120
+ </div>
121
+ """)
122
+ with gr.Row():
123
+ with gr.Column():
124
+ prompt = gr.Textbox(label="Prompt (Less than 200 Words)", placeholder="Enter your prompt here", lines=5)
125
+ with gr.Row():
126
+ gr.Markdown(
127
+ "✨Upon pressing the enhanced prompt button, we will use [GLM-4 Model](https://github.com/THUDM/GLM-4) to polish the prompt and overwrite the original one.")
128
+ enhance_button = gr.Button("✨ Enhance Prompt(Optional)")
129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  with gr.Column():
131
+ gr.Markdown("**Optional Parameters** (default values are recommended)")
132
+ with gr.Row():
133
+ num_inference_steps = gr.Number(label="Inference Steps", value=50)
134
+ guidance_scale = gr.Number(label="Guidance Scale", value=6.0)
135
+ generate_button = gr.Button("🎬 Generate Video")
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
+ with gr.Column():
138
+ video_output = gr.Video(label="CogVideoX Generate Video", width=720, height=480)
139
+ with gr.Row():
140
+ download_video_button = gr.File(label="📥 Download Video", visible=False)
141
+ download_gif_button = gr.File(label="📥 Download GIF", visible=False)
142
+
143
+
144
+ def generate(prompt, num_inference_steps, guidance_scale, progress=gr.Progress(track_tqdm=True)):
145
+ video_path = infer(prompt, num_inference_steps, guidance_scale, progress=progress)
146
+ video_update = gr.update(visible=True, value=video_path)
147
+
148
+ gif_path = convert_to_gif(video_path)
149
+ gif_update = gr.update(visible=True, value=gif_path)
150
+
151
+ return video_path, video_update, gif_update
152
+
153
+
154
+ def enhance_prompt_func(prompt):
155
+ return convert_prompt(prompt, retry_times=1)
156
+
157
+
158
+ generate_button.click(
159
+ generate,
160
+ inputs=[prompt, num_inference_steps, guidance_scale],
161
+ outputs=[video_output, download_video_button, download_gif_button]
162
+ )
163
+
164
+ enhance_button.click(
165
+ enhance_prompt_func,
166
+ inputs=[prompt],
167
+ outputs=[prompt]
168
+ )
169
+
170
+
171
+ def enhance_prompt_func(prompt):
172
+ return convert_prompt(prompt, retry_times=1)
173
+
174
+
175
+ generate_button.click(
176
+ generate,
177
+ inputs=[prompt, num_inference_steps, guidance_scale],
178
+ outputs=[video_output, download_video_button, download_gif_button]
179
+ )
180
+
181
+ enhance_button.click(
182
+ enhance_prompt_func,
183
+ inputs=[prompt],
184
+ outputs=[prompt]
185
+ )
186
+
187
+ if __name__ == "__main__":
188
+ demo.launch(server_name="127.0.0.1", server_port=7870, share=True)
model.py DELETED
@@ -1,1243 +0,0 @@
1
- # This code is adapted from https://github.com/THUDM/CogVideo/blob/ff423aa169978fb2f636f761e348631fa3178b03/cogvideo_pipeline.py
2
-
3
- from __future__ import annotations
4
-
5
- import argparse
6
- import logging
7
- import os
8
- import pathlib
9
- import shutil
10
- import subprocess
11
- import sys
12
- import tempfile
13
- import time
14
- import zipfile
15
- from typing import Any
16
-
17
- if os.getenv('SYSTEM') == 'spaces':
18
- subprocess.run('pip install icetk==0.0.4'.split())
19
- subprocess.run('pip install SwissArmyTransformer==0.2.9'.split())
20
- subprocess.run(
21
- 'pip install git+https://github.com/Sleepychord/Image-Local-Attention@43fee31'
22
- .split())
23
- #subprocess.run('git clone https://github.com/NVIDIA/apex'.split())
24
- #subprocess.run('git checkout 1403c21'.split(), cwd='apex')
25
- #with open('patch.apex') as f:
26
- # subprocess.run('patch -p1'.split(), cwd='apex', stdin=f)
27
- #subprocess.run(
28
- # 'pip install -v --disable-pip-version-check --no-cache-dir --global-option --cpp_ext --global-option --cuda_ext ./'
29
- # .split(),
30
- # cwd='apex')
31
- #subprocess.run('rm -rf apex'.split())
32
- with open('patch') as f:
33
- subprocess.run('patch -p1'.split(), cwd='CogVideo', stdin=f)
34
-
35
- from huggingface_hub import hf_hub_download
36
-
37
- def download_and_extract_icetk_models() -> None:
38
- icetk_model_dir = pathlib.Path('/home/user/.icetk_models')
39
- icetk_model_dir.mkdir()
40
- path = hf_hub_download('THUDM/icetk',
41
- 'models.zip',
42
- use_auth_token=os.getenv('HF_TOKEN'))
43
- with zipfile.ZipFile(path) as f:
44
- f.extractall(path=icetk_model_dir.as_posix())
45
-
46
- def download_and_extract_cogvideo_models(name: str) -> None:
47
- path = hf_hub_download('THUDM/CogVideo',
48
- name,
49
- use_auth_token=os.getenv('HF_TOKEN'))
50
- with zipfile.ZipFile(path) as f:
51
- f.extractall('pretrained')
52
- os.remove(path)
53
-
54
- def download_and_extract_cogview2_models(name: str) -> None:
55
- path = hf_hub_download('THUDM/CogView2', name)
56
- with zipfile.ZipFile(path) as f:
57
- f.extractall()
58
- shutil.move('/home/user/app/sharefs/cogview-new/cogview2-dsr',
59
- 'pretrained')
60
- shutil.rmtree('/home/user/app/sharefs/')
61
- os.remove(path)
62
-
63
- download_and_extract_icetk_models()
64
- download_and_extract_cogvideo_models('cogvideo-stage1.zip')
65
- #download_and_extract_cogvideo_models('cogvideo-stage2.zip')
66
- #download_and_extract_cogview2_models('cogview2-dsr.zip')
67
-
68
- os.environ['SAT_HOME'] = '/home/user/app/pretrained'
69
-
70
- import gradio as gr
71
- import imageio.v2 as iio
72
- import numpy as np
73
- import torch
74
- from icetk import IceTokenizer
75
- from SwissArmyTransformer import get_args
76
- from SwissArmyTransformer.arguments import set_random_seed
77
- from SwissArmyTransformer.generation.sampling_strategies import BaseStrategy
78
- from SwissArmyTransformer.resources import auto_create
79
-
80
- app_dir = pathlib.Path(__file__).parent
81
- submodule_dir = app_dir / 'CogVideo'
82
- sys.path.insert(0, submodule_dir.as_posix())
83
-
84
- from coglm_strategy import CoglmStrategy
85
- from models.cogvideo_cache_model import CogVideoCacheModel
86
- from sr_pipeline import DirectSuperResolution
87
-
88
- formatter = logging.Formatter(
89
- '[%(asctime)s] %(name)s %(levelname)s: %(message)s',
90
- datefmt='%Y-%m-%d %H:%M:%S')
91
- stream_handler = logging.StreamHandler(stream=sys.stdout)
92
- stream_handler.setLevel(logging.INFO)
93
- stream_handler.setFormatter(formatter)
94
- logger = logging.getLogger(__name__)
95
- logger.setLevel(logging.INFO)
96
- logger.propagate = False
97
- logger.addHandler(stream_handler)
98
-
99
- ICETK_MODEL_DIR = app_dir / 'icetk_models'
100
-
101
-
102
- def get_masks_and_position_ids_stage1(data, textlen, framelen):
103
- # Extract batch size and sequence length.
104
- tokens = data
105
- seq_length = len(data[0])
106
- # Attention mask (lower triangular).
107
- attention_mask = torch.ones((1, textlen + framelen, textlen + framelen),
108
- device=data.device)
109
- attention_mask[:, :textlen, textlen:] = 0
110
- attention_mask[:, textlen:, textlen:].tril_()
111
- attention_mask.unsqueeze_(1)
112
- # Unaligned version
113
- position_ids = torch.zeros(seq_length,
114
- dtype=torch.long,
115
- device=data.device)
116
- torch.arange(textlen,
117
- out=position_ids[:textlen],
118
- dtype=torch.long,
119
- device=data.device)
120
- torch.arange(512,
121
- 512 + seq_length - textlen,
122
- out=position_ids[textlen:],
123
- dtype=torch.long,
124
- device=data.device)
125
- position_ids = position_ids.unsqueeze(0)
126
-
127
- return tokens, attention_mask, position_ids
128
-
129
-
130
- def get_masks_and_position_ids_stage2(data, textlen, framelen):
131
- # Extract batch size and sequence length.
132
- tokens = data
133
- seq_length = len(data[0])
134
-
135
- # Attention mask (lower triangular).
136
- attention_mask = torch.ones((1, textlen + framelen, textlen + framelen),
137
- device=data.device)
138
- attention_mask[:, :textlen, textlen:] = 0
139
- attention_mask[:, textlen:, textlen:].tril_()
140
- attention_mask.unsqueeze_(1)
141
-
142
- # Unaligned version
143
- position_ids = torch.zeros(seq_length,
144
- dtype=torch.long,
145
- device=data.device)
146
- torch.arange(textlen,
147
- out=position_ids[:textlen],
148
- dtype=torch.long,
149
- device=data.device)
150
- frame_num = (seq_length - textlen) // framelen
151
- assert frame_num == 5
152
- torch.arange(512,
153
- 512 + framelen,
154
- out=position_ids[textlen:textlen + framelen],
155
- dtype=torch.long,
156
- device=data.device)
157
- torch.arange(512 + framelen * 2,
158
- 512 + framelen * 3,
159
- out=position_ids[textlen + framelen:textlen + framelen * 2],
160
- dtype=torch.long,
161
- device=data.device)
162
- torch.arange(512 + framelen * (frame_num - 1),
163
- 512 + framelen * frame_num,
164
- out=position_ids[textlen + framelen * 2:textlen +
165
- framelen * 3],
166
- dtype=torch.long,
167
- device=data.device)
168
- torch.arange(512 + framelen * 1,
169
- 512 + framelen * 2,
170
- out=position_ids[textlen + framelen * 3:textlen +
171
- framelen * 4],
172
- dtype=torch.long,
173
- device=data.device)
174
- torch.arange(512 + framelen * 3,
175
- 512 + framelen * 4,
176
- out=position_ids[textlen + framelen * 4:textlen +
177
- framelen * 5],
178
- dtype=torch.long,
179
- device=data.device)
180
-
181
- position_ids = position_ids.unsqueeze(0)
182
-
183
- return tokens, attention_mask, position_ids
184
-
185
-
186
- def my_update_mems(hiddens, mems_buffers, mems_indexs,
187
- limited_spatial_channel_mem, text_len, frame_len):
188
- if hiddens is None:
189
- return None, mems_indexs
190
- mem_num = len(hiddens)
191
- ret_mem = []
192
- with torch.no_grad():
193
- for id in range(mem_num):
194
- if hiddens[id][0] is None:
195
- ret_mem.append(None)
196
- else:
197
- if id == 0 and limited_spatial_channel_mem and mems_indexs[
198
- id] + hiddens[0][0].shape[1] >= text_len + frame_len:
199
- if mems_indexs[id] == 0:
200
- for layer, hidden in enumerate(hiddens[id]):
201
- mems_buffers[id][
202
- layer, :, :text_len] = hidden.expand(
203
- mems_buffers[id].shape[1], -1,
204
- -1)[:, :text_len]
205
- new_mem_len_part2 = (mems_indexs[id] +
206
- hiddens[0][0].shape[1] -
207
- text_len) % frame_len
208
- if new_mem_len_part2 > 0:
209
- for layer, hidden in enumerate(hiddens[id]):
210
- mems_buffers[id][
211
- layer, :, text_len:text_len +
212
- new_mem_len_part2] = hidden.expand(
213
- mems_buffers[id].shape[1], -1,
214
- -1)[:, -new_mem_len_part2:]
215
- mems_indexs[id] = text_len + new_mem_len_part2
216
- else:
217
- for layer, hidden in enumerate(hiddens[id]):
218
- mems_buffers[id][layer, :,
219
- mems_indexs[id]:mems_indexs[id] +
220
- hidden.shape[1]] = hidden.expand(
221
- mems_buffers[id].shape[1], -1, -1)
222
- mems_indexs[id] += hidden.shape[1]
223
- ret_mem.append(mems_buffers[id][:, :, :mems_indexs[id]])
224
- return ret_mem, mems_indexs
225
-
226
-
227
- def calc_next_tokens_frame_begin_id(text_len, frame_len, total_len):
228
- # The fisrt token's position id of the frame that the next token belongs to;
229
- if total_len < text_len:
230
- return None
231
- return (total_len - text_len) // frame_len * frame_len + text_len
232
-
233
-
234
- def my_filling_sequence(
235
- model,
236
- tokenizer,
237
- args,
238
- seq,
239
- batch_size,
240
- get_masks_and_position_ids,
241
- text_len,
242
- frame_len,
243
- strategy=BaseStrategy(),
244
- strategy2=BaseStrategy(),
245
- mems=None,
246
- log_text_attention_weights=0, # default to 0: no artificial change
247
- mode_stage1=True,
248
- enforce_no_swin=False,
249
- guider_seq=None,
250
- guider_text_len=0,
251
- guidance_alpha=1,
252
- limited_spatial_channel_mem=False, # 空间通道的存储限制在本帧内
253
- **kw_args):
254
- '''
255
- seq: [2, 3, 5, ..., -1(to be generated), -1, ...]
256
- mems: [num_layers, batch_size, len_mems(index), mem_hidden_size]
257
- cache, should be first mems.shape[1] parts of context_tokens.
258
- mems are the first-level citizens here, but we don't assume what is memorized.
259
- input mems are used when multi-phase generation.
260
- '''
261
- if guider_seq is not None:
262
- logger.debug('Using Guidance In Inference')
263
- if limited_spatial_channel_mem:
264
- logger.debug("Limit spatial-channel's mem to current frame")
265
- assert len(seq.shape) == 2
266
-
267
- # building the initial tokens, attention_mask, and position_ids
268
- actual_context_length = 0
269
-
270
- while seq[-1][
271
- actual_context_length] >= 0: # the last seq has least given tokens
272
- actual_context_length += 1 # [0, context_length-1] are given
273
- assert actual_context_length > 0
274
- current_frame_num = (actual_context_length - text_len) // frame_len
275
- assert current_frame_num >= 0
276
- context_length = text_len + current_frame_num * frame_len
277
-
278
- tokens, attention_mask, position_ids = get_masks_and_position_ids(
279
- seq, text_len, frame_len)
280
- tokens = tokens[..., :context_length]
281
- input_tokens = tokens.clone()
282
-
283
- if guider_seq is not None:
284
- guider_index_delta = text_len - guider_text_len
285
- guider_tokens, guider_attention_mask, guider_position_ids = get_masks_and_position_ids(
286
- guider_seq, guider_text_len, frame_len)
287
- guider_tokens = guider_tokens[..., :context_length -
288
- guider_index_delta]
289
- guider_input_tokens = guider_tokens.clone()
290
-
291
- for fid in range(current_frame_num):
292
- input_tokens[:, text_len + 400 * fid] = tokenizer['<start_of_image>']
293
- if guider_seq is not None:
294
- guider_input_tokens[:, guider_text_len +
295
- 400 * fid] = tokenizer['<start_of_image>']
296
-
297
- attention_mask = attention_mask.type_as(next(
298
- model.parameters())) # if fp16
299
- # initialize generation
300
- counter = context_length - 1 # Last fixed index is ``counter''
301
- index = 0 # Next forward starting index, also the length of cache.
302
- mems_buffers_on_GPU = False
303
- mems_indexs = [0, 0]
304
- mems_len = [(400 + 74) if limited_spatial_channel_mem else 5 * 400 + 74,
305
- 5 * 400 + 74]
306
- mems_buffers = [
307
- torch.zeros(args.num_layers,
308
- batch_size,
309
- mem_len,
310
- args.hidden_size * 2,
311
- dtype=next(model.parameters()).dtype)
312
- for mem_len in mems_len
313
- ]
314
-
315
- if guider_seq is not None:
316
- guider_attention_mask = guider_attention_mask.type_as(
317
- next(model.parameters())) # if fp16
318
- guider_mems_buffers = [
319
- torch.zeros(args.num_layers,
320
- batch_size,
321
- mem_len,
322
- args.hidden_size * 2,
323
- dtype=next(model.parameters()).dtype)
324
- for mem_len in mems_len
325
- ]
326
- guider_mems_indexs = [0, 0]
327
- guider_mems = None
328
-
329
- torch.cuda.empty_cache()
330
- # step-by-step generation
331
- while counter < len(seq[0]) - 1:
332
- # we have generated counter+1 tokens
333
- # Now, we want to generate seq[counter + 1],
334
- # token[:, index: counter+1] needs forwarding.
335
- if index == 0:
336
- group_size = 2 if (input_tokens.shape[0] == batch_size
337
- and not mode_stage1) else batch_size
338
-
339
- logits_all = None
340
- for batch_idx in range(0, input_tokens.shape[0], group_size):
341
- logits, *output_per_layers = model(
342
- input_tokens[batch_idx:batch_idx + group_size, index:],
343
- position_ids[..., index:counter + 1],
344
- attention_mask, # TODO memlen
345
- mems=mems,
346
- text_len=text_len,
347
- frame_len=frame_len,
348
- counter=counter,
349
- log_text_attention_weights=log_text_attention_weights,
350
- enforce_no_swin=enforce_no_swin,
351
- **kw_args)
352
- logits_all = torch.cat(
353
- (logits_all,
354
- logits), dim=0) if logits_all is not None else logits
355
- mem_kv01 = [[o['mem_kv'][0] for o in output_per_layers],
356
- [o['mem_kv'][1] for o in output_per_layers]]
357
- next_tokens_frame_begin_id = calc_next_tokens_frame_begin_id(
358
- text_len, frame_len, mem_kv01[0][0].shape[1])
359
- for id, mem_kv in enumerate(mem_kv01):
360
- for layer, mem_kv_perlayer in enumerate(mem_kv):
361
- if limited_spatial_channel_mem and id == 0:
362
- mems_buffers[id][
363
- layer, batch_idx:batch_idx + group_size, :
364
- text_len] = mem_kv_perlayer.expand(
365
- min(group_size,
366
- input_tokens.shape[0] - batch_idx), -1,
367
- -1)[:, :text_len]
368
- mems_buffers[id][layer, batch_idx:batch_idx+group_size, text_len:text_len+mem_kv_perlayer.shape[1]-next_tokens_frame_begin_id] =\
369
- mem_kv_perlayer.expand(min(group_size, input_tokens.shape[0]-batch_idx), -1, -1)[:, next_tokens_frame_begin_id:]
370
- else:
371
- mems_buffers[id][
372
- layer, batch_idx:batch_idx +
373
- group_size, :mem_kv_perlayer.
374
- shape[1]] = mem_kv_perlayer.expand(
375
- min(group_size,
376
- input_tokens.shape[0] - batch_idx), -1,
377
- -1)
378
- mems_indexs[0], mems_indexs[1] = mem_kv01[0][0].shape[
379
- 1], mem_kv01[1][0].shape[1]
380
- if limited_spatial_channel_mem:
381
- mems_indexs[0] -= (next_tokens_frame_begin_id - text_len)
382
-
383
- mems = [
384
- mems_buffers[id][:, :, :mems_indexs[id]] for id in range(2)
385
- ]
386
- logits = logits_all
387
-
388
- # Guider
389
- if guider_seq is not None:
390
- guider_logits_all = None
391
- for batch_idx in range(0, guider_input_tokens.shape[0],
392
- group_size):
393
- guider_logits, *guider_output_per_layers = model(
394
- guider_input_tokens[batch_idx:batch_idx + group_size,
395
- max(index -
396
- guider_index_delta, 0):],
397
- guider_position_ids[
398
- ...,
399
- max(index - guider_index_delta, 0):counter + 1 -
400
- guider_index_delta],
401
- guider_attention_mask,
402
- mems=guider_mems,
403
- text_len=guider_text_len,
404
- frame_len=frame_len,
405
- counter=counter - guider_index_delta,
406
- log_text_attention_weights=log_text_attention_weights,
407
- enforce_no_swin=enforce_no_swin,
408
- **kw_args)
409
- guider_logits_all = torch.cat(
410
- (guider_logits_all, guider_logits), dim=0
411
- ) if guider_logits_all is not None else guider_logits
412
- guider_mem_kv01 = [[
413
- o['mem_kv'][0] for o in guider_output_per_layers
414
- ], [o['mem_kv'][1] for o in guider_output_per_layers]]
415
- for id, guider_mem_kv in enumerate(guider_mem_kv01):
416
- for layer, guider_mem_kv_perlayer in enumerate(
417
- guider_mem_kv):
418
- if limited_spatial_channel_mem and id == 0:
419
- guider_mems_buffers[id][
420
- layer, batch_idx:batch_idx + group_size, :
421
- guider_text_len] = guider_mem_kv_perlayer.expand(
422
- min(group_size,
423
- input_tokens.shape[0] - batch_idx),
424
- -1, -1)[:, :guider_text_len]
425
- guider_next_tokens_frame_begin_id = calc_next_tokens_frame_begin_id(
426
- guider_text_len, frame_len,
427
- guider_mem_kv_perlayer.shape[1])
428
- guider_mems_buffers[id][layer, batch_idx:batch_idx+group_size, guider_text_len:guider_text_len+guider_mem_kv_perlayer.shape[1]-guider_next_tokens_frame_begin_id] =\
429
- guider_mem_kv_perlayer.expand(min(group_size, input_tokens.shape[0]-batch_idx), -1, -1)[:, guider_next_tokens_frame_begin_id:]
430
- else:
431
- guider_mems_buffers[id][
432
- layer, batch_idx:batch_idx +
433
- group_size, :guider_mem_kv_perlayer.
434
- shape[1]] = guider_mem_kv_perlayer.expand(
435
- min(group_size,
436
- input_tokens.shape[0] - batch_idx),
437
- -1, -1)
438
- guider_mems_indexs[0], guider_mems_indexs[
439
- 1] = guider_mem_kv01[0][0].shape[1], guider_mem_kv01[
440
- 1][0].shape[1]
441
- if limited_spatial_channel_mem:
442
- guider_mems_indexs[0] -= (
443
- guider_next_tokens_frame_begin_id -
444
- guider_text_len)
445
- guider_mems = [
446
- guider_mems_buffers[id][:, :, :guider_mems_indexs[id]]
447
- for id in range(2)
448
- ]
449
- guider_logits = guider_logits_all
450
- else:
451
- if not mems_buffers_on_GPU:
452
- if not mode_stage1:
453
- torch.cuda.empty_cache()
454
- for idx, mem in enumerate(mems):
455
- mems[idx] = mem.to(next(model.parameters()).device)
456
- if guider_seq is not None:
457
- for idx, mem in enumerate(guider_mems):
458
- guider_mems[idx] = mem.to(
459
- next(model.parameters()).device)
460
- else:
461
- torch.cuda.empty_cache()
462
- for idx, mem_buffer in enumerate(mems_buffers):
463
- mems_buffers[idx] = mem_buffer.to(
464
- next(model.parameters()).device)
465
- mems = [
466
- mems_buffers[id][:, :, :mems_indexs[id]]
467
- for id in range(2)
468
- ]
469
- if guider_seq is not None:
470
- for idx, guider_mem_buffer in enumerate(
471
- guider_mems_buffers):
472
- guider_mems_buffers[idx] = guider_mem_buffer.to(
473
- next(model.parameters()).device)
474
- guider_mems = [
475
- guider_mems_buffers[id]
476
- [:, :, :guider_mems_indexs[id]] for id in range(2)
477
- ]
478
- mems_buffers_on_GPU = True
479
-
480
- logits, *output_per_layers = model(
481
- input_tokens[:, index:],
482
- position_ids[..., index:counter + 1],
483
- attention_mask, # TODO memlen
484
- mems=mems,
485
- text_len=text_len,
486
- frame_len=frame_len,
487
- counter=counter,
488
- log_text_attention_weights=log_text_attention_weights,
489
- enforce_no_swin=enforce_no_swin,
490
- limited_spatial_channel_mem=limited_spatial_channel_mem,
491
- **kw_args)
492
- mem_kv0, mem_kv1 = [o['mem_kv'][0] for o in output_per_layers
493
- ], [o['mem_kv'][1] for o in output_per_layers]
494
-
495
- if guider_seq is not None:
496
- guider_logits, *guider_output_per_layers = model(
497
- guider_input_tokens[:,
498
- max(index - guider_index_delta, 0):],
499
- guider_position_ids[...,
500
- max(index -
501
- guider_index_delta, 0):counter +
502
- 1 - guider_index_delta],
503
- guider_attention_mask,
504
- mems=guider_mems,
505
- text_len=guider_text_len,
506
- frame_len=frame_len,
507
- counter=counter - guider_index_delta,
508
- log_text_attention_weights=0,
509
- enforce_no_swin=enforce_no_swin,
510
- limited_spatial_channel_mem=limited_spatial_channel_mem,
511
- **kw_args)
512
- guider_mem_kv0, guider_mem_kv1 = [
513
- o['mem_kv'][0] for o in guider_output_per_layers
514
- ], [o['mem_kv'][1] for o in guider_output_per_layers]
515
-
516
- if not mems_buffers_on_GPU:
517
- torch.cuda.empty_cache()
518
- for idx, mem_buffer in enumerate(mems_buffers):
519
- mems_buffers[idx] = mem_buffer.to(
520
- next(model.parameters()).device)
521
- if guider_seq is not None:
522
- for idx, guider_mem_buffer in enumerate(
523
- guider_mems_buffers):
524
- guider_mems_buffers[idx] = guider_mem_buffer.to(
525
- next(model.parameters()).device)
526
- mems_buffers_on_GPU = True
527
-
528
- mems, mems_indexs = my_update_mems([mem_kv0, mem_kv1],
529
- mems_buffers, mems_indexs,
530
- limited_spatial_channel_mem,
531
- text_len, frame_len)
532
- if guider_seq is not None:
533
- guider_mems, guider_mems_indexs = my_update_mems(
534
- [guider_mem_kv0, guider_mem_kv1], guider_mems_buffers,
535
- guider_mems_indexs, limited_spatial_channel_mem,
536
- guider_text_len, frame_len)
537
-
538
- counter += 1
539
- index = counter
540
-
541
- logits = logits[:, -1].expand(batch_size,
542
- -1) # [batch size, vocab size]
543
- tokens = tokens.expand(batch_size, -1)
544
- if guider_seq is not None:
545
- guider_logits = guider_logits[:, -1].expand(batch_size, -1)
546
- guider_tokens = guider_tokens.expand(batch_size, -1)
547
-
548
- if seq[-1][counter].item() < 0:
549
- # sampling
550
- guided_logits = guider_logits + (
551
- logits - guider_logits
552
- ) * guidance_alpha if guider_seq is not None else logits
553
- if mode_stage1 and counter < text_len + 400:
554
- tokens, mems = strategy.forward(guided_logits, tokens, mems)
555
- else:
556
- tokens, mems = strategy2.forward(guided_logits, tokens, mems)
557
- if guider_seq is not None:
558
- guider_tokens = torch.cat((guider_tokens, tokens[:, -1:]),
559
- dim=1)
560
-
561
- if seq[0][counter].item() >= 0:
562
- for si in range(seq.shape[0]):
563
- if seq[si][counter].item() >= 0:
564
- tokens[si, -1] = seq[si, counter]
565
- if guider_seq is not None:
566
- guider_tokens[si,
567
- -1] = guider_seq[si, counter -
568
- guider_index_delta]
569
-
570
- else:
571
- tokens = torch.cat(
572
- (tokens, seq[:, counter:counter + 1].clone().expand(
573
- tokens.shape[0], 1).to(device=tokens.device,
574
- dtype=tokens.dtype)),
575
- dim=1)
576
- if guider_seq is not None:
577
- guider_tokens = torch.cat(
578
- (guider_tokens,
579
- guider_seq[:, counter - guider_index_delta:counter + 1 -
580
- guider_index_delta].clone().expand(
581
- guider_tokens.shape[0], 1).to(
582
- device=guider_tokens.device,
583
- dtype=guider_tokens.dtype)),
584
- dim=1)
585
-
586
- input_tokens = tokens.clone()
587
- if guider_seq is not None:
588
- guider_input_tokens = guider_tokens.clone()
589
- if (index - text_len - 1) // 400 < (input_tokens.shape[-1] - text_len -
590
- 1) // 400:
591
- boi_idx = ((index - text_len - 1) // 400 + 1) * 400 + text_len
592
- while boi_idx < input_tokens.shape[-1]:
593
- input_tokens[:, boi_idx] = tokenizer['<start_of_image>']
594
- if guider_seq is not None:
595
- guider_input_tokens[:, boi_idx -
596
- guider_index_delta] = tokenizer[
597
- '<start_of_image>']
598
- boi_idx += 400
599
-
600
- if strategy.is_done:
601
- break
602
- return strategy.finalize(tokens, mems)
603
-
604
-
605
- class InferenceModel_Sequential(CogVideoCacheModel):
606
- def __init__(self, args, transformer=None, parallel_output=True):
607
- super().__init__(args,
608
- transformer=transformer,
609
- parallel_output=parallel_output,
610
- window_size=-1,
611
- cogvideo_stage=1)
612
-
613
- # TODO: check it
614
-
615
- def final_forward(self, logits, **kwargs):
616
- logits_parallel = logits
617
- logits_parallel = torch.nn.functional.linear(
618
- logits_parallel.float(),
619
- self.transformer.word_embeddings.weight[:20000].float())
620
- return logits_parallel
621
-
622
-
623
- class InferenceModel_Interpolate(CogVideoCacheModel):
624
- def __init__(self, args, transformer=None, parallel_output=True):
625
- super().__init__(args,
626
- transformer=transformer,
627
- parallel_output=parallel_output,
628
- window_size=10,
629
- cogvideo_stage=2)
630
-
631
- # TODO: check it
632
-
633
- def final_forward(self, logits, **kwargs):
634
- logits_parallel = logits
635
- logits_parallel = torch.nn.functional.linear(
636
- logits_parallel.float(),
637
- self.transformer.word_embeddings.weight[:20000].float())
638
- return logits_parallel
639
-
640
-
641
- def get_default_args() -> argparse.Namespace:
642
- known = argparse.Namespace(generate_frame_num=5,
643
- coglm_temperature2=0.89,
644
- use_guidance_stage1=True,
645
- use_guidance_stage2=False,
646
- guidance_alpha=3.0,
647
- stage_1=True,
648
- stage_2=False,
649
- both_stages=False,
650
- parallel_size=1,
651
- stage1_max_inference_batch_size=-1,
652
- multi_gpu=False,
653
- layout='64, 464, 2064',
654
- window_size=10,
655
- additional_seqlen=2000,
656
- cogvideo_stage=1)
657
-
658
- args_list = [
659
- '--tokenizer-type',
660
- 'fake',
661
- '--mode',
662
- 'inference',
663
- '--distributed-backend',
664
- 'nccl',
665
- '--fp16',
666
- '--model-parallel-size',
667
- '1',
668
- '--temperature',
669
- '1.05',
670
- '--top_k',
671
- '12',
672
- '--sandwich-ln',
673
- '--seed',
674
- '1234',
675
- '--num-workers',
676
- '0',
677
- '--batch-size',
678
- '1',
679
- '--max-inference-batch-size',
680
- '8',
681
- ]
682
- args = get_args(args_list)
683
- args = argparse.Namespace(**vars(args), **vars(known))
684
- args.layout = [int(x) for x in args.layout.split(',')]
685
- args.do_train = False
686
- return args
687
-
688
-
689
- class Model:
690
- def __init__(self, only_first_stage: bool = False):
691
- self.args = get_default_args()
692
- if only_first_stage:
693
- self.args.stage_1 = True
694
- self.args.both_stages = False
695
- else:
696
- self.args.stage_1 = False
697
- self.args.both_stages = True
698
-
699
- self.tokenizer = self.load_tokenizer()
700
-
701
- self.model_stage1, self.args = self.load_model_stage1()
702
- self.model_stage2, self.args = self.load_model_stage2()
703
-
704
- self.strategy_cogview2, self.strategy_cogvideo = self.load_strategies()
705
- self.dsr = self.load_dsr()
706
-
707
- self.device = torch.device(self.args.device)
708
-
709
- def load_tokenizer(self) -> IceTokenizer:
710
- logger.info('--- load_tokenizer ---')
711
- start = time.perf_counter()
712
-
713
- tokenizer = IceTokenizer(ICETK_MODEL_DIR.as_posix())
714
- tokenizer.add_special_tokens(
715
- ['<start_of_image>', '<start_of_english>', '<start_of_chinese>'])
716
-
717
- elapsed = time.perf_counter() - start
718
- logger.info(f'--- done ({elapsed=:.3f}) ---')
719
- return tokenizer
720
-
721
- def load_model_stage1(
722
- self) -> tuple[CogVideoCacheModel, argparse.Namespace]:
723
- logger.info('--- load_model_stage1 ---')
724
- start = time.perf_counter()
725
-
726
- args = self.args
727
- model_stage1, args = InferenceModel_Sequential.from_pretrained(
728
- args, 'cogvideo-stage1')
729
- model_stage1.eval()
730
- if args.both_stages:
731
- model_stage1 = model_stage1.cpu()
732
-
733
- elapsed = time.perf_counter() - start
734
- logger.info(f'--- done ({elapsed=:.3f}) ---')
735
- return model_stage1, args
736
-
737
- def load_model_stage2(
738
- self) -> tuple[CogVideoCacheModel | None, argparse.Namespace]:
739
- logger.info('--- load_model_stage2 ---')
740
- start = time.perf_counter()
741
-
742
- args = self.args
743
- if args.both_stages:
744
- model_stage2, args = InferenceModel_Interpolate.from_pretrained(
745
- args, 'cogvideo-stage2')
746
- model_stage2.eval()
747
- if args.both_stages:
748
- model_stage2 = model_stage2.cpu()
749
- else:
750
- model_stage2 = None
751
-
752
- elapsed = time.perf_counter() - start
753
- logger.info(f'--- done ({elapsed=:.3f}) ---')
754
- return model_stage2, args
755
-
756
- def load_strategies(self) -> tuple[CoglmStrategy, CoglmStrategy]:
757
- logger.info('--- load_strategies ---')
758
- start = time.perf_counter()
759
-
760
- invalid_slices = [slice(self.tokenizer.num_image_tokens, None)]
761
- strategy_cogview2 = CoglmStrategy(invalid_slices,
762
- temperature=1.0,
763
- top_k=16)
764
- strategy_cogvideo = CoglmStrategy(
765
- invalid_slices,
766
- temperature=self.args.temperature,
767
- top_k=self.args.top_k,
768
- temperature2=self.args.coglm_temperature2)
769
-
770
- elapsed = time.perf_counter() - start
771
- logger.info(f'--- done ({elapsed=:.3f}) ---')
772
- return strategy_cogview2, strategy_cogvideo
773
-
774
- def load_dsr(self) -> DirectSuperResolution | None:
775
- logger.info('--- load_dsr ---')
776
- start = time.perf_counter()
777
-
778
- if self.args.both_stages:
779
- path = auto_create('cogview2-dsr', path=None)
780
- dsr = DirectSuperResolution(self.args,
781
- path,
782
- max_bz=12,
783
- onCUDA=False)
784
- else:
785
- dsr = None
786
-
787
- elapsed = time.perf_counter() - start
788
- logger.info(f'--- done ({elapsed=:.3f}) ---')
789
- return dsr
790
-
791
- @torch.inference_mode()
792
- def process_stage1(self,
793
- model,
794
- seq_text,
795
- duration,
796
- video_raw_text=None,
797
- video_guidance_text='视频',
798
- image_text_suffix='',
799
- batch_size=1,
800
- image_prompt=None):
801
- process_start_time = time.perf_counter()
802
-
803
- generate_frame_num = self.args.generate_frame_num
804
- tokenizer = self.tokenizer
805
- use_guide = self.args.use_guidance_stage1
806
-
807
- if next(model.parameters()).device != self.device:
808
- move_start_time = time.perf_counter()
809
- logger.debug('moving stage 1 model to cuda')
810
-
811
- model = model.to(self.device)
812
-
813
- elapsed = time.perf_counter() - move_start_time
814
- logger.debug(f'moving in model1 takes time: {elapsed:.2f}')
815
-
816
- if video_raw_text is None:
817
- video_raw_text = seq_text
818
- mbz = self.args.stage1_max_inference_batch_size if self.args.stage1_max_inference_batch_size > 0 else self.args.max_inference_batch_size
819
- assert batch_size < mbz or batch_size % mbz == 0
820
- frame_len = 400
821
-
822
- # generate the first frame:
823
- enc_text = tokenizer.encode(seq_text + image_text_suffix)
824
- seq_1st = enc_text + [tokenizer['<start_of_image>']] + [-1] * 400
825
- logger.info(
826
- f'[Generating First Frame with CogView2] Raw text: {tokenizer.decode(enc_text):s}'
827
- )
828
- text_len_1st = len(seq_1st) - frame_len * 1 - 1
829
-
830
- seq_1st = torch.tensor(seq_1st, dtype=torch.long,
831
- device=self.device).unsqueeze(0)
832
- if image_prompt is None:
833
- output_list_1st = []
834
- for tim in range(max(batch_size // mbz, 1)):
835
- start_time = time.perf_counter()
836
- output_list_1st.append(
837
- my_filling_sequence(
838
- model,
839
- tokenizer,
840
- self.args,
841
- seq_1st.clone(),
842
- batch_size=min(batch_size, mbz),
843
- get_masks_and_position_ids=
844
- get_masks_and_position_ids_stage1,
845
- text_len=text_len_1st,
846
- frame_len=frame_len,
847
- strategy=self.strategy_cogview2,
848
- strategy2=self.strategy_cogvideo,
849
- log_text_attention_weights=1.4,
850
- enforce_no_swin=True,
851
- mode_stage1=True,
852
- )[0])
853
- elapsed = time.perf_counter() - start_time
854
- logger.info(f'[First Frame] Elapsed: {elapsed:.2f}')
855
- output_tokens_1st = torch.cat(output_list_1st, dim=0)
856
- given_tokens = output_tokens_1st[:, text_len_1st + 1:text_len_1st +
857
- 401].unsqueeze(
858
- 1
859
- ) # given_tokens.shape: [bs, frame_num, 400]
860
- else:
861
- given_tokens = tokenizer.encode(image_path=image_prompt, image_size=160).repeat(batch_size, 1).unsqueeze(1)
862
-
863
- # generate subsequent frames:
864
- total_frames = generate_frame_num
865
- enc_duration = tokenizer.encode(f'{float(duration)}秒')
866
- if use_guide:
867
- video_raw_text = video_raw_text + ' 视频'
868
- enc_text_video = tokenizer.encode(video_raw_text)
869
- seq = enc_duration + [tokenizer['<n>']] + enc_text_video + [
870
- tokenizer['<start_of_image>']
871
- ] + [-1] * 400 * generate_frame_num
872
- guider_seq = enc_duration + [tokenizer['<n>']] + tokenizer.encode(
873
- video_guidance_text) + [tokenizer['<start_of_image>']
874
- ] + [-1] * 400 * generate_frame_num
875
- logger.info(
876
- f'[Stage1: Generating Subsequent Frames, Frame Rate {4/duration:.1f}] raw text: {tokenizer.decode(enc_text_video):s}'
877
- )
878
-
879
- text_len = len(seq) - frame_len * generate_frame_num - 1
880
- guider_text_len = len(guider_seq) - frame_len * generate_frame_num - 1
881
- seq = torch.tensor(seq, dtype=torch.long,
882
- device=self.device).unsqueeze(0).repeat(
883
- batch_size, 1)
884
- guider_seq = torch.tensor(guider_seq,
885
- dtype=torch.long,
886
- device=self.device).unsqueeze(0).repeat(
887
- batch_size, 1)
888
-
889
- for given_frame_id in range(given_tokens.shape[1]):
890
- seq[:, text_len + 1 + given_frame_id * 400:text_len + 1 +
891
- (given_frame_id + 1) * 400] = given_tokens[:, given_frame_id]
892
- guider_seq[:, guider_text_len + 1 +
893
- given_frame_id * 400:guider_text_len + 1 +
894
- (given_frame_id + 1) *
895
- 400] = given_tokens[:, given_frame_id]
896
- output_list = []
897
-
898
- if use_guide:
899
- video_log_text_attention_weights = 0
900
- else:
901
- guider_seq = None
902
- video_log_text_attention_weights = 1.4
903
-
904
- for tim in range(max(batch_size // mbz, 1)):
905
- input_seq = seq[:min(batch_size, mbz)].clone(
906
- ) if tim == 0 else seq[mbz * tim:mbz * (tim + 1)].clone()
907
- guider_seq2 = (guider_seq[:min(batch_size, mbz)].clone()
908
- if tim == 0 else guider_seq[mbz * tim:mbz *
909
- (tim + 1)].clone()
910
- ) if guider_seq is not None else None
911
- output_list.append(
912
- my_filling_sequence(
913
- model,
914
- tokenizer,
915
- self.args,
916
- input_seq,
917
- batch_size=min(batch_size, mbz),
918
- get_masks_and_position_ids=
919
- get_masks_and_position_ids_stage1,
920
- text_len=text_len,
921
- frame_len=frame_len,
922
- strategy=self.strategy_cogview2,
923
- strategy2=self.strategy_cogvideo,
924
- log_text_attention_weights=video_log_text_attention_weights,
925
- guider_seq=guider_seq2,
926
- guider_text_len=guider_text_len,
927
- guidance_alpha=self.args.guidance_alpha,
928
- limited_spatial_channel_mem=True,
929
- mode_stage1=True,
930
- )[0])
931
-
932
- output_tokens = torch.cat(output_list, dim=0)[:, 1 + text_len:]
933
-
934
- if self.args.both_stages:
935
- move_start_time = time.perf_counter()
936
- logger.debug('moving stage 1 model to cpu')
937
- model = model.cpu()
938
- torch.cuda.empty_cache()
939
- elapsed = time.perf_counter() - move_start_time
940
- logger.debug(f'moving in model1 takes time: {elapsed:.2f}')
941
-
942
- # decoding
943
- res = []
944
- for seq in output_tokens:
945
- decoded_imgs = [
946
- self.postprocess(
947
- torch.nn.functional.interpolate(tokenizer.decode(
948
- image_ids=seq.tolist()[i * 400:(i + 1) * 400]),
949
- size=(480, 480))[0])
950
- for i in range(total_frames)
951
- ]
952
- res.append(decoded_imgs) # only the last image (target)
953
-
954
- assert len(res) == batch_size
955
- tokens = output_tokens[:, :+total_frames * 400].reshape(
956
- -1, total_frames, 400).cpu()
957
-
958
- elapsed = time.perf_counter() - process_start_time
959
- logger.info(f'--- done ({elapsed=:.3f}) ---')
960
- return tokens, res[0]
961
-
962
- @torch.inference_mode()
963
- def process_stage2(self,
964
- model,
965
- seq_text,
966
- duration,
967
- parent_given_tokens,
968
- video_raw_text=None,
969
- video_guidance_text='视频',
970
- gpu_rank=0,
971
- gpu_parallel_size=1):
972
- process_start_time = time.perf_counter()
973
-
974
- generate_frame_num = self.args.generate_frame_num
975
- tokenizer = self.tokenizer
976
- use_guidance = self.args.use_guidance_stage2
977
-
978
- stage2_start_time = time.perf_counter()
979
-
980
- if next(model.parameters()).device != self.device:
981
- move_start_time = time.perf_counter()
982
- logger.debug('moving stage-2 model to cuda')
983
-
984
- model = model.to(self.device)
985
-
986
- elapsed = time.perf_counter() - move_start_time
987
- logger.debug(f'moving in stage-2 model takes time: {elapsed:.2f}')
988
-
989
- try:
990
- sample_num_allgpu = parent_given_tokens.shape[0]
991
- sample_num = sample_num_allgpu // gpu_parallel_size
992
- assert sample_num * gpu_parallel_size == sample_num_allgpu
993
- parent_given_tokens = parent_given_tokens[gpu_rank *
994
- sample_num:(gpu_rank +
995
- 1) *
996
- sample_num]
997
- except:
998
- logger.critical('No frame_tokens found in interpolation, skip')
999
- return False, []
1000
-
1001
- # CogVideo Stage2 Generation
1002
- while duration >= 0.5: # TODO: You can change the boundary to change the frame rate
1003
- parent_given_tokens_num = parent_given_tokens.shape[1]
1004
- generate_batchsize_persample = (parent_given_tokens_num - 1) // 2
1005
- generate_batchsize_total = generate_batchsize_persample * sample_num
1006
- total_frames = generate_frame_num
1007
- frame_len = 400
1008
- enc_text = tokenizer.encode(seq_text)
1009
- enc_duration = tokenizer.encode(str(float(duration)) + '秒')
1010
- seq = enc_duration + [tokenizer['<n>']] + enc_text + [
1011
- tokenizer['<start_of_image>']
1012
- ] + [-1] * 400 * generate_frame_num
1013
- text_len = len(seq) - frame_len * generate_frame_num - 1
1014
-
1015
- logger.info(
1016
- f'[Stage2: Generating Frames, Frame Rate {int(4/duration):d}] raw text: {tokenizer.decode(enc_text):s}'
1017
- )
1018
-
1019
- # generation
1020
- seq = torch.tensor(seq, dtype=torch.long,
1021
- device=self.device).unsqueeze(0).repeat(
1022
- generate_batchsize_total, 1)
1023
- for sample_i in range(sample_num):
1024
- for i in range(generate_batchsize_persample):
1025
- seq[sample_i * generate_batchsize_persample +
1026
- i][text_len + 1:text_len + 1 +
1027
- 400] = parent_given_tokens[sample_i][2 * i]
1028
- seq[sample_i * generate_batchsize_persample +
1029
- i][text_len + 1 + 400:text_len + 1 +
1030
- 800] = parent_given_tokens[sample_i][2 * i + 1]
1031
- seq[sample_i * generate_batchsize_persample +
1032
- i][text_len + 1 + 800:text_len + 1 +
1033
- 1200] = parent_given_tokens[sample_i][2 * i + 2]
1034
-
1035
- if use_guidance:
1036
- guider_seq = enc_duration + [
1037
- tokenizer['<n>']
1038
- ] + tokenizer.encode(video_guidance_text) + [
1039
- tokenizer['<start_of_image>']
1040
- ] + [-1] * 400 * generate_frame_num
1041
- guider_text_len = len(
1042
- guider_seq) - frame_len * generate_frame_num - 1
1043
- guider_seq = torch.tensor(
1044
- guider_seq, dtype=torch.long,
1045
- device=self.device).unsqueeze(0).repeat(
1046
- generate_batchsize_total, 1)
1047
- for sample_i in range(sample_num):
1048
- for i in range(generate_batchsize_persample):
1049
- guider_seq[sample_i * generate_batchsize_persample +
1050
- i][text_len + 1:text_len + 1 +
1051
- 400] = parent_given_tokens[sample_i][2 *
1052
- i]
1053
- guider_seq[sample_i * generate_batchsize_persample +
1054
- i][text_len + 1 + 400:text_len + 1 +
1055
- 800] = parent_given_tokens[sample_i][2 *
1056
- i +
1057
- 1]
1058
- guider_seq[sample_i * generate_batchsize_persample +
1059
- i][text_len + 1 + 800:text_len + 1 +
1060
- 1200] = parent_given_tokens[sample_i][2 *
1061
- i +
1062
- 2]
1063
- video_log_text_attention_weights = 0
1064
- else:
1065
- guider_seq = None
1066
- guider_text_len = 0
1067
- video_log_text_attention_weights = 1.4
1068
-
1069
- mbz = self.args.max_inference_batch_size
1070
-
1071
- assert generate_batchsize_total < mbz or generate_batchsize_total % mbz == 0
1072
- output_list = []
1073
- start_time = time.perf_counter()
1074
- for tim in range(max(generate_batchsize_total // mbz, 1)):
1075
- input_seq = seq[:min(generate_batchsize_total, mbz)].clone(
1076
- ) if tim == 0 else seq[mbz * tim:mbz * (tim + 1)].clone()
1077
- guider_seq2 = (
1078
- guider_seq[:min(generate_batchsize_total, mbz)].clone()
1079
- if tim == 0 else guider_seq[mbz * tim:mbz *
1080
- (tim + 1)].clone()
1081
- ) if guider_seq is not None else None
1082
- output_list.append(
1083
- my_filling_sequence(
1084
- model,
1085
- tokenizer,
1086
- self.args,
1087
- input_seq,
1088
- batch_size=min(generate_batchsize_total, mbz),
1089
- get_masks_and_position_ids=
1090
- get_masks_and_position_ids_stage2,
1091
- text_len=text_len,
1092
- frame_len=frame_len,
1093
- strategy=self.strategy_cogview2,
1094
- strategy2=self.strategy_cogvideo,
1095
- log_text_attention_weights=
1096
- video_log_text_attention_weights,
1097
- mode_stage1=False,
1098
- guider_seq=guider_seq2,
1099
- guider_text_len=guider_text_len,
1100
- guidance_alpha=self.args.guidance_alpha,
1101
- limited_spatial_channel_mem=True,
1102
- )[0])
1103
- elapsed = time.perf_counter() - start_time
1104
- logger.info(f'Duration {duration:.2f}, Elapsed: {elapsed:.2f}\n')
1105
-
1106
- output_tokens = torch.cat(output_list, dim=0)
1107
- output_tokens = output_tokens[:, text_len + 1:text_len + 1 +
1108
- (total_frames) * 400].reshape(
1109
- sample_num, -1,
1110
- 400 * total_frames)
1111
- output_tokens_merge = torch.cat(
1112
- (output_tokens[:, :, :1 * 400], output_tokens[:, :,
1113
- 400 * 3:4 * 400],
1114
- output_tokens[:, :, 400 * 1:2 * 400],
1115
- output_tokens[:, :, 400 * 4:(total_frames) * 400]),
1116
- dim=2).reshape(sample_num, -1, 400)
1117
-
1118
- output_tokens_merge = torch.cat(
1119
- (output_tokens_merge, output_tokens[:, -1:, 400 * 2:3 * 400]),
1120
- dim=1)
1121
- duration /= 2
1122
- parent_given_tokens = output_tokens_merge
1123
-
1124
- if self.args.both_stages:
1125
- move_start_time = time.perf_counter()
1126
- logger.debug('moving stage 2 model to cpu')
1127
- model = model.cpu()
1128
- torch.cuda.empty_cache()
1129
- elapsed = time.perf_counter() - move_start_time
1130
- logger.debug(f'moving out model2 takes time: {elapsed:.2f}')
1131
-
1132
- elapsed = time.perf_counter() - stage2_start_time
1133
- logger.info(f'CogVideo Stage2 completed. Elapsed: {elapsed:.2f}\n')
1134
-
1135
- # direct super-resolution by CogView2
1136
- logger.info('[Direct super-resolution]')
1137
- dsr_start_time = time.perf_counter()
1138
-
1139
- enc_text = tokenizer.encode(seq_text)
1140
- frame_num_per_sample = parent_given_tokens.shape[1]
1141
- parent_given_tokens_2d = parent_given_tokens.reshape(-1, 400)
1142
- text_seq = torch.tensor(enc_text, dtype=torch.long,
1143
- device=self.device).unsqueeze(0).repeat(
1144
- parent_given_tokens_2d.shape[0], 1)
1145
- sred_tokens = self.dsr(text_seq, parent_given_tokens_2d)
1146
-
1147
- decoded_sr_videos = []
1148
- for sample_i in range(sample_num):
1149
- decoded_sr_imgs = []
1150
- for frame_i in range(frame_num_per_sample):
1151
- decoded_sr_img = tokenizer.decode(
1152
- image_ids=sred_tokens[frame_i + sample_i *
1153
- frame_num_per_sample][-3600:])
1154
- decoded_sr_imgs.append(
1155
- self.postprocess(
1156
- torch.nn.functional.interpolate(decoded_sr_img,
1157
- size=(480, 480))[0]))
1158
- decoded_sr_videos.append(decoded_sr_imgs)
1159
-
1160
- elapsed = time.perf_counter() - dsr_start_time
1161
- logger.info(
1162
- f'Direct super-resolution completed. Elapsed: {elapsed:.2f}')
1163
-
1164
- elapsed = time.perf_counter() - process_start_time
1165
- logger.info(f'--- done ({elapsed=:.3f}) ---')
1166
- return True, decoded_sr_videos[0]
1167
-
1168
- @staticmethod
1169
- def postprocess(tensor: torch.Tensor) -> np.ndarray:
1170
- return tensor.cpu().mul(255).add_(0.5).clamp_(0, 255).permute(
1171
- 1, 2, 0).to(torch.uint8).numpy()
1172
-
1173
- def run(self, text: str, seed: int,
1174
- only_first_stage: bool,image_prompt: None) -> list[np.ndarray]:
1175
- logger.info('==================== run ====================')
1176
- start = time.perf_counter()
1177
-
1178
- set_random_seed(seed)
1179
- self.args.seed = seed
1180
-
1181
- if only_first_stage:
1182
- self.args.stage_1 = True
1183
- self.args.both_stages = False
1184
- else:
1185
- self.args.stage_1 = False
1186
- self.args.both_stages = True
1187
-
1188
- parent_given_tokens, res = self.process_stage1(
1189
- self.model_stage1,
1190
- text,
1191
- duration=4.0,
1192
- video_raw_text=text,
1193
- video_guidance_text='视频',
1194
- image_text_suffix=' 高清摄影',
1195
- batch_size=self.args.batch_size,
1196
- image_prompt=image_prompt)
1197
- if not only_first_stage:
1198
- _, res = self.process_stage2(
1199
- self.model_stage2,
1200
- text,
1201
- duration=2.0,
1202
- parent_given_tokens=parent_given_tokens,
1203
- video_raw_text=text + ' 视频',
1204
- video_guidance_text='视频',
1205
- gpu_rank=0,
1206
- gpu_parallel_size=1) # TODO: 修改
1207
-
1208
- elapsed = time.perf_counter() - start
1209
- logger.info(f'Elapsed: {elapsed:.3f}')
1210
- logger.info('==================== done ====================')
1211
- return res
1212
-
1213
-
1214
- class AppModel(Model):
1215
- def __init__(self, only_first_stage: bool):
1216
- super().__init__(only_first_stage)
1217
- self.translator = gr.Interface.load(
1218
- 'spaces/chinhon/translation_eng2ch')
1219
-
1220
- def to_video(self, frames: list[np.ndarray]) -> str:
1221
- out_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
1222
- if self.args.stage_1:
1223
- fps = 4
1224
- else:
1225
- fps = 8
1226
- writer = iio.get_writer(out_file.name, fps=fps)
1227
- for frame in frames:
1228
- writer.append_data(frame)
1229
- writer.close()
1230
- return out_file.name
1231
-
1232
- def run_with_translation(
1233
- self, text: str, translate: bool, seed: int,
1234
- only_first_stage: bool,image_prompt: None) -> tuple[str | None, str | None]:
1235
-
1236
- logger.info(f'{text=}, {translate=}, {seed=}, {only_first_stage=},{image_prompt=}')
1237
- if translate:
1238
- text = translated_text = self.translator(text)
1239
- else:
1240
- translated_text = None
1241
- frames = self.run(text, seed, only_first_stage,image_prompt)
1242
- video_path = self.to_video(frames)
1243
- return translated_text, video_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
patch DELETED
@@ -1,51 +0,0 @@
1
- diff --git a/coglm_strategy.py b/coglm_strategy.py
2
- index d485715..a9eab3b 100644
3
- --- a/coglm_strategy.py
4
- +++ b/coglm_strategy.py
5
- @@ -8,6 +8,7 @@
6
-
7
- # here put the import lib
8
- import os
9
- +import pathlib
10
- import sys
11
- import math
12
- import random
13
- @@ -58,7 +59,8 @@ class CoglmStrategy:
14
- self._is_done = False
15
- self.outlier_count_down = torch.zeros(16)
16
- self.vis_list = [[]for i in range(16)]
17
- - self.cluster_labels = torch.tensor(np.load('cluster_label2.npy'), device='cuda', dtype=torch.long)
18
- + cluster_label_path = pathlib.Path(__file__).parent / 'cluster_label2.npy'
19
- + self.cluster_labels = torch.tensor(np.load(cluster_label_path), device='cuda', dtype=torch.long)
20
- self.start_pos = -1
21
- self.white_cluster = []
22
- # self.fout = open('tmp.txt', 'w')
23
- @@ -98,4 +100,4 @@ class CoglmStrategy:
24
-
25
- def finalize(self, tokens, mems):
26
- self._is_done = False
27
- - return tokens, mems
28
-
29
- + return tokens, mems
30
- diff --git a/sr_pipeline/dsr_sampling.py b/sr_pipeline/dsr_sampling.py
31
- index 5b8dded..07e97fd 100644
32
- --- a/sr_pipeline/dsr_sampling.py
33
- +++ b/sr_pipeline/dsr_sampling.py
34
- @@ -8,6 +8,7 @@
35
-
36
- # here put the import lib
37
- import os
38
- +import pathlib
39
- import sys
40
- import math
41
- import random
42
- @@ -28,7 +29,8 @@ class IterativeEntfilterStrategy:
43
- self.invalid_slices = invalid_slices
44
- self.temperature = temperature
45
- self.topk = topk
46
- - self.cluster_labels = torch.tensor(np.load('cluster_label2.npy'), device='cuda', dtype=torch.long)
47
- + cluster_label_path = pathlib.Path(__file__).parents[1] / 'cluster_label2.npy'
48
- + self.cluster_labels = torch.tensor(np.load(cluster_label_path), device='cuda', dtype=torch.long)
49
-
50
-
51
- def forward(self, logits_, tokens, temperature=None, entfilter=None, filter_topk=5, temperature2=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,7 +1,4 @@
1
- --extra-index-url https://download.pytorch.org/whl/cu113
2
- imageio==2.19.5
3
- imageio-ffmpeg==0.4.7
4
- numpy==1.22.4
5
- opencv-python-headless==4.6.0.66
6
- torch==1.12.0+cu113
7
- torchvision==0.13.0+cu113
 
1
+ gradio>=4.40.0
2
+ imageio-ffmpeg>=0.5.1
3
+
4
+
 
 
 
samples.txt DELETED
@@ -1,2 +0,0 @@
1
- 骑滑板的皮卡丘
2
- a cat playing chess
 
 
 
style.css DELETED
@@ -1,7 +0,0 @@
1
- h1 {
2
- text-align: center;
3
- }
4
- img#visitor-badge {
5
- display: block;
6
- margin: auto;
7
- }