pyx9913
commited on
Commit
•
4d32fc1
1
Parent(s):
93beeec
feat: 🎸 add paint model code
Browse files- README.md +25 -1
- config.json +26 -0
- configuration_cpmbee.py +132 -0
- feature_extractor/preprocessor_config.json +20 -0
- model_index.json +10 -9
- modeling_cpmbee.py +943 -0
- pipeline_stable_diffusion.py +723 -0
- scheduler/scheduler_config.json +14 -0
- tokenization_viscpmbee.py +1008 -0
- tokenizer_config.json +10 -0
- unet/config.json +45 -0
- vae/config.json +29 -0
- vocab.txt +0 -0
README.md
CHANGED
@@ -27,8 +27,32 @@ language:
|
|
27 |
|
28 |
Similar to `VisCPM-Chat`, we found that due to the bilingual capability of `CPM-Bee`, `VisCPM-Paint` can achieve good Chinese text-to-image generation by training only on English text-image pairs, surpassing the performance of Chinese open-source models. By incorporating an additional 20M cleaned native Chinese text-image pairs and 120M translated text-image pairs in Chinese, the model's Chinese text-to-image generation ability can be further improved. We sample 30,000 images from the standard image generation test set MSCOCO and calculated commonly used evaluation metrics FID (Fréchet Inception Distance) to assess the quality of generated images. Similarly, we provide two versions of the model, namely `VisCPM-Paint-balance` and `VisCPM-Paint-zhplus`. The former has a balanced ability in both English and Chinese, while the latter emphasizes Chinese proficiency. `VisCPM-Paint-balance` is trained only using English text-image pairs, while `VisCPM-Paint-zhplus` incorporates an additional 20M native Chinese text-image pairs and 120M translated text-image pairs in Chinese based on `VisCPM-Paint-balance`.
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
## 📝 License
|
31 |
|
32 |
VisCPM is governed by the [GML License](https://github.com/OpenBMB/General-Model-License/blob/main/%E9%80%9A%E7%94%A8%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE-%E6%9D%A5%E6%BA%90%E8%AF%B4%E6%98%8E-%E5%AE%A3%E4%BC%A0%E9%99%90%E5%88%B6-%E9%9D%9E%E5%95%86%E4%B8%9A%E5%8C%96.md), and permits individual and research usages. If you intend to utilize the model for commercial purposes, please reach out to [email protected] to negotiate commercial licensing.
|
33 |
|
34 |
-
The CPM-Bee base, governed by the [General Model License (GML)](https://github.com/OpenBMB/General-Model-License/blob/main/%E9%80%9A%E7%94%A8%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE-%E6%9D%A5%E6%BA%90%E8%AF%B4%E6%98%8E-%E5%AE%A3%E4%BC%A0%E9%99%90%E5%88%B6-%E5%95%86%E4%B8%9A%E6%8E%88%E6%9D%83.md), permits commercial usage. If you intend to utilize the model for commercial purposes, please reach out to [email protected] to obtain the certificate of authorization.
|
|
|
27 |
|
28 |
Similar to `VisCPM-Chat`, we found that due to the bilingual capability of `CPM-Bee`, `VisCPM-Paint` can achieve good Chinese text-to-image generation by training only on English text-image pairs, surpassing the performance of Chinese open-source models. By incorporating an additional 20M cleaned native Chinese text-image pairs and 120M translated text-image pairs in Chinese, the model's Chinese text-to-image generation ability can be further improved. We sample 30,000 images from the standard image generation test set MSCOCO and calculated commonly used evaluation metrics FID (Fréchet Inception Distance) to assess the quality of generated images. Similarly, we provide two versions of the model, namely `VisCPM-Paint-balance` and `VisCPM-Paint-zhplus`. The former has a balanced ability in both English and Chinese, while the latter emphasizes Chinese proficiency. `VisCPM-Paint-balance` is trained only using English text-image pairs, while `VisCPM-Paint-zhplus` incorporates an additional 20M native Chinese text-image pairs and 120M translated text-image pairs in Chinese based on `VisCPM-Paint-balance`.
|
29 |
|
30 |
+
|
31 |
+
## How to Use
|
32 |
+
|
33 |
+
```python
|
34 |
+
#!/usr/bin/env python
|
35 |
+
# encoding: utf-8
|
36 |
+
from diffusers import DiffusionPipeline
|
37 |
+
from transformers import AutoModel
|
38 |
+
from transformers import AutoTokenizer
|
39 |
+
|
40 |
+
|
41 |
+
tokenizer = AutoTokenizer.from_pretrained('openbmb/VisCPM-Paint', trust_remote_code=True)
|
42 |
+
text_encoder = AutoModel.from_pretrained('openbmb/VisCPM-Paint', trust_remote_code=True)
|
43 |
+
print('load pipeline')
|
44 |
+
pipeline = DiffusionPipeline.from_pretrained('openbmb/VisCPM-Paint', custom_pipeline="pipeline_stable_diffusion.py", text_encoder=text_encoder, tokenizer=tokenizer)
|
45 |
+
|
46 |
+
pipeline = pipeline.to('cuda')
|
47 |
+
|
48 |
+
prompt = "a photo of an astronaut riding a horse on mars"
|
49 |
+
image = pipeline(prompt).images[0]
|
50 |
+
|
51 |
+
image.save("astronaut_rides_horse.png")
|
52 |
+
```
|
53 |
+
|
54 |
## 📝 License
|
55 |
|
56 |
VisCPM is governed by the [GML License](https://github.com/OpenBMB/General-Model-License/blob/main/%E9%80%9A%E7%94%A8%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE-%E6%9D%A5%E6%BA%90%E8%AF%B4%E6%98%8E-%E5%AE%A3%E4%BC%A0%E9%99%90%E5%88%B6-%E9%9D%9E%E5%95%86%E4%B8%9A%E5%8C%96.md), and permits individual and research usages. If you intend to utilize the model for commercial purposes, please reach out to [email protected] to negotiate commercial licensing.
|
57 |
|
58 |
+
The CPM-Bee base, governed by the [General Model License (GML)](https://github.com/OpenBMB/General-Model-License/blob/main/%E9%80%9A%E7%94%A8%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE-%E6%9D%A5%E6%BA%90%E8%AF%B4%E6%98%8E-%E5%AE%A3%E4%BC%A0%E9%99%90%E5%88%B6-%E5%95%86%E4%B8%9A%E6%8E%88%E6%9D%83.md), permits commercial usage. If you intend to utilize the model for commercial purposes, please reach out to [email protected] to obtain the certificate of authorization.
|
config.json
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_from_model_config": true,
|
3 |
+
"_name_or_path": "openbmb/cpm-bee-10b",
|
4 |
+
"architectures": [
|
5 |
+
"CpmBeeForWithTransform"
|
6 |
+
],
|
7 |
+
"auto_map": {
|
8 |
+
"AutoConfig": "configuration_cpmbee.CpmBeeConfig",
|
9 |
+
"AutoModel": "modeling_cpmbee.CpmBeeWithTransform",
|
10 |
+
"AutoTokenizer": "tokenization_viscpmbee.VisCpmBeeTokenizer"
|
11 |
+
},
|
12 |
+
"vocab_size": 86583,
|
13 |
+
"hidden_size": 4096,
|
14 |
+
"dim_ff" : 10240,
|
15 |
+
"num_hidden_layers" : 48,
|
16 |
+
"num_attention_heads": 32,
|
17 |
+
"dim_head" : 128,
|
18 |
+
"dropout_p" : 0.0,
|
19 |
+
"position_bias_num_buckets" : 256,
|
20 |
+
"position_bias_num_segment_buckets": 256,
|
21 |
+
"position_bias_max_distance" : 2048,
|
22 |
+
"eps" : 1e-6,
|
23 |
+
"half" : false,
|
24 |
+
"model_type": "viscpmbee",
|
25 |
+
"unet_cross_attention_dim": 1024
|
26 |
+
}
|
configuration_cpmbee.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2022 The OpenBMB Team and The HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" CpmBee model configuration"""
|
16 |
+
|
17 |
+
from typing import List, Optional, Tuple, Union
|
18 |
+
|
19 |
+
from transformers.configuration_utils import PretrainedConfig
|
20 |
+
from transformers.utils import logging
|
21 |
+
|
22 |
+
|
23 |
+
logger = logging.get_logger(__name__)
|
24 |
+
|
25 |
+
CPMBEE_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
26 |
+
"openbmb/cpm-bee-10b": "https://huggingface.co/openbmb/cpm-bee-10b/resolve/main/config.json",
|
27 |
+
"openbmb/cpm-bee-5b": "https://huggingface.co/openbmb/cpm-bee-5b/resolve/main/config.json",
|
28 |
+
"openbmb/cpm-bee-2b": "https://huggingface.co/openbmb/cpm-bee-2b/resolve/main/config.json",
|
29 |
+
"openbmb/cpm-bee-1b": "https://huggingface.co/openbmb/cpm-bee-1b/resolve/main/config.json",
|
30 |
+
# See all CpmBee models at https://huggingface.co/models?filter=cpmbee
|
31 |
+
}
|
32 |
+
|
33 |
+
|
34 |
+
class CpmBeeConfig(PretrainedConfig):
|
35 |
+
r"""
|
36 |
+
This is the configuration class to store the configuration of a [`CpmBeeModel`]. It is used to instbeeiate an
|
37 |
+
CPMBee model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
38 |
+
with the defaults will yield a similar configuration to that of the CPMBee
|
39 |
+
[openbmb/cpm-bee-10b](https://huggingface.co/openbmb/cpm-bee-10b) architecture.
|
40 |
+
|
41 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
42 |
+
documentation from [`PretrainedConfig`] for more information.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
vocab_size (`int`, *optional*, defaults to 30720):
|
46 |
+
Vocabulary size of the CPMBee model. Defines the number of different tokens that can be represented by the
|
47 |
+
`input` passed when calling [`CpmBeeModel`].
|
48 |
+
hidden_size (`int`, *optional*, defaults to 4096):
|
49 |
+
Dimension of the encoder layers.
|
50 |
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
51 |
+
Number of attention heads in the Transformer encoder.
|
52 |
+
dim_head (`int`, *optional*, defaults to 128):
|
53 |
+
Dimension of attention heads for each attention layer in the Transformer encoder.
|
54 |
+
dim_ff (`int`, *optional*, defaults to 10240):
|
55 |
+
Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
56 |
+
num_hidden_layers (`int`, *optional*, defaults to 48):
|
57 |
+
Number of layers of the Transformer encoder.
|
58 |
+
dropout_p (`float`, *optional*, defaults to 0.1):
|
59 |
+
The dropout probabilitiy for all fully connected layers in the embeddings, encoder.
|
60 |
+
position_bias_num_buckets (`int`, *optional*, defaults to 512):
|
61 |
+
The number of position_bias buckets.
|
62 |
+
position_bias_num_segment_buckets (`int`, *optional*, defaults to 32):
|
63 |
+
The number of segment buckets.
|
64 |
+
position_bias_max_distance (`int`, *optional*, defaults to 2048):
|
65 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
66 |
+
just in case (e.g., 512 or 1024 or 2048).
|
67 |
+
eps (`float`, *optional*, defaults to 1e-6):
|
68 |
+
The epsilon used by the layer normalization layers.
|
69 |
+
init_std (`float`, *optional*, defaults to 1.0):
|
70 |
+
Initialize parameters with std = init_std.
|
71 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
72 |
+
Whether to use cache.
|
73 |
+
distance_scale (`float` or `int`, *optional*, defaults to 16):
|
74 |
+
Scale the rotary embedding.
|
75 |
+
mask_modules (`list` or `tuple`, *optional*, defaults to None):
|
76 |
+
Decides which feedforward block or attention block is pruned.
|
77 |
+
half (`bool`, *optional*, defaults to `False`):
|
78 |
+
Decides the model parameters are half-precision or not.
|
79 |
+
|
80 |
+
Example:
|
81 |
+
|
82 |
+
```python
|
83 |
+
>>> from transformers import CpmBeeModel, CpmBeeConfig
|
84 |
+
|
85 |
+
>>> # Initializing a CPMBee cpm-bee-10b style configuration
|
86 |
+
>>> configuration = CpmBeeConfig()
|
87 |
+
|
88 |
+
>>> # Initializing a model from the cpm-bee-10b style configuration
|
89 |
+
>>> model = CpmBeeModel(configuration)
|
90 |
+
|
91 |
+
>>> # Accessing the model configuration
|
92 |
+
>>> configuration = model.config
|
93 |
+
```"""
|
94 |
+
model_type = "cpmbee"
|
95 |
+
|
96 |
+
def __init__(
|
97 |
+
self,
|
98 |
+
vocab_size: int = 30720,
|
99 |
+
hidden_size: int = 4096,
|
100 |
+
num_attention_heads: int = 64,
|
101 |
+
dim_head: int = 64,
|
102 |
+
dim_ff: int = 10240,
|
103 |
+
num_hidden_layers: int = 32,
|
104 |
+
dropout_p: int = 0.0,
|
105 |
+
position_bias_num_buckets: int = 256,
|
106 |
+
position_bias_num_segment_buckets: int = 32,
|
107 |
+
position_bias_max_distance: int = 2048,
|
108 |
+
eps: int = 1e-6,
|
109 |
+
init_std: float = 1.0,
|
110 |
+
use_cache: bool = True,
|
111 |
+
distance_scale: Union[int, float] = 16,
|
112 |
+
mask_modules: Optional[Union[List, Tuple]] = None,
|
113 |
+
half: bool = False,
|
114 |
+
**kwargs,
|
115 |
+
):
|
116 |
+
super().__init__(**kwargs)
|
117 |
+
self.position_bias_num_segment_buckets = position_bias_num_segment_buckets
|
118 |
+
self.hidden_size = hidden_size
|
119 |
+
self.num_attention_heads = num_attention_heads
|
120 |
+
self.dim_head = dim_head
|
121 |
+
self.dim_ff = dim_ff
|
122 |
+
self.num_hidden_layers = num_hidden_layers
|
123 |
+
self.position_bias_num_buckets = position_bias_num_buckets
|
124 |
+
self.position_bias_max_distance = position_bias_max_distance
|
125 |
+
self.dropout_p = dropout_p
|
126 |
+
self.eps = eps
|
127 |
+
self.use_cache = use_cache
|
128 |
+
self.vocab_size = vocab_size
|
129 |
+
self.init_std = init_std
|
130 |
+
self.distance_scale = distance_scale
|
131 |
+
self.half = half
|
132 |
+
self.mask_modules = mask_modules
|
feature_extractor/preprocessor_config.json
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"crop_size": 224,
|
3 |
+
"do_center_crop": true,
|
4 |
+
"do_convert_rgb": true,
|
5 |
+
"do_normalize": true,
|
6 |
+
"do_resize": true,
|
7 |
+
"feature_extractor_type": "CLIPFeatureExtractor",
|
8 |
+
"image_mean": [
|
9 |
+
0.48145466,
|
10 |
+
0.4578275,
|
11 |
+
0.40821073
|
12 |
+
],
|
13 |
+
"image_std": [
|
14 |
+
0.26862954,
|
15 |
+
0.26130258,
|
16 |
+
0.27577711
|
17 |
+
],
|
18 |
+
"resample": 3,
|
19 |
+
"size": 224
|
20 |
+
}
|
model_index.json
CHANGED
@@ -1,13 +1,14 @@
|
|
1 |
{
|
2 |
-
"_class_name": "
|
3 |
"_diffusers_version": "0.3.0",
|
4 |
"feature_extractor": [
|
5 |
"transformers",
|
6 |
"CLIPImageProcessor"
|
7 |
],
|
|
|
8 |
"safety_checker": [
|
9 |
-
|
10 |
-
|
11 |
],
|
12 |
"scheduler": [
|
13 |
"diffusers",
|
@@ -15,7 +16,11 @@
|
|
15 |
],
|
16 |
"text_encoder": [
|
17 |
"transformers",
|
18 |
-
"
|
|
|
|
|
|
|
|
|
19 |
],
|
20 |
"unet": [
|
21 |
"diffusers",
|
@@ -24,9 +29,5 @@
|
|
24 |
"vae": [
|
25 |
"diffusers",
|
26 |
"AutoencoderKL"
|
27 |
-
],
|
28 |
-
"text_safety_checker": [
|
29 |
-
"transformers",
|
30 |
-
"BertForSequenceClassification"
|
31 |
]
|
32 |
-
}
|
|
|
1 |
{
|
2 |
+
"_class_name": "VisCPMPaintBeePipeline",
|
3 |
"_diffusers_version": "0.3.0",
|
4 |
"feature_extractor": [
|
5 |
"transformers",
|
6 |
"CLIPImageProcessor"
|
7 |
],
|
8 |
+
"requires_safety_checker": false,
|
9 |
"safety_checker": [
|
10 |
+
null,
|
11 |
+
null
|
12 |
],
|
13 |
"scheduler": [
|
14 |
"diffusers",
|
|
|
16 |
],
|
17 |
"text_encoder": [
|
18 |
"transformers",
|
19 |
+
"PreTrainedModel"
|
20 |
+
],
|
21 |
+
"tokenizer": [
|
22 |
+
"transformers",
|
23 |
+
"PreTrainedTokenizer"
|
24 |
],
|
25 |
"unet": [
|
26 |
"diffusers",
|
|
|
29 |
"vae": [
|
30 |
"diffusers",
|
31 |
"AutoencoderKL"
|
|
|
|
|
|
|
|
|
32 |
]
|
33 |
+
}
|
modeling_cpmbee.py
ADDED
@@ -0,0 +1,943 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2022 The OpenBMB Team The HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" PyTorch CpmBee model."""
|
16 |
+
import copy
|
17 |
+
import math
|
18 |
+
from collections import UserDict
|
19 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import torch.nn as nn
|
23 |
+
|
24 |
+
from transformers.generation.beam_search import BeamHypotheses, BeamSearchScorer
|
25 |
+
from transformers.generation.streamers import BaseStreamer
|
26 |
+
from transformers.generation.utils import (
|
27 |
+
GenerationConfig,
|
28 |
+
LogitsProcessorList,
|
29 |
+
StoppingCriteriaList,
|
30 |
+
dist,
|
31 |
+
inspect,
|
32 |
+
is_deepspeed_zero3_enabled,
|
33 |
+
warnings,
|
34 |
+
)
|
35 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput
|
36 |
+
from transformers.modeling_utils import PreTrainedModel
|
37 |
+
from transformers.utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
38 |
+
from .configuration_cpmbee import CpmBeeConfig
|
39 |
+
from .tokenization_viscpmbee import VisCpmBeeTokenizer
|
40 |
+
|
41 |
+
|
42 |
+
logger = logging.get_logger(__name__)
|
43 |
+
|
44 |
+
_CHECKPOINT_FOR_DOC = "openbmb/cpm-bee-10b"
|
45 |
+
_CONFIG_FOR_DOC = "CpmBeeConfig"
|
46 |
+
|
47 |
+
CPMBEE_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
48 |
+
"openbmb/cpm-bee-10b",
|
49 |
+
"openbmb/cpm-bee-5b",
|
50 |
+
"openbmb/cpm-bee-2b",
|
51 |
+
"openbmb/cpm-bee-1b",
|
52 |
+
# See all CPMBee models at https://huggingface.co/models?filter=cpmbee
|
53 |
+
]
|
54 |
+
|
55 |
+
|
56 |
+
class CpmBeeLinear(nn.Linear):
|
57 |
+
def __init__(self, dim_in, dim_out, dtype):
|
58 |
+
"""
|
59 |
+
Construct a linear for CPMBee. It contains a scale operation.
|
60 |
+
"""
|
61 |
+
super().__init__(dim_in, dim_out, bias=False)
|
62 |
+
self.dim_in = self.in_features = dim_in
|
63 |
+
self.dim_out = self.out_features = dim_out
|
64 |
+
|
65 |
+
self.weight = torch.nn.parameter.Parameter(torch.empty((dim_out, dim_in), dtype=dtype))
|
66 |
+
|
67 |
+
def forward(self, x: torch.Tensor):
|
68 |
+
"""
|
69 |
+
Args:
|
70 |
+
x (`torch.Tensor` of shape `(batch, seq_len, dim_in)`): The input of linear layer
|
71 |
+
Returns:
|
72 |
+
`torch.Tensor` of shape `(batch, seq_len, dim_out)`: The output of the linear transform y.
|
73 |
+
"""
|
74 |
+
x = nn.functional.linear(x, self.weight)
|
75 |
+
x = x / math.sqrt(self.dim_in)
|
76 |
+
return x
|
77 |
+
|
78 |
+
|
79 |
+
class CpmBeeLayerNorm(nn.Module):
|
80 |
+
"""
|
81 |
+
We use Root Mean Square (RMS) Layer Normalization, please see https://arxiv.org/abs/1910.07467 for details."
|
82 |
+
"""
|
83 |
+
|
84 |
+
def __init__(self, config: CpmBeeConfig):
|
85 |
+
super().__init__()
|
86 |
+
|
87 |
+
self.eps = config.eps
|
88 |
+
self.dim_norm = config.hidden_size
|
89 |
+
self.weight = nn.Parameter(torch.empty(config.hidden_size, dtype=config.torch_dtype))
|
90 |
+
|
91 |
+
def forward(self, hidden_states: torch.Tensor):
|
92 |
+
"""
|
93 |
+
Args:
|
94 |
+
hidden_states (`torch.Tensor` of shape `(batch, seq_len, dim_in)`)
|
95 |
+
"""
|
96 |
+
if hidden_states.size(-1) != self.dim_norm:
|
97 |
+
raise AssertionError("hidden_states.size(-1) != self.dim_norm")
|
98 |
+
old_dtype = hidden_states.dtype
|
99 |
+
variance = hidden_states.to(torch.float32).pow(2).mean(dim=-1, keepdim=True)
|
100 |
+
hidden_states = (hidden_states * torch.rsqrt(variance + self.eps)).to(old_dtype) * self.weight
|
101 |
+
return hidden_states
|
102 |
+
|
103 |
+
|
104 |
+
class CpmBeeAttention(nn.Module):
|
105 |
+
def __init__(self, config: CpmBeeConfig):
|
106 |
+
super().__init__()
|
107 |
+
self.dim_model = config.hidden_size
|
108 |
+
self.num_heads = config.num_attention_heads
|
109 |
+
self.dim_head = config.dim_head
|
110 |
+
|
111 |
+
self.project_q = CpmBeeLinear(self.dim_model, self.num_heads * self.dim_head, dtype=config.torch_dtype)
|
112 |
+
self.project_k = CpmBeeLinear(self.dim_model, self.num_heads * self.dim_head, dtype=config.torch_dtype)
|
113 |
+
self.project_v = CpmBeeLinear(self.dim_model, self.num_heads * self.dim_head, dtype=config.torch_dtype)
|
114 |
+
|
115 |
+
self.attention_out = CpmBeeLinear(self.num_heads * self.dim_head, self.dim_model, dtype=config.torch_dtype)
|
116 |
+
|
117 |
+
self.softmax = torch.nn.Softmax(dim=-1)
|
118 |
+
|
119 |
+
if config.dropout_p is not None:
|
120 |
+
self.dropout = torch.nn.Dropout(p=config.dropout_p)
|
121 |
+
else:
|
122 |
+
self.dropout = None
|
123 |
+
|
124 |
+
def forward(
|
125 |
+
self,
|
126 |
+
hidden_q: torch.Tensor,
|
127 |
+
hidden_kv: torch.Tensor,
|
128 |
+
attention_mask: torch.BoolTensor,
|
129 |
+
position_bias: torch.Tensor,
|
130 |
+
output_attentions: Optional[bool] = False,
|
131 |
+
past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
132 |
+
use_cache: Optional[bool] = None,
|
133 |
+
):
|
134 |
+
"""
|
135 |
+
Args:
|
136 |
+
hidden_q (`torch.Tensor`):
|
137 |
+
Input of transformer block(self-attention block). It can be the raw embedding of a batch of sequences.
|
138 |
+
hidden_kv (`torch.Tensor` of shape `(batch, len_k, dim_model)`)):
|
139 |
+
Tensor *key_value* and *query* of shape `(batch, len_k, dim_model)`
|
140 |
+
attention_mask (`torch.Tensor` of shape `(batch, len_seq, len_seq)`):
|
141 |
+
Avoid invalid areas to participate in the calculation of self-attention.
|
142 |
+
position_bias (`torch.Tensor` of shape `(batch, len_seq, len_seq)`):
|
143 |
+
Provide positional information to self-attention block.
|
144 |
+
output_attentions (`bool`, *optional*):
|
145 |
+
Whether or not to return the attentions tensors of all attention layers.
|
146 |
+
past_key_values (`Tuple[torch.Tensor, torch.Tensor]`, *optional*):
|
147 |
+
Cached past key and value projection states.
|
148 |
+
use_cache (`bool`, *optional*):
|
149 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
150 |
+
(see `past_key_values`).
|
151 |
+
"""
|
152 |
+
batch_size = hidden_q.size(0)
|
153 |
+
len_q = hidden_q.size(1)
|
154 |
+
len_k = hidden_kv.size(1)
|
155 |
+
|
156 |
+
query = self.project_q(hidden_q)
|
157 |
+
key = self.project_k(hidden_kv)
|
158 |
+
value = self.project_v(hidden_kv)
|
159 |
+
|
160 |
+
query = query.view(batch_size, len_q, self.num_heads, self.dim_head).permute(0, 2, 1, 3)
|
161 |
+
key = key.view(batch_size, len_k, self.num_heads, self.dim_head).permute(0, 2, 1, 3)
|
162 |
+
value = value.view(batch_size, len_k, self.num_heads, self.dim_head).permute(0, 2, 1, 3)
|
163 |
+
|
164 |
+
if past_key_values is not None:
|
165 |
+
key = torch.cat([past_key_values[0], key], dim=-2)
|
166 |
+
value = torch.cat([past_key_values[1], value], dim=-2)
|
167 |
+
len_k = key.size(-2)
|
168 |
+
|
169 |
+
# (batch_size, num_heads, len_q, dim_head) @ (batch_size, num_heads, dim_head, len_k) -> (batch_size, num_heads, len_q, len_k)
|
170 |
+
score = torch.matmul(query, key.transpose(-1, -2)) / math.sqrt(self.dim_head)
|
171 |
+
score = score + position_bias
|
172 |
+
|
173 |
+
score = torch.masked_fill(
|
174 |
+
score,
|
175 |
+
attention_mask.view(batch_size, 1, len_q, len_k) == torch.tensor(False),
|
176 |
+
torch.scalar_tensor(float("-inf"), device=score.device, dtype=score.dtype),
|
177 |
+
)
|
178 |
+
score = self.softmax(score)
|
179 |
+
|
180 |
+
score = torch.masked_fill(
|
181 |
+
score,
|
182 |
+
attention_mask.view(batch_size, 1, len_q, len_k) == torch.tensor(False),
|
183 |
+
torch.scalar_tensor(0, device=score.device, dtype=score.dtype),
|
184 |
+
)
|
185 |
+
if output_attentions:
|
186 |
+
attn_weights = score
|
187 |
+
else:
|
188 |
+
attn_weights = None
|
189 |
+
|
190 |
+
if self.dropout is not None:
|
191 |
+
score = self.dropout(score)
|
192 |
+
|
193 |
+
# (batch_size, num_heads, len_q, len_k) @ (batch_size, num_heads, len_k, dim_head) -> (batch_size, num_heads, len_q, dim_head)
|
194 |
+
score = torch.matmul(score, value)
|
195 |
+
|
196 |
+
score = score.view(batch_size, self.num_heads, len_q, self.dim_head).permute(0, 2, 1, 3)
|
197 |
+
score = score.contiguous().view(batch_size, len_q, self.num_heads * self.dim_head)
|
198 |
+
|
199 |
+
score = self.attention_out(score)
|
200 |
+
|
201 |
+
past_key_values = None
|
202 |
+
if use_cache:
|
203 |
+
past_key_values = (key, value)
|
204 |
+
|
205 |
+
return score, attn_weights, past_key_values
|
206 |
+
|
207 |
+
|
208 |
+
class CpmBeeSelfAttentionBlock(nn.Module):
|
209 |
+
def __init__(self, config: CpmBeeConfig):
|
210 |
+
super().__init__()
|
211 |
+
self.layernorm_before_attention = CpmBeeLayerNorm(config)
|
212 |
+
self.self_attention = CpmBeeAttention(config)
|
213 |
+
if config.dropout_p:
|
214 |
+
self.dropout = torch.nn.Dropout(config.dropout_p)
|
215 |
+
else:
|
216 |
+
self.dropout = None
|
217 |
+
|
218 |
+
def forward(
|
219 |
+
self,
|
220 |
+
hidden_states: torch.Tensor,
|
221 |
+
attention_mask: torch.Tensor,
|
222 |
+
position_bias: Optional[torch.Tensor] = None,
|
223 |
+
output_attentions: Optional[bool] = False,
|
224 |
+
past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
225 |
+
use_cache: Optional[bool] = None,
|
226 |
+
):
|
227 |
+
"""
|
228 |
+
Args:
|
229 |
+
hidden_states (`torch.Tensor` of shape `(batch, len_seq, dim_model)`):
|
230 |
+
Input of transformer block(self-attention block). It can be the raw embedding of a batch of sequences.
|
231 |
+
attention_mask (`torch.Tensor` of shape `(batch, len_seq, len_seq)`):
|
232 |
+
Avoid invalid areas to participate in the calculation of self-attention.
|
233 |
+
position_bias (`torch.Tensor` of shape `(batch, len_seq, len_seq)`):
|
234 |
+
Provide positional information to self-attention block.
|
235 |
+
output_attentions (`bool`, *optional*):
|
236 |
+
Whether or not to return the attentions tensors of all attention layers.
|
237 |
+
past_key_values (`Tuple(torch.FloatTensor)`, *optional*):
|
238 |
+
Cached past key and value projection states.
|
239 |
+
use_cache (`bool`, *optional*):
|
240 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
241 |
+
(see `past_key_values`).
|
242 |
+
"""
|
243 |
+
outputs = self.layernorm_before_attention(hidden_states)
|
244 |
+
outputs = self.self_attention(
|
245 |
+
outputs, outputs, attention_mask, position_bias, output_attentions, past_key_values, use_cache
|
246 |
+
)
|
247 |
+
|
248 |
+
outputs, attn_weights, current_key_value = outputs
|
249 |
+
|
250 |
+
if self.dropout is not None:
|
251 |
+
outputs = self.dropout(outputs)
|
252 |
+
hidden_states = (hidden_states + outputs) / 1.05
|
253 |
+
|
254 |
+
return hidden_states, attn_weights, current_key_value
|
255 |
+
|
256 |
+
|
257 |
+
class CpmBeeDenseGatedACT(nn.Module):
|
258 |
+
def __init__(self, config: CpmBeeConfig):
|
259 |
+
super().__init__()
|
260 |
+
self.w_0 = CpmBeeLinear(config.hidden_size, config.dim_ff, dtype=config.torch_dtype)
|
261 |
+
self.w_1 = CpmBeeLinear(config.hidden_size, config.dim_ff, dtype=config.torch_dtype)
|
262 |
+
self.act = torch.nn.GELU()
|
263 |
+
|
264 |
+
def forward(self, hidden_states: torch.Tensor):
|
265 |
+
"""Transform an input tensor from one feature space to another via a nonlinear operation
|
266 |
+
|
267 |
+
Args:
|
268 |
+
hidden_states (`torch.Tensor` of shape `(batch, seq_len, dim_in)`)
|
269 |
+
"""
|
270 |
+
gate_score = self.act(self.w_0(hidden_states))
|
271 |
+
hidden_states = self.w_1(hidden_states)
|
272 |
+
|
273 |
+
hidden_states = gate_score * hidden_states
|
274 |
+
return hidden_states
|
275 |
+
|
276 |
+
|
277 |
+
class CpmBeeFeedForward(nn.Module):
|
278 |
+
def __init__(self, config: CpmBeeConfig):
|
279 |
+
super().__init__()
|
280 |
+
self.w_in = CpmBeeDenseGatedACT(config)
|
281 |
+
if config.dropout_p is not None:
|
282 |
+
self.dropout = torch.nn.Dropout(config.dropout_p)
|
283 |
+
else:
|
284 |
+
self.dropout = None
|
285 |
+
|
286 |
+
self.w_out = CpmBeeLinear(config.dim_ff, config.hidden_size, dtype=config.torch_dtype)
|
287 |
+
|
288 |
+
def forward(self, hidden_states: torch.Tensor):
|
289 |
+
"""
|
290 |
+
Args:
|
291 |
+
hidden_states (`torch.Tensor` of shape `(batch, seq_len, dim_in)`)
|
292 |
+
"""
|
293 |
+
hidden_states = self.w_in(hidden_states)
|
294 |
+
|
295 |
+
if self.dropout is not None:
|
296 |
+
hidden_states = self.dropout(hidden_states)
|
297 |
+
|
298 |
+
hidden_states = self.w_out(hidden_states)
|
299 |
+
|
300 |
+
return hidden_states
|
301 |
+
|
302 |
+
|
303 |
+
class CpmBeeFFNBlock(nn.Module):
|
304 |
+
def __init__(self, config: CpmBeeConfig):
|
305 |
+
super().__init__()
|
306 |
+
self.layernorm_before_ffn = CpmBeeLayerNorm(config)
|
307 |
+
self.ffn = CpmBeeFeedForward(config)
|
308 |
+
if config.dropout_p:
|
309 |
+
self.dropout = torch.nn.Dropout(config.dropout_p)
|
310 |
+
else:
|
311 |
+
self.dropout = None
|
312 |
+
|
313 |
+
def forward(
|
314 |
+
self,
|
315 |
+
hidden_states: torch.Tensor,
|
316 |
+
):
|
317 |
+
"""
|
318 |
+
Args:
|
319 |
+
hidden_states (`torch.Tensor` of shape `(batch, len_seq, dim_model)`):
|
320 |
+
Hidden states before feed forward layer.
|
321 |
+
"""
|
322 |
+
ln_outputs = self.layernorm_before_ffn(hidden_states)
|
323 |
+
outputs = self.ffn(ln_outputs)
|
324 |
+
if self.dropout is not None:
|
325 |
+
outputs = self.dropout(outputs)
|
326 |
+
hidden_states = (hidden_states + outputs) / 1.05
|
327 |
+
return hidden_states
|
328 |
+
|
329 |
+
|
330 |
+
class CpmBeeTransformerBlock(nn.Module):
|
331 |
+
def __init__(self, config: CpmBeeConfig, mask_att: bool = False, mask_ffn: bool = False):
|
332 |
+
super().__init__()
|
333 |
+
self.mask_att = mask_att
|
334 |
+
self.mask_ffn = mask_ffn
|
335 |
+
|
336 |
+
if not self.mask_att:
|
337 |
+
self.self_att = CpmBeeSelfAttentionBlock(config)
|
338 |
+
if not self.mask_ffn:
|
339 |
+
self.ffn = CpmBeeFFNBlock(config)
|
340 |
+
|
341 |
+
def forward(
|
342 |
+
self,
|
343 |
+
hidden_states: torch.Tensor,
|
344 |
+
attention_mask: torch.Tensor,
|
345 |
+
position_bias: Optional[torch.Tensor] = None,
|
346 |
+
output_attentions: Optional[bool] = False,
|
347 |
+
past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
348 |
+
use_cache: Optional[bool] = None,
|
349 |
+
):
|
350 |
+
"""
|
351 |
+
Args:
|
352 |
+
hidden_states (`torch.Tensor`):
|
353 |
+
Input to the layer of shape `(batch, seq_len, dim_model)`
|
354 |
+
attention_mask (`torch.Tensor`):
|
355 |
+
Avoid invalid areas to participate in the calculation of shape `(batch, seq_len, seq_len)`
|
356 |
+
position_bias (`torch.Tensor`):
|
357 |
+
Provides position information to attention mechanism of shape `(num_heads, seq_len, seq_len)`
|
358 |
+
output_attentions (`bool`, *optional*):
|
359 |
+
Whether or not to return the attentions tensors of all attention layers.
|
360 |
+
past_key_values (`Tuple[torch.Tensor, torch.Tensor])`, *optional*):
|
361 |
+
Cached past key and value projection states
|
362 |
+
use_cache (`bool`, *optional*):
|
363 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
364 |
+
(see `past_key_values`).
|
365 |
+
"""
|
366 |
+
if not self.mask_att:
|
367 |
+
hidden_states = self.self_att(
|
368 |
+
hidden_states,
|
369 |
+
attention_mask=attention_mask,
|
370 |
+
position_bias=position_bias,
|
371 |
+
output_attentions=output_attentions,
|
372 |
+
past_key_values=past_key_values,
|
373 |
+
use_cache=use_cache,
|
374 |
+
)
|
375 |
+
|
376 |
+
hidden_states, attn_weights, current_key_value = hidden_states
|
377 |
+
else:
|
378 |
+
attn_weights, current_key_value = None, (None, None)
|
379 |
+
|
380 |
+
if not self.mask_ffn:
|
381 |
+
hidden_states = self.ffn(hidden_states)
|
382 |
+
|
383 |
+
return hidden_states, attn_weights, current_key_value
|
384 |
+
|
385 |
+
|
386 |
+
class CpmBeeEncoder(nn.Module):
|
387 |
+
def __init__(self, config: CpmBeeConfig):
|
388 |
+
super().__init__()
|
389 |
+
self.num_layers = config.num_hidden_layers
|
390 |
+
if config.mask_modules is not None:
|
391 |
+
assert len(config.mask_modules) == self.num_layers, "The total number of masks should equal to num_layers"
|
392 |
+
for mask_module in config.mask_modules:
|
393 |
+
assert len(mask_module) == 2, "For encoder, each mask should be (mask_att, mask_ffn)"
|
394 |
+
else:
|
395 |
+
config.mask_modules = [(False, False)] * self.num_layers
|
396 |
+
|
397 |
+
self.layers = nn.ModuleList(
|
398 |
+
[
|
399 |
+
CpmBeeTransformerBlock(
|
400 |
+
config, mask_att=config.mask_modules[ith][0], mask_ffn=config.mask_modules[ith][1]
|
401 |
+
)
|
402 |
+
for ith in range(self.num_layers)
|
403 |
+
]
|
404 |
+
)
|
405 |
+
|
406 |
+
self.output_layernorm = CpmBeeLayerNorm(config)
|
407 |
+
|
408 |
+
def forward(
|
409 |
+
self,
|
410 |
+
hidden_states: torch.Tensor,
|
411 |
+
attention_mask: torch.Tensor,
|
412 |
+
position_bias: torch.Tensor,
|
413 |
+
output_attentions: Optional[bool] = None,
|
414 |
+
output_hidden_states: Optional[bool] = None,
|
415 |
+
past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
416 |
+
use_cache: Optional[bool] = None,
|
417 |
+
):
|
418 |
+
"""
|
419 |
+
Args:
|
420 |
+
hidden_states (`torch.Tensor`):
|
421 |
+
Input to the layer of shape `(batch, seq_len, dim_model)`
|
422 |
+
attention_mask (`torch.Tensor`):
|
423 |
+
Avoid invalid areas to participate in the calculation of shape `(batch, seq_len, seq_len)`
|
424 |
+
position_bias (`torch.Tensor`):
|
425 |
+
Provides position information to attention mechanism of shape `(num_heads, seq_len, seq_len)`
|
426 |
+
output_attentions (`bool`, *optional*):
|
427 |
+
Whether or not to return the attentions tensors of all attention layers.
|
428 |
+
output_hidden_states (`bool`, *optional*):
|
429 |
+
Whether or not to return the hidden states of all layers.
|
430 |
+
past_key_values (`Tuple[torch.Tensor, torch.Tensor])`, *optional*):
|
431 |
+
Cached past key and value projection states
|
432 |
+
use_cache (`bool`, *optional*):
|
433 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
434 |
+
(see `past_key_values`).
|
435 |
+
"""
|
436 |
+
all_hidden_states = () if output_hidden_states else None
|
437 |
+
all_self_attns = () if output_attentions else None
|
438 |
+
current_key_values = () if use_cache else None
|
439 |
+
|
440 |
+
for i, layer in enumerate(self.layers):
|
441 |
+
if output_hidden_states:
|
442 |
+
all_hidden_states += (hidden_states,)
|
443 |
+
layer_outputs = layer(
|
444 |
+
hidden_states,
|
445 |
+
attention_mask,
|
446 |
+
position_bias,
|
447 |
+
output_attentions=output_attentions,
|
448 |
+
past_key_values=past_key_values[i] if past_key_values else None,
|
449 |
+
use_cache=use_cache,
|
450 |
+
)
|
451 |
+
hidden_states, attn_weights, current_key_value = layer_outputs
|
452 |
+
if output_attentions:
|
453 |
+
all_self_attns += (attn_weights,)
|
454 |
+
if current_key_value is not None:
|
455 |
+
current_key_values = current_key_values + (current_key_value,)
|
456 |
+
|
457 |
+
hidden_states = self.output_layernorm(hidden_states)
|
458 |
+
|
459 |
+
if output_hidden_states:
|
460 |
+
all_hidden_states += (hidden_states,)
|
461 |
+
|
462 |
+
return hidden_states, current_key_values, all_hidden_states, all_self_attns
|
463 |
+
|
464 |
+
|
465 |
+
class CpmBeeBucketPositionBias(nn.Module):
|
466 |
+
def __init__(self, config: CpmBeeConfig) -> None:
|
467 |
+
super().__init__()
|
468 |
+
|
469 |
+
self.num_heads = config.num_attention_heads
|
470 |
+
self.num_buckets = config.position_bias_num_buckets
|
471 |
+
self.num_segment_bucket = config.position_bias_num_segment_buckets
|
472 |
+
self.max_distance = config.position_bias_max_distance
|
473 |
+
|
474 |
+
self.relative_attention_bias = nn.Parameter(
|
475 |
+
torch.empty(
|
476 |
+
config.position_bias_num_buckets + config.position_bias_num_segment_buckets,
|
477 |
+
config.num_attention_heads,
|
478 |
+
dtype=config.torch_dtype,
|
479 |
+
),
|
480 |
+
)
|
481 |
+
|
482 |
+
def forward(self, query_pos: torch.Tensor, key_pos: torch.Tensor, rel_buckets: torch.Tensor):
|
483 |
+
with torch.no_grad():
|
484 |
+
batch = key_pos.size(0)
|
485 |
+
keylen = key_pos.size(1)
|
486 |
+
querylen = query_pos.size(1)
|
487 |
+
|
488 |
+
if key_pos.size(0) != query_pos.size(0):
|
489 |
+
raise AssertionError(
|
490 |
+
f"key_pos.size(0) should be equal to query_pos.size(0), but got {key_pos.size(0)} and {query_pos.size(0)}!"
|
491 |
+
)
|
492 |
+
if rel_buckets.size(0) != batch:
|
493 |
+
raise AssertionError(
|
494 |
+
f"rel_buckets.size(0) should be equal to batch, but got {rel_buckets.size(0)} and {batch}!"
|
495 |
+
)
|
496 |
+
if rel_buckets.size(1) != querylen:
|
497 |
+
raise AssertionError(
|
498 |
+
f"rel_buckets.size(1) should be equal to querylen, but got {rel_buckets.size(1)} and {querylen}!"
|
499 |
+
)
|
500 |
+
if rel_buckets.size(2) != keylen:
|
501 |
+
raise AssertionError(
|
502 |
+
f"rel_buckets.size(2) should be equal to keylen, but got {rel_buckets.size(2)} and {keylen}!"
|
503 |
+
)
|
504 |
+
|
505 |
+
relative_position_bucket = rel_buckets - 1 + self.num_buckets
|
506 |
+
|
507 |
+
inner_segment_bucket = self._position_bucket(
|
508 |
+
key_pos[..., None, :] - query_pos[..., :, None],
|
509 |
+
num_buckets=self.num_buckets,
|
510 |
+
max_distance=self.max_distance,
|
511 |
+
)
|
512 |
+
relative_position_bucket = torch.where(
|
513 |
+
rel_buckets == 0,
|
514 |
+
inner_segment_bucket,
|
515 |
+
relative_position_bucket,
|
516 |
+
)
|
517 |
+
|
518 |
+
embeds = nn.functional.embedding(relative_position_bucket, self.relative_attention_bias)
|
519 |
+
embeds = embeds.permute(0, 3, 1, 2).contiguous()
|
520 |
+
return embeds
|
521 |
+
|
522 |
+
def _position_bucket(self, relative_position, num_buckets=32, max_distance=128):
|
523 |
+
relative_buckets = 0
|
524 |
+
num_buckets //= 2
|
525 |
+
relative_buckets = (relative_position > 0).to(torch.int32) * num_buckets
|
526 |
+
relative_position = torch.abs(relative_position)
|
527 |
+
max_exact = num_buckets // 2
|
528 |
+
is_small = relative_position < max_exact
|
529 |
+
relative_postion_if_large = max_exact + (
|
530 |
+
torch.log(relative_position.float() / max_exact)
|
531 |
+
/ math.log(max_distance / max_exact)
|
532 |
+
* (num_buckets - max_exact)
|
533 |
+
).to(torch.int32)
|
534 |
+
relative_postion_if_large = torch.min(
|
535 |
+
relative_postion_if_large,
|
536 |
+
torch.full_like(relative_postion_if_large, num_buckets - 1),
|
537 |
+
)
|
538 |
+
relative_buckets += torch.where(is_small, relative_position.to(torch.int32), relative_postion_if_large)
|
539 |
+
return relative_buckets
|
540 |
+
|
541 |
+
|
542 |
+
# Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->CPMBee
|
543 |
+
class CpmBeeOutput(nn.Module):
|
544 |
+
def __init__(self, config):
|
545 |
+
super().__init__()
|
546 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
547 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
548 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
549 |
+
|
550 |
+
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
551 |
+
hidden_states = self.dense(hidden_states)
|
552 |
+
hidden_states = self.dropout(hidden_states)
|
553 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
554 |
+
return hidden_states
|
555 |
+
|
556 |
+
|
557 |
+
class CpmBeeRotaryEmbedding(nn.Module):
|
558 |
+
"""
|
559 |
+
RotaryEmbedding embeds the unk token and special token. It will embeds the "...<mask>...<mask>...<unk>...<unk>..."
|
560 |
+
to "...<mask_0>...<mask_1>...<unk_0>...<unk_1>..."" to help model to specify different special tokens and unk
|
561 |
+
tokens.
|
562 |
+
"""
|
563 |
+
|
564 |
+
def __init__(self, config: CpmBeeConfig):
|
565 |
+
super().__init__()
|
566 |
+
inv_freq = 1.0 / (10000 ** (torch.arange(0, config.hidden_size, 2, dtype=torch.float32) / config.hidden_size))
|
567 |
+
self.distance_scale = config.distance_scale
|
568 |
+
self.dtype = config.torch_dtype
|
569 |
+
self.inv_freq = inv_freq.to(config.torch_dtype)
|
570 |
+
|
571 |
+
def forward(self, x: torch.Tensor, x_pos: torch.Tensor):
|
572 |
+
inv_freq = self.inv_freq.to(device=x.device, dtype=self.dtype)
|
573 |
+
|
574 |
+
x_pos = x_pos * self.distance_scale
|
575 |
+
freqs = x_pos[..., None].to(self.dtype) * inv_freq[None, :] # (..., dim/2)
|
576 |
+
|
577 |
+
emb = torch.cat((freqs, freqs), dim=-1) # (..., dim)
|
578 |
+
emb_cos = emb.cos() # (..., dim)
|
579 |
+
emb_sin = emb.sin() # (..., dim)
|
580 |
+
|
581 |
+
rotate_x = torch.cat([-x[..., x.size(-1) // 2 :], x[..., : x.size(-1) // 2]], dim=-1) # (..., dim)
|
582 |
+
|
583 |
+
return x * emb_cos + rotate_x * emb_sin
|
584 |
+
|
585 |
+
|
586 |
+
class CpmBeeEmbeddingExt(nn.Embedding):
|
587 |
+
"""
|
588 |
+
Contains a RotaryEmbedding.
|
589 |
+
"""
|
590 |
+
|
591 |
+
def __init__(self, config: CpmBeeConfig):
|
592 |
+
super().__init__(config.vocab_size, config.hidden_size, dtype=config.torch_dtype)
|
593 |
+
self.dim_model = config.hidden_size
|
594 |
+
self.rotary_emb = CpmBeeRotaryEmbedding(config)
|
595 |
+
|
596 |
+
def forward(self, ids: torch.Tensor, ids_sub: torch.Tensor):
|
597 |
+
embeds = super().forward(ids) / math.sqrt(self.dim_model)
|
598 |
+
return self.rotary_emb(embeds, ids_sub)
|
599 |
+
|
600 |
+
def projection(self, x: torch.Tensor, ext_table: Optional[torch.Tensor] = None):
|
601 |
+
logits = nn.functional.linear(x / math.sqrt(self.dim_model), self.weight)
|
602 |
+
if ext_table is not None:
|
603 |
+
logits_ext = nn.functional.linear(x, ext_table)
|
604 |
+
logits = torch.cat([logits, logits_ext], dim=-1)
|
605 |
+
return logits
|
606 |
+
|
607 |
+
|
608 |
+
class CpmBeePreTrainedModel(PreTrainedModel):
|
609 |
+
"""
|
610 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
611 |
+
models.
|
612 |
+
"""
|
613 |
+
|
614 |
+
config_class = CpmBeeConfig
|
615 |
+
base_model_prefix = "cpmbee"
|
616 |
+
supports_gradient_checkpointing = True
|
617 |
+
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
618 |
+
|
619 |
+
def _init_weights(self, module):
|
620 |
+
"""Initialize the weights"""
|
621 |
+
if isinstance(module, nn.Linear):
|
622 |
+
module.weight.data.normal_(mean=0.0, std=self.config.init_std)
|
623 |
+
if module.bias is not None:
|
624 |
+
module.bias.data.zero_()
|
625 |
+
# still needed
|
626 |
+
elif isinstance(module, CpmBeeEmbeddingExt):
|
627 |
+
module.weight.data.normal_(mean=0.0, std=self.config.init_std)
|
628 |
+
elif isinstance(module, nn.LayerNorm):
|
629 |
+
module.bias.data.zero_()
|
630 |
+
module.weight.data.fill_(1.0)
|
631 |
+
elif isinstance(module, CpmBeeLayerNorm):
|
632 |
+
module.weight.data.fill_(1.0)
|
633 |
+
elif isinstance(module, CpmBeeBucketPositionBias):
|
634 |
+
module.relative_attention_bias.data.normal_(mean=0.0, std=self.config.init_std)
|
635 |
+
|
636 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
637 |
+
if isinstance(module, CpmBeeEncoder):
|
638 |
+
module.gradient_checkpointing = value
|
639 |
+
|
640 |
+
|
641 |
+
CPMBEE_START_DOCSTRING = r"""
|
642 |
+
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
|
643 |
+
it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
644 |
+
behavior.
|
645 |
+
|
646 |
+
Parameters
|
647 |
+
config ([`~CpmBeeConfig`]): Model configuration class with all the parameters of the
|
648 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
649 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
650 |
+
"""
|
651 |
+
|
652 |
+
CPMBEE_INPUTS_DOCSTRING = r"""
|
653 |
+
Args:
|
654 |
+
input_ids (`torch.Tensor` of shape `(batch_size, seq_len)`):
|
655 |
+
Indices of input sequence tokens in the vocabulary.
|
656 |
+
|
657 |
+
Indices can be obtained using [`CPMBeeTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
658 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
659 |
+
|
660 |
+
[What are input IDs?](../glossary#input-ids)
|
661 |
+
input_id_sub (`torch.Tensor` of shape `(batch_size, seq_len)`):
|
662 |
+
Subscription of input sequence tokens in the vocabulary.
|
663 |
+
|
664 |
+
Subscription of normal text will be zero while the special tokens of each group will be the 0, 1, 2, ...
|
665 |
+
<ans_0>, <ans_1>, <ans_2> ... belongs to group <ans>. <mask_0>, <mask_1>, <mask_2> ... belongs to group
|
666 |
+
<mask>.
|
667 |
+
position (`torch.Tensor` of shape `(batch_size, seq_len)`):
|
668 |
+
The position of input sequence tokens in the vocabulary for each segment. if segment1 is 0, 1, 2 and
|
669 |
+
segment2 is 0, 1, 2, 3, the position will be 0, 1, 2, 0, 1, 2, 3
|
670 |
+
context (`torch.Tensor` of shape `(batch_size, seq_len)`):
|
671 |
+
Whether this token id is context or not. If is context, the value is 1. If not, the value is 0. If a token
|
672 |
+
id is context, it does not need to be predicted.
|
673 |
+
sample_ids (`torch.Tensor` of shape `(batch_size, seq_len)`):
|
674 |
+
Give a sample id to every token id. The token ids with same sample ids belongs to the same sample.
|
675 |
+
num_segments (`torch.Tensor` of shape `(batch_size, seq_len)`):
|
676 |
+
Total number of segments in the current input.
|
677 |
+
segment (`torch.Tensor` of shape `(batch_size, seq_len)`):
|
678 |
+
Give a segment id to every token id. The token ids with same segment ids belongs to the same sample.
|
679 |
+
|
680 |
+
Generally, a string key or value in input data will be a segment. For example, input {"input": "hello, ",
|
681 |
+
"<ans>": ""}, the segments includes: "input", "hello, ", "<ans>" and "".
|
682 |
+
segment_rel_offset (`torch.Tensor` of shape `(batch_size, seq_len)`):
|
683 |
+
The offset of segment rel.
|
684 |
+
segment_rel (`torch.Tensor` of shape `(batch_size, seq_len)`):
|
685 |
+
The segment relevance. A relative implementation of measuring the importance of segments.
|
686 |
+
past_states (`Dict[str, Union[torch.Tensor, List]]`):
|
687 |
+
Store the history information including position, context, sample_ids, num_segments, segment and
|
688 |
+
past_key_values.
|
689 |
+
output_attentions (`bool`, *optional*):
|
690 |
+
Whether or not to return the attentions tensors of all attention layers.
|
691 |
+
output_hidden_states (`bool`, *optional*):
|
692 |
+
Whether or not to return the hidden states of all layers.
|
693 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
694 |
+
A dummy arguments for CPMBee. The `past_states` contains pre-computed hidden-states (key and values in the
|
695 |
+
self-attention blocks and in the cross-attention blocks) that can be used (see `past_key_values` input) and
|
696 |
+
other history arguments to speed up sequential decoding.
|
697 |
+
use_cache (`bool`, *optional*):
|
698 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
699 |
+
`past_key_values`).
|
700 |
+
labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
701 |
+
Labels for computing the masked language modeling loss.
|
702 |
+
return_dict (`bool`, *optional*):
|
703 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
704 |
+
"""
|
705 |
+
|
706 |
+
|
707 |
+
@add_start_docstrings(
|
708 |
+
"The bare CPMBee Model outputting raw hidden-states without any specific head on top.",
|
709 |
+
CPMBEE_START_DOCSTRING,
|
710 |
+
)
|
711 |
+
class CpmBeeModel(CpmBeePreTrainedModel):
|
712 |
+
def __init__(self, config: CpmBeeConfig):
|
713 |
+
super().__init__(config)
|
714 |
+
if config.half:
|
715 |
+
config.torch_dtype = torch.half
|
716 |
+
else:
|
717 |
+
config.torch_dtype = torch.float
|
718 |
+
self.encoder = CpmBeeEncoder(config)
|
719 |
+
self.input_embedding = CpmBeeEmbeddingExt(config)
|
720 |
+
self.position_bias = CpmBeeBucketPositionBias(config)
|
721 |
+
self.vocab_size = config.vocab_size
|
722 |
+
self.post_init()
|
723 |
+
|
724 |
+
def get_input_embeddings(self):
|
725 |
+
return self.input_embedding
|
726 |
+
|
727 |
+
def set_input_embeddings(self, embeddings, **kwargs):
|
728 |
+
self.input_embedding = embeddings
|
729 |
+
|
730 |
+
@add_start_docstrings_to_model_forward(CPMBEE_INPUTS_DOCSTRING)
|
731 |
+
@add_code_sample_docstrings(
|
732 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
733 |
+
output_type=BaseModelOutputWithPast,
|
734 |
+
config_class=_CONFIG_FOR_DOC,
|
735 |
+
)
|
736 |
+
def forward(
|
737 |
+
self,
|
738 |
+
input_ids: torch.Tensor,
|
739 |
+
input_id_sub: Optional[torch.Tensor] = None,
|
740 |
+
position: Optional[torch.Tensor] = None,
|
741 |
+
context: Optional[torch.Tensor] = None,
|
742 |
+
sample_ids: Optional[torch.Tensor] = None,
|
743 |
+
num_segments: Optional[torch.Tensor] = None,
|
744 |
+
segment: Optional[torch.Tensor] = None,
|
745 |
+
segment_rel_offset: Optional[torch.Tensor] = None,
|
746 |
+
segment_rel: Optional[torch.Tensor] = None,
|
747 |
+
past_states: Optional[Dict] = None,
|
748 |
+
output_attentions: Optional[bool] = None,
|
749 |
+
output_hidden_states: Optional[bool] = None,
|
750 |
+
past_key_values: Optional[List] = None,
|
751 |
+
use_cache: Optional[bool] = None,
|
752 |
+
return_dict: Optional[bool] = None,
|
753 |
+
**kwargs,
|
754 |
+
):
|
755 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
756 |
+
output_hidden_states = (
|
757 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
758 |
+
)
|
759 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
760 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
761 |
+
|
762 |
+
# dummy setting for common tests
|
763 |
+
if input_id_sub is None:
|
764 |
+
dtype, device = input_ids.dtype, input_ids.device
|
765 |
+
batch, seq_length = input_ids.size()
|
766 |
+
segment = torch.where(input_ids != 0, 2, 0).to(dtype=dtype, device=device)
|
767 |
+
context = torch.full((batch, seq_length), 1, dtype=dtype, device=device)
|
768 |
+
position = torch.arange(seq_length, dtype=dtype, device=device).repeat(batch, 1)
|
769 |
+
input_id_sub = torch.full((batch, seq_length), 0, dtype=dtype, device=device)
|
770 |
+
segment_rel_offset = torch.full((batch, seq_length), 0, dtype=dtype, device=device)
|
771 |
+
segment_rel = torch.full((batch, seq_length), 0, dtype=dtype, device=device)
|
772 |
+
num_segments = torch.full((batch, seq_length), 0, dtype=dtype, device=device)
|
773 |
+
sample_ids = torch.zeros_like(input_ids)
|
774 |
+
|
775 |
+
with torch.no_grad():
|
776 |
+
if past_states is None:
|
777 |
+
present_position = position
|
778 |
+
present_context = context
|
779 |
+
present_sample_ids = sample_ids
|
780 |
+
present_num_segments = num_segments
|
781 |
+
present_segments = segment
|
782 |
+
present_buffer = None
|
783 |
+
else:
|
784 |
+
present_position = torch.cat([past_states["buffer_position"], position], dim=-1)
|
785 |
+
present_context = torch.cat([past_states["buffer_context"], context], dim=-1)
|
786 |
+
present_sample_ids = torch.cat([past_states["buffer_sample_ids"], sample_ids], dim=-1)
|
787 |
+
present_num_segments = torch.cat([past_states["buffer_num_segments"], num_segments], dim=-1)
|
788 |
+
present_segments = torch.cat([past_states["buffer_segments"], segment], dim=-1)
|
789 |
+
present_buffer = past_states["buffer"]
|
790 |
+
|
791 |
+
batch = input_ids.size(0)
|
792 |
+
len_q = input_ids.size(1)
|
793 |
+
len_buffer = present_position.size(1)
|
794 |
+
|
795 |
+
segment_rel_2d = torch.masked_fill(
|
796 |
+
segment[:, :, None] * num_segments[:, :, None]
|
797 |
+
+ present_segments[:, None, :]
|
798 |
+
+ segment_rel_offset[:, :, None],
|
799 |
+
~((sample_ids[:, :, None] == present_sample_ids[:, None, :])), # not in the same sample
|
800 |
+
0, # avoid torch.gather overflow
|
801 |
+
).view(batch, len_q * len_buffer)
|
802 |
+
|
803 |
+
segment_bucket = torch.gather(
|
804 |
+
input=segment_rel,
|
805 |
+
dim=1,
|
806 |
+
index=segment_rel_2d.long(),
|
807 |
+
).view(batch, len_q, len_buffer)
|
808 |
+
|
809 |
+
segment_bucket.masked_fill_(
|
810 |
+
~((sample_ids[:, :, None] == present_sample_ids[:, None, :])), # not in the same span or sample
|
811 |
+
1, # bucket is used for in-context samples
|
812 |
+
)
|
813 |
+
|
814 |
+
# directional mask
|
815 |
+
directional_mask_2d = present_position[:, None, :] <= position[:, :, None]
|
816 |
+
# sample mask
|
817 |
+
sample_mask_2d = (sample_ids[:, :, None] == 0) | (sample_ids[:, :, None] == present_sample_ids[:, None, :])
|
818 |
+
# context mask
|
819 |
+
attention_mask = present_context[:, None, :] | (
|
820 |
+
context[:, :, None].logical_not() & directional_mask_2d.view(batch, len_q, len_buffer)
|
821 |
+
)
|
822 |
+
# span mask
|
823 |
+
attention_mask = attention_mask & sample_mask_2d
|
824 |
+
# length mask
|
825 |
+
mask_1d = present_num_segments != 0
|
826 |
+
attention_mask = mask_1d.view(batch, 1, len_buffer) & attention_mask
|
827 |
+
|
828 |
+
hidden_states = self.input_embedding(input_ids, input_id_sub)
|
829 |
+
position_bias = self.position_bias(position, present_position, segment_bucket)
|
830 |
+
hidden_states, present_key_values, all_hidden_states, all_attentions = self.encoder(
|
831 |
+
hidden_states,
|
832 |
+
attention_mask,
|
833 |
+
position_bias,
|
834 |
+
output_attentions,
|
835 |
+
output_hidden_states,
|
836 |
+
present_buffer,
|
837 |
+
use_cache,
|
838 |
+
)
|
839 |
+
|
840 |
+
if not return_dict:
|
841 |
+
return tuple(
|
842 |
+
v for v in [hidden_states, present_key_values, all_hidden_states, all_attentions] if v is not None
|
843 |
+
)
|
844 |
+
|
845 |
+
return BaseModelOutputWithPast(
|
846 |
+
last_hidden_state=hidden_states,
|
847 |
+
past_key_values=present_key_values,
|
848 |
+
hidden_states=all_hidden_states,
|
849 |
+
attentions=all_attentions,
|
850 |
+
)
|
851 |
+
|
852 |
+
|
853 |
+
class CpmBeeBeamHypotheses(BeamHypotheses):
|
854 |
+
def __init__(self, num_beams: int, length_penalty: float, early_stopping: bool, max_length: Optional[int] = None):
|
855 |
+
"""
|
856 |
+
Override BeamHypotheses for CpmBee. The hyp to add is list but not tensor.
|
857 |
+
"""
|
858 |
+
super().__init__(num_beams, length_penalty, early_stopping, max_length)
|
859 |
+
|
860 |
+
def add(self, hyp: List, sum_logprobs: float, beam_indices: Optional[torch.LongTensor] = None):
|
861 |
+
"""
|
862 |
+
Add a new hypothesis to the list.
|
863 |
+
"""
|
864 |
+
score = sum_logprobs / (len(hyp) ** self.length_penalty)
|
865 |
+
if len(self) < self.num_beams or score > self.worst_score:
|
866 |
+
self.beams.append((score, hyp, beam_indices))
|
867 |
+
if len(self) > self.num_beams:
|
868 |
+
sorted_next_scores = sorted([(s, idx) for idx, (s, _, _) in enumerate(self.beams)])
|
869 |
+
del self.beams[sorted_next_scores[0][1]]
|
870 |
+
self.worst_score = sorted_next_scores[1][0]
|
871 |
+
else:
|
872 |
+
self.worst_score = min(score, self.worst_score)
|
873 |
+
|
874 |
+
|
875 |
+
class CPMBeeTransBlock(torch.nn.Module):
|
876 |
+
def __init__(
|
877 |
+
self,
|
878 |
+
dim_model=4096,
|
879 |
+
dim_ff=1024,
|
880 |
+
dim_out=768,
|
881 |
+
dtype=torch.float,
|
882 |
+
eps=1e-6,
|
883 |
+
dropout_p=0,
|
884 |
+
):
|
885 |
+
super().__init__()
|
886 |
+
if dropout_p is not None:
|
887 |
+
self.dropout = torch.nn.Dropout(dropout_p)
|
888 |
+
else:
|
889 |
+
self.dropout = None
|
890 |
+
self.w_out_res = torch.nn.Linear(dim_model, dim_out, bias=False)
|
891 |
+
self.layernorm = torch.nn.LayerNorm(
|
892 |
+
dim_out,
|
893 |
+
dtype=dtype,
|
894 |
+
eps=eps,
|
895 |
+
)
|
896 |
+
|
897 |
+
def forward(self, hidden_states: torch.Tensor):
|
898 |
+
x_res = self.w_out_res(hidden_states)
|
899 |
+
if self.dropout is not None:
|
900 |
+
x_res = self.dropout(x_res)
|
901 |
+
hidden_states = self.layernorm(x_res)
|
902 |
+
return hidden_states
|
903 |
+
|
904 |
+
|
905 |
+
class CpmBeeWithTransform(CpmBeePreTrainedModel):
|
906 |
+
_keys_to_ignore_on_load_missing = [r"lm_head.weight"]
|
907 |
+
|
908 |
+
def __init__(self, config: CpmBeeConfig):
|
909 |
+
super().__init__(config)
|
910 |
+
self.llm = CpmBeeModel(config)
|
911 |
+
|
912 |
+
self.trans_block = CPMBeeTransBlock(config.hidden_size, config.hidden_size // 4, config.unet_cross_attention_dim)
|
913 |
+
|
914 |
+
def forward(
|
915 |
+
self,
|
916 |
+
input_ids: torch.Tensor,
|
917 |
+
input_id_sub: Optional[torch.Tensor] = None,
|
918 |
+
position: Optional[torch.Tensor] = None,
|
919 |
+
context: Optional[torch.Tensor] = None,
|
920 |
+
sample_ids: Optional[torch.Tensor] = None,
|
921 |
+
num_segments: Optional[torch.Tensor] = None,
|
922 |
+
segment: Optional[torch.Tensor] = None,
|
923 |
+
segment_rel_offset: Optional[torch.Tensor] = None,
|
924 |
+
segment_rel: Optional[torch.Tensor] = None,
|
925 |
+
past_states: Optional[Dict] = None,
|
926 |
+
output_attentions: Optional[bool] = None,
|
927 |
+
output_hidden_states: Optional[bool] = None,
|
928 |
+
past_key_values: Optional[List] = None,
|
929 |
+
use_cache: Optional[bool] = None,
|
930 |
+
return_dict: Optional[bool] = None,
|
931 |
+
**kwargs,
|
932 |
+
):
|
933 |
+
outputs = self.llm(input_ids, input_id_sub, position, context,
|
934 |
+
sample_ids, num_segments, segment, segment_rel_offset,
|
935 |
+
segment_rel, past_states, output_attentions, output_hidden_states,
|
936 |
+
past_key_values, use_cache, return_dict, **kwargs,)
|
937 |
+
if return_dict:
|
938 |
+
hidden_states = outputs.last_hidden_state
|
939 |
+
else:
|
940 |
+
hidden_states = outputs[0]
|
941 |
+
#if self.trans_block is not None:
|
942 |
+
# hidden_states = self.trans_block(hidden_states)
|
943 |
+
return outputs, hidden_states
|
pipeline_stable_diffusion.py
ADDED
@@ -0,0 +1,723 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import inspect
|
15 |
+
import warnings
|
16 |
+
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
|
17 |
+
import numpy as np
|
18 |
+
|
19 |
+
import torch
|
20 |
+
from torch.utils.data.dataloader import default_collate
|
21 |
+
from packaging import version
|
22 |
+
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
23 |
+
|
24 |
+
from diffusers.configuration_utils import FrozenDict
|
25 |
+
from diffusers.image_processor import VaeImageProcessor
|
26 |
+
from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
27 |
+
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
28 |
+
from diffusers.schedulers import KarrasDiffusionSchedulers
|
29 |
+
from diffusers.utils import (
|
30 |
+
deprecate,
|
31 |
+
is_accelerate_available,
|
32 |
+
is_accelerate_version,
|
33 |
+
logging,
|
34 |
+
randn_tensor,
|
35 |
+
replace_example_docstring,
|
36 |
+
)
|
37 |
+
from diffusers.pipeline_utils import DiffusionPipeline
|
38 |
+
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
39 |
+
|
40 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
41 |
+
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg, StableDiffusionPipeline
|
42 |
+
from .modeling_cpmbee import CpmBeeModel
|
43 |
+
from .tokenization_viscpmbee import VisCpmBeeTokenizer
|
44 |
+
|
45 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
46 |
+
|
47 |
+
def pad(orig_items, key, max_length=None, padding_value=0, padding_side="left"):
|
48 |
+
items = []
|
49 |
+
if isinstance(orig_items[0][key], list):
|
50 |
+
assert isinstance(orig_items[0][key][0], torch.Tensor)
|
51 |
+
for it in orig_items:
|
52 |
+
for tr in it[key]:
|
53 |
+
items.append({key: tr})
|
54 |
+
else:
|
55 |
+
assert isinstance(orig_items[0][key], torch.Tensor)
|
56 |
+
items = orig_items
|
57 |
+
|
58 |
+
batch_size = len(items)
|
59 |
+
shape = items[0][key].shape
|
60 |
+
dim = len(shape)
|
61 |
+
assert dim <= 3
|
62 |
+
if max_length is None:
|
63 |
+
max_length = 0
|
64 |
+
max_length = max(max_length, max(item[key].shape[-1] for item in items))
|
65 |
+
min_length = min(item[key].shape[-1] for item in items)
|
66 |
+
dtype = items[0][key].dtype
|
67 |
+
|
68 |
+
if dim == 1:
|
69 |
+
return torch.cat([item[key] for item in items], dim=0)
|
70 |
+
elif dim == 2:
|
71 |
+
if max_length == min_length:
|
72 |
+
return torch.cat([item[key] for item in items], dim=0)
|
73 |
+
tensor = torch.zeros((batch_size, max_length), dtype=dtype) + padding_value
|
74 |
+
else:
|
75 |
+
tensor = torch.zeros((batch_size, max_length, shape[-1]), dtype=dtype) + padding_value
|
76 |
+
|
77 |
+
for i, item in enumerate(items):
|
78 |
+
if dim == 2:
|
79 |
+
if padding_side == "left":
|
80 |
+
tensor[i, -len(item[key][0]):] = item[key][0].clone()
|
81 |
+
else:
|
82 |
+
tensor[i, : len(item[key][0])] = item[key][0].clone()
|
83 |
+
elif dim == 3:
|
84 |
+
if padding_side == "left":
|
85 |
+
tensor[i, -len(item[key][0]):, :] = item[key][0].clone()
|
86 |
+
else:
|
87 |
+
tensor[i, : len(item[key][0]), :] = item[key][0].clone()
|
88 |
+
|
89 |
+
return tensor
|
90 |
+
|
91 |
+
|
92 |
+
class CPMBeeCollater:
|
93 |
+
"""
|
94 |
+
针对 cpmbee 输入数据 collate, 对应 cpm-live 的 _MixedDatasetBatchPacker
|
95 |
+
目前利用 torch 的原生 Dataloader 不太适合改造 in-context-learning
|
96 |
+
并且原来实现为了最大化提高有效 token 比比例, 会有一个 best_fit 操作, 这个目前也不支持
|
97 |
+
todo: 重写一下 Dataloader or BatchPacker
|
98 |
+
"""
|
99 |
+
|
100 |
+
def __init__(self, tokenizer: VisCpmBeeTokenizer, max_len):
|
101 |
+
self.tokenizer = tokenizer
|
102 |
+
self._max_length = max_len
|
103 |
+
self.pad_keys = ['input_ids', 'input_id_subs', 'context', 'segment_ids', 'segment_rel_offset',
|
104 |
+
'segment_rel', 'sample_ids', 'num_segments']
|
105 |
+
|
106 |
+
def __call__(self, batch):
|
107 |
+
batch_size = len(batch)
|
108 |
+
|
109 |
+
tgt = np.full((batch_size, self._max_length), -100, dtype=np.int32)
|
110 |
+
# 目前没有 best_fit, span 为全 0
|
111 |
+
span = np.zeros((batch_size, self._max_length), dtype=np.int32)
|
112 |
+
length = np.zeros((batch_size,), dtype=np.int32)
|
113 |
+
|
114 |
+
batch_ext_table_map: Dict[Tuple[int, int], int] = {}
|
115 |
+
batch_ext_table_ids: List[int] = []
|
116 |
+
batch_ext_table_sub: List[int] = []
|
117 |
+
raw_data_list: List[Any] = []
|
118 |
+
|
119 |
+
for i in range(batch_size):
|
120 |
+
instance_length = batch[i]['input_ids'][0].shape[0]
|
121 |
+
length[i] = instance_length
|
122 |
+
raw_data_list.extend(batch[i]['raw_data'])
|
123 |
+
|
124 |
+
for j in range(instance_length):
|
125 |
+
idx, idx_sub = batch[i]['input_ids'][0, j], batch[i]['input_id_subs'][0, j]
|
126 |
+
tgt_idx = idx
|
127 |
+
if idx_sub > 0:
|
128 |
+
# need to be in ext table
|
129 |
+
if (idx, idx_sub) not in batch_ext_table_map:
|
130 |
+
batch_ext_table_map[(idx, idx_sub)] = len(batch_ext_table_map)
|
131 |
+
batch_ext_table_ids.append(idx)
|
132 |
+
batch_ext_table_sub.append(idx_sub)
|
133 |
+
tgt_idx = batch_ext_table_map[(idx, idx_sub)] + self.tokenizer.vocab_size
|
134 |
+
if j > 1 and batch[i]['context'][0, j - 1] == 0:
|
135 |
+
if idx != self.tokenizer.bos_id:
|
136 |
+
tgt[i, j - 1] = tgt_idx
|
137 |
+
else:
|
138 |
+
tgt[i, j - 1] = self.tokenizer.eos_id
|
139 |
+
if batch[i]['context'][0, instance_length - 1] == 0:
|
140 |
+
tgt[i, instance_length - 1] = self.tokenizer.eos_id
|
141 |
+
|
142 |
+
if len(batch_ext_table_map) == 0:
|
143 |
+
# placeholder
|
144 |
+
batch_ext_table_ids.append(0)
|
145 |
+
batch_ext_table_sub.append(1)
|
146 |
+
|
147 |
+
# image
|
148 |
+
if 'pixel_values' in batch[0]:
|
149 |
+
data = {'pixel_values': default_collate([i['pixel_values'] for i in batch])}
|
150 |
+
else:
|
151 |
+
data = {}
|
152 |
+
|
153 |
+
# image_bound
|
154 |
+
if 'image_bound' in batch[0]:
|
155 |
+
data['image_bound'] = default_collate([i['image_bound'] for i in batch])
|
156 |
+
|
157 |
+
# bee inp
|
158 |
+
for key in self.pad_keys:
|
159 |
+
data[key] = pad(batch, key, max_length=self._max_length, padding_value=0, padding_side='right')
|
160 |
+
|
161 |
+
data['context'] = data['context'] > 0
|
162 |
+
data['length'] = torch.from_numpy(length)
|
163 |
+
data['span'] = torch.from_numpy(span)
|
164 |
+
data['target'] = torch.from_numpy(tgt)
|
165 |
+
data['ext_table_ids'] = torch.from_numpy(np.array(batch_ext_table_ids))
|
166 |
+
data['ext_table_sub'] = torch.from_numpy(np.array(batch_ext_table_sub))
|
167 |
+
data['raw_data'] = raw_data_list
|
168 |
+
|
169 |
+
return data
|
170 |
+
|
171 |
+
|
172 |
+
class VisCPMPaintBeePipeline(StableDiffusionPipeline):
|
173 |
+
_optional_components = ["safety_checker", "feature_extractor"]
|
174 |
+
|
175 |
+
def __init__(
|
176 |
+
self,
|
177 |
+
vae: AutoencoderKL,
|
178 |
+
text_encoder: CpmBeeModel,
|
179 |
+
tokenizer: VisCpmBeeTokenizer,
|
180 |
+
unet: UNet2DConditionModel,
|
181 |
+
scheduler: KarrasDiffusionSchedulers,
|
182 |
+
safety_checker: StableDiffusionSafetyChecker,
|
183 |
+
feature_extractor: CLIPImageProcessor,
|
184 |
+
requires_safety_checker: bool = True,
|
185 |
+
):
|
186 |
+
super().__init__(
|
187 |
+
vae=vae,
|
188 |
+
text_encoder=text_encoder,
|
189 |
+
tokenizer=tokenizer,
|
190 |
+
unet=unet,
|
191 |
+
scheduler=scheduler,
|
192 |
+
safety_checker=safety_checker,
|
193 |
+
feature_extractor=feature_extractor,
|
194 |
+
requires_safety_checker=requires_safety_checker
|
195 |
+
)
|
196 |
+
|
197 |
+
def build_input(
|
198 |
+
self,
|
199 |
+
prompt: str,
|
200 |
+
negative_prompt: Optional[str] = None,
|
201 |
+
image_size: int = 512
|
202 |
+
):
|
203 |
+
data_input = {'caption': prompt, 'objects': ''}
|
204 |
+
(
|
205 |
+
input_ids,
|
206 |
+
input_id_subs,
|
207 |
+
context,
|
208 |
+
segment_ids,
|
209 |
+
segment_rel,
|
210 |
+
n_segments,
|
211 |
+
table_states,
|
212 |
+
image_bound
|
213 |
+
) = self.tokenizer.convert_data_to_id(data=data_input, shuffle_answer=False, max_depth=8)
|
214 |
+
sample_ids = np.zeros(input_ids.shape, dtype=np.int32)
|
215 |
+
segment_rel_offset = np.zeros(input_ids.shape, dtype=np.int32)
|
216 |
+
num_segments = np.full(input_ids.shape, n_segments, dtype=np.int32)
|
217 |
+
data = {
|
218 |
+
'pixel_values': torch.zeros(3, image_size, image_size).unsqueeze(0),
|
219 |
+
'input_ids': torch.from_numpy(input_ids).unsqueeze(0),
|
220 |
+
'input_id_subs': torch.from_numpy(input_id_subs).unsqueeze(0),
|
221 |
+
'context': torch.from_numpy(context).unsqueeze(0),
|
222 |
+
'segment_ids': torch.from_numpy(segment_ids).unsqueeze(0),
|
223 |
+
'segment_rel_offset': torch.from_numpy(segment_rel_offset).unsqueeze(0),
|
224 |
+
'segment_rel': torch.from_numpy(segment_rel).unsqueeze(0),
|
225 |
+
'sample_ids': torch.from_numpy(sample_ids).unsqueeze(0),
|
226 |
+
'num_segments': torch.from_numpy(num_segments).unsqueeze(0),
|
227 |
+
'image_bound': image_bound,
|
228 |
+
'raw_data': prompt,
|
229 |
+
}
|
230 |
+
|
231 |
+
uncond_data_input = {
|
232 |
+
'caption': "" if negative_prompt is None else negative_prompt,
|
233 |
+
'objects': ''
|
234 |
+
}
|
235 |
+
(
|
236 |
+
input_ids,
|
237 |
+
input_id_subs,
|
238 |
+
context,
|
239 |
+
segment_ids,
|
240 |
+
segment_rel,
|
241 |
+
n_segments,
|
242 |
+
table_states,
|
243 |
+
image_bound
|
244 |
+
) = self.tokenizer.convert_data_to_id(data=uncond_data_input, shuffle_answer=False, max_depth=8)
|
245 |
+
sample_ids = np.zeros(input_ids.shape, dtype=np.int32)
|
246 |
+
segment_rel_offset = np.zeros(input_ids.shape, dtype=np.int32)
|
247 |
+
num_segments = np.full(input_ids.shape, n_segments, dtype=np.int32)
|
248 |
+
uncond_data = {
|
249 |
+
'pixel_values': torch.zeros(3, image_size, image_size).unsqueeze(0),
|
250 |
+
'input_ids': torch.from_numpy(input_ids).unsqueeze(0),
|
251 |
+
'input_id_subs': torch.from_numpy(input_id_subs).unsqueeze(0),
|
252 |
+
'context': torch.from_numpy(context).unsqueeze(0),
|
253 |
+
'segment_ids': torch.from_numpy(segment_ids).unsqueeze(0),
|
254 |
+
'segment_rel_offset': torch.from_numpy(segment_rel_offset).unsqueeze(0),
|
255 |
+
'segment_rel': torch.from_numpy(segment_rel).unsqueeze(0),
|
256 |
+
'sample_ids': torch.from_numpy(sample_ids).unsqueeze(0),
|
257 |
+
'num_segments': torch.from_numpy(num_segments).unsqueeze(0),
|
258 |
+
'image_bound': image_bound,
|
259 |
+
'raw_data': "" if negative_prompt is None else negative_prompt,
|
260 |
+
}
|
261 |
+
packer = CPMBeeCollater(
|
262 |
+
tokenizer=self.tokenizer,
|
263 |
+
max_len=max(data['input_ids'].size(-1), uncond_data['input_ids'].size(-1))
|
264 |
+
)
|
265 |
+
data = packer([data])
|
266 |
+
uncond_data = packer([uncond_data])
|
267 |
+
return data, uncond_data
|
268 |
+
|
269 |
+
def _encode_prompt(
|
270 |
+
self,
|
271 |
+
prompt,
|
272 |
+
device,
|
273 |
+
num_images_per_prompt,
|
274 |
+
do_classifier_free_guidance,
|
275 |
+
negative_prompt=None,
|
276 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
277 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
278 |
+
lora_scale: Optional[float] = None,
|
279 |
+
):
|
280 |
+
r"""
|
281 |
+
Encodes the prompt into text encoder hidden states.
|
282 |
+
|
283 |
+
Args:
|
284 |
+
prompt (`str` or `List[str]`, *optional*):
|
285 |
+
prompt to be encoded
|
286 |
+
device: (`torch.device`):
|
287 |
+
torch device
|
288 |
+
num_images_per_prompt (`int`):
|
289 |
+
number of images that should be generated per prompt
|
290 |
+
do_classifier_free_guidance (`bool`):
|
291 |
+
whether to use classifier free guidance or not
|
292 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
293 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
294 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
295 |
+
less than `1`).
|
296 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
297 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
298 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
299 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
300 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
301 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
302 |
+
argument.
|
303 |
+
lora_scale (`float`, *optional*):
|
304 |
+
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
305 |
+
"""
|
306 |
+
# set lora scale so that monkey patched LoRA
|
307 |
+
# function of text encoder can correctly access it
|
308 |
+
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
|
309 |
+
self._lora_scale = lora_scale
|
310 |
+
|
311 |
+
data, uncond_data = self.build_input(prompt, negative_prompt, image_size=512)
|
312 |
+
for key, value in data.items():
|
313 |
+
if isinstance(value, torch.Tensor):
|
314 |
+
data[key] = value.to(self.device)
|
315 |
+
for key, value in uncond_data.items():
|
316 |
+
if isinstance(value, torch.Tensor):
|
317 |
+
uncond_data[key] = value.to(self.device)
|
318 |
+
|
319 |
+
batch, seq_length = data['input_ids'].size()
|
320 |
+
dtype, device = data['input_ids'].dtype, data['input_ids'].device
|
321 |
+
data['position'] = torch.arange(seq_length, dtype=dtype, device=device).repeat(batch, 1)
|
322 |
+
|
323 |
+
batch, seq_length = uncond_data['input_ids'].size()
|
324 |
+
dtype, device = uncond_data['input_ids'].dtype, uncond_data['input_ids'].device
|
325 |
+
uncond_data['position'] = torch.arange(seq_length, dtype=dtype, device=device).repeat(batch, 1)
|
326 |
+
|
327 |
+
with torch.no_grad():
|
328 |
+
# llm_hidden_state = self.text_encoder.llm.input_embedding(data['input_ids'], data['input_id_subs'])
|
329 |
+
_, hidden_states = self.text_encoder(
|
330 |
+
input_ids=data['input_ids'],
|
331 |
+
input_id_sub=data['input_id_subs'],
|
332 |
+
position=data['position'],
|
333 |
+
#length=data['length'],
|
334 |
+
context=data['context'],
|
335 |
+
sample_ids=data['sample_ids'],
|
336 |
+
num_segments=data['num_segments'],
|
337 |
+
segment=data['segment_ids'],
|
338 |
+
segment_rel_offset=data['segment_rel_offset'],
|
339 |
+
segment_rel=data['segment_rel'],
|
340 |
+
#span=data['span'],
|
341 |
+
#ext_table_ids=data['ext_table_ids'],
|
342 |
+
#ext_table_sub=data['ext_table_sub'],
|
343 |
+
#hidden_states=llm_hidden_state
|
344 |
+
)
|
345 |
+
|
346 |
+
with torch.no_grad():
|
347 |
+
# uncond_llm_hidden_state = self.text_encoder.llm.input_embedding(uncond_data['input_ids'], uncond_data['input_id_subs'])
|
348 |
+
_, uncond_hidden_states = self.text_encoder(
|
349 |
+
input_ids=uncond_data['input_ids'],
|
350 |
+
input_id_sub=uncond_data['input_id_subs'],
|
351 |
+
position=uncond_data['position'],
|
352 |
+
#length=uncond_data['length'],
|
353 |
+
context=uncond_data['context'],
|
354 |
+
sample_ids=uncond_data['sample_ids'],
|
355 |
+
num_segments=uncond_data['num_segments'],
|
356 |
+
segment=uncond_data['segment_ids'],
|
357 |
+
segment_rel_offset=uncond_data['segment_rel_offset'],
|
358 |
+
segment_rel=uncond_data['segment_rel'],
|
359 |
+
#span=uncond_data['span'],
|
360 |
+
#ext_table_ids=uncond_data['ext_table_ids'],
|
361 |
+
#ext_table_sub=uncond_data['ext_table_sub'],
|
362 |
+
#hidden_states=uncond_llm_hidden_state
|
363 |
+
)
|
364 |
+
|
365 |
+
text_hidden_states, uncond_text_hidden_states = hidden_states, uncond_hidden_states
|
366 |
+
if self.text_encoder.trans_block is not None:
|
367 |
+
text_hidden_states = self.text_encoder.trans_block(text_hidden_states)
|
368 |
+
uncond_text_hidden_states = self.text_encoder.trans_block(uncond_text_hidden_states)
|
369 |
+
bs_embed, seq_len, _ = text_hidden_states.shape
|
370 |
+
text_hidden_states = text_hidden_states.repeat(1, num_images_per_prompt, 1)
|
371 |
+
text_hidden_states = text_hidden_states.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
372 |
+
|
373 |
+
bs_embed, seq_len, _ = uncond_text_hidden_states.shape
|
374 |
+
uncond_text_hidden_states = uncond_text_hidden_states.repeat(1, num_images_per_prompt, 1)
|
375 |
+
uncond_text_hidden_states = uncond_text_hidden_states.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
376 |
+
|
377 |
+
prompt_embeds = torch.cat([uncond_text_hidden_states, text_hidden_states])
|
378 |
+
return prompt_embeds
|
379 |
+
|
380 |
+
# if prompt is not None and isinstance(prompt, str):
|
381 |
+
# batch_size = 1
|
382 |
+
# elif prompt is not None and isinstance(prompt, list):
|
383 |
+
# batch_size = len(prompt)
|
384 |
+
# else:
|
385 |
+
# batch_size = prompt_embeds.shape[0]
|
386 |
+
|
387 |
+
# if prompt_embeds is None:
|
388 |
+
# # textual inversion: procecss multi-vector tokens if necessary
|
389 |
+
# if isinstance(self, TextualInversionLoaderMixin):
|
390 |
+
# prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
391 |
+
|
392 |
+
# text_inputs = self.tokenizer(
|
393 |
+
# prompt,
|
394 |
+
# padding="max_length",
|
395 |
+
# max_length=self.tokenizer.model_max_length,
|
396 |
+
# truncation=True,
|
397 |
+
# return_tensors="pt",
|
398 |
+
# )
|
399 |
+
# text_input_ids = text_inputs.input_ids
|
400 |
+
# untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
401 |
+
|
402 |
+
# if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
403 |
+
# text_input_ids, untruncated_ids
|
404 |
+
# ):
|
405 |
+
# removed_text = self.tokenizer.batch_decode(
|
406 |
+
# untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
407 |
+
# )
|
408 |
+
# logger.warning(
|
409 |
+
# "The following part of your input was truncated because CLIP can only handle sequences up to"
|
410 |
+
# f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
411 |
+
# )
|
412 |
+
|
413 |
+
# if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
414 |
+
# attention_mask = text_inputs.attention_mask.to(device)
|
415 |
+
# else:
|
416 |
+
# attention_mask = None
|
417 |
+
|
418 |
+
# prompt_embeds = self.text_encoder(
|
419 |
+
# text_input_ids.to(device),
|
420 |
+
# attention_mask=attention_mask,
|
421 |
+
# )
|
422 |
+
# prompt_embeds = prompt_embeds[0]
|
423 |
+
|
424 |
+
# prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
425 |
+
|
426 |
+
# bs_embed, seq_len, _ = prompt_embeds.shape
|
427 |
+
# # duplicate text embeddings for each generation per prompt, using mps friendly method
|
428 |
+
# prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
429 |
+
# prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
430 |
+
|
431 |
+
# # get unconditional embeddings for classifier free guidance
|
432 |
+
# if do_classifier_free_guidance and negative_prompt_embeds is None:
|
433 |
+
# uncond_tokens: List[str]
|
434 |
+
# if negative_prompt is None:
|
435 |
+
# uncond_tokens = [""] * batch_size
|
436 |
+
# elif prompt is not None and type(prompt) is not type(negative_prompt):
|
437 |
+
# raise TypeError(
|
438 |
+
# f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
439 |
+
# f" {type(prompt)}."
|
440 |
+
# )
|
441 |
+
# elif isinstance(negative_prompt, str):
|
442 |
+
# uncond_tokens = [negative_prompt]
|
443 |
+
# elif batch_size != len(negative_prompt):
|
444 |
+
# raise ValueError(
|
445 |
+
# f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
446 |
+
# f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
447 |
+
# " the batch size of `prompt`."
|
448 |
+
# )
|
449 |
+
# else:
|
450 |
+
# uncond_tokens = negative_prompt
|
451 |
+
|
452 |
+
# # textual inversion: procecss multi-vector tokens if necessary
|
453 |
+
# if isinstance(self, TextualInversionLoaderMixin):
|
454 |
+
# uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
455 |
+
|
456 |
+
# max_length = prompt_embeds.shape[1]
|
457 |
+
# uncond_input = self.tokenizer(
|
458 |
+
# uncond_tokens,
|
459 |
+
# padding="max_length",
|
460 |
+
# max_length=max_length,
|
461 |
+
# truncation=True,
|
462 |
+
# return_tensors="pt",
|
463 |
+
# )
|
464 |
+
|
465 |
+
# if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
466 |
+
# attention_mask = uncond_input.attention_mask.to(device)
|
467 |
+
# else:
|
468 |
+
# attention_mask = None
|
469 |
+
|
470 |
+
# negative_prompt_embeds = self.text_encoder(
|
471 |
+
# uncond_input.input_ids.to(device),
|
472 |
+
# attention_mask=attention_mask,
|
473 |
+
# )
|
474 |
+
# negative_prompt_embeds = negative_prompt_embeds[0]
|
475 |
+
|
476 |
+
# if do_classifier_free_guidance:
|
477 |
+
# # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
478 |
+
# seq_len = negative_prompt_embeds.shape[1]
|
479 |
+
|
480 |
+
# negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
481 |
+
|
482 |
+
# negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
483 |
+
# negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
484 |
+
|
485 |
+
# # For classifier free guidance, we need to do two forward passes.
|
486 |
+
# # Here we concatenate the unconditional and text embeddings into a single batch
|
487 |
+
# # to avoid doing two forward passes
|
488 |
+
# prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
489 |
+
|
490 |
+
# return prompt_embeds
|
491 |
+
|
492 |
+
def decode_latents(self, latents):
|
493 |
+
warnings.warn(
|
494 |
+
"The decode_latents method is deprecated and will be removed in a future version. Please"
|
495 |
+
" use VaeImageProcessor instead",
|
496 |
+
FutureWarning,
|
497 |
+
)
|
498 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
499 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
500 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
501 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
502 |
+
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
503 |
+
return image
|
504 |
+
|
505 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
506 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
507 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
508 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
509 |
+
# and should be between [0, 1]
|
510 |
+
|
511 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
512 |
+
extra_step_kwargs = {}
|
513 |
+
if accepts_eta:
|
514 |
+
extra_step_kwargs["eta"] = eta
|
515 |
+
|
516 |
+
# check if the scheduler accepts generator
|
517 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
518 |
+
if accepts_generator:
|
519 |
+
extra_step_kwargs["generator"] = generator
|
520 |
+
return extra_step_kwargs
|
521 |
+
|
522 |
+
def check_inputs(
|
523 |
+
self,
|
524 |
+
prompt,
|
525 |
+
height,
|
526 |
+
width,
|
527 |
+
callback_steps,
|
528 |
+
negative_prompt=None,
|
529 |
+
prompt_embeds=None,
|
530 |
+
negative_prompt_embeds=None,
|
531 |
+
):
|
532 |
+
if height % 8 != 0 or width % 8 != 0:
|
533 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
534 |
+
|
535 |
+
if (callback_steps is None) or (
|
536 |
+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
537 |
+
):
|
538 |
+
raise ValueError(
|
539 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
540 |
+
f" {type(callback_steps)}."
|
541 |
+
)
|
542 |
+
|
543 |
+
if prompt is not None and prompt_embeds is not None:
|
544 |
+
raise ValueError(
|
545 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
546 |
+
" only forward one of the two."
|
547 |
+
)
|
548 |
+
elif prompt is None and prompt_embeds is None:
|
549 |
+
raise ValueError(
|
550 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
551 |
+
)
|
552 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
553 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
554 |
+
|
555 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
556 |
+
raise ValueError(
|
557 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
558 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
559 |
+
)
|
560 |
+
|
561 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
562 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
563 |
+
raise ValueError(
|
564 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
565 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
566 |
+
f" {negative_prompt_embeds.shape}."
|
567 |
+
)
|
568 |
+
|
569 |
+
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
570 |
+
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
571 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
572 |
+
raise ValueError(
|
573 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
574 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
575 |
+
)
|
576 |
+
|
577 |
+
if latents is None:
|
578 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
579 |
+
else:
|
580 |
+
latents = latents.to(device)
|
581 |
+
|
582 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
583 |
+
latents = latents * self.scheduler.init_noise_sigma
|
584 |
+
return latents
|
585 |
+
|
586 |
+
@torch.no_grad()
|
587 |
+
def __call__(
|
588 |
+
self,
|
589 |
+
prompt: Union[str, List[str]] = None,
|
590 |
+
height: Optional[int] = None,
|
591 |
+
width: Optional[int] = None,
|
592 |
+
num_inference_steps: int = 50,
|
593 |
+
guidance_scale: float = 7.5,
|
594 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
595 |
+
num_images_per_prompt: Optional[int] = 1,
|
596 |
+
eta: float = 0.0,
|
597 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
598 |
+
latents: Optional[torch.FloatTensor] = None,
|
599 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
600 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
601 |
+
output_type: Optional[str] = "pil",
|
602 |
+
return_dict: bool = True,
|
603 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
604 |
+
callback_steps: int = 1,
|
605 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
606 |
+
guidance_rescale: float = 0.0,
|
607 |
+
):
|
608 |
+
# 0. Default height and width to unet
|
609 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
610 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
611 |
+
|
612 |
+
# 1. Check inputs. Raise error if not correct
|
613 |
+
self.check_inputs(
|
614 |
+
prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
|
615 |
+
)
|
616 |
+
|
617 |
+
# 2. Define call parameters
|
618 |
+
if prompt is not None and isinstance(prompt, str):
|
619 |
+
batch_size = 1
|
620 |
+
elif prompt is not None and isinstance(prompt, list):
|
621 |
+
batch_size = len(prompt)
|
622 |
+
else:
|
623 |
+
batch_size = prompt_embeds.shape[0]
|
624 |
+
|
625 |
+
device = self._execution_device
|
626 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
627 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
628 |
+
# corresponds to doing no classifier free guidance.
|
629 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
630 |
+
|
631 |
+
# 3. Encode input prompt
|
632 |
+
text_encoder_lora_scale = (
|
633 |
+
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
634 |
+
)
|
635 |
+
|
636 |
+
prompt_embeds = self._encode_prompt(
|
637 |
+
prompt,
|
638 |
+
device,
|
639 |
+
num_images_per_prompt,
|
640 |
+
do_classifier_free_guidance,
|
641 |
+
negative_prompt,
|
642 |
+
prompt_embeds=prompt_embeds,
|
643 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
644 |
+
lora_scale=text_encoder_lora_scale,
|
645 |
+
)
|
646 |
+
|
647 |
+
# 4. Prepare timesteps
|
648 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
649 |
+
timesteps = self.scheduler.timesteps
|
650 |
+
|
651 |
+
# 5. Prepare latent variables
|
652 |
+
num_channels_latents = self.unet.config.in_channels
|
653 |
+
latents = self.prepare_latents(
|
654 |
+
batch_size * num_images_per_prompt,
|
655 |
+
num_channels_latents,
|
656 |
+
height,
|
657 |
+
width,
|
658 |
+
prompt_embeds.dtype,
|
659 |
+
device,
|
660 |
+
generator,
|
661 |
+
latents,
|
662 |
+
)
|
663 |
+
|
664 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
665 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
666 |
+
|
667 |
+
# 7. Denoising loop
|
668 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
669 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
670 |
+
for i, t in enumerate(timesteps):
|
671 |
+
# expand the latents if we are doing classifier free guidance
|
672 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
673 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
674 |
+
|
675 |
+
# predict the noise residual
|
676 |
+
noise_pred = self.unet(
|
677 |
+
latent_model_input,
|
678 |
+
t,
|
679 |
+
encoder_hidden_states=prompt_embeds,
|
680 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
681 |
+
return_dict=False,
|
682 |
+
)[0]
|
683 |
+
|
684 |
+
# perform guidance
|
685 |
+
if do_classifier_free_guidance:
|
686 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
687 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
688 |
+
|
689 |
+
if do_classifier_free_guidance and guidance_rescale > 0.0:
|
690 |
+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
691 |
+
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
692 |
+
|
693 |
+
# compute the previous noisy sample x_t -> x_t-1
|
694 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
695 |
+
|
696 |
+
# call the callback, if provided
|
697 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
698 |
+
progress_bar.update()
|
699 |
+
if callback is not None and i % callback_steps == 0:
|
700 |
+
callback(i, t, latents)
|
701 |
+
|
702 |
+
if not output_type == "latent":
|
703 |
+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
704 |
+
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
705 |
+
else:
|
706 |
+
image = latents
|
707 |
+
has_nsfw_concept = None
|
708 |
+
|
709 |
+
if has_nsfw_concept is None:
|
710 |
+
do_denormalize = [True] * image.shape[0]
|
711 |
+
else:
|
712 |
+
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
713 |
+
|
714 |
+
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
715 |
+
|
716 |
+
# Offload last model to CPU
|
717 |
+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
718 |
+
self.final_offload_hook.offload()
|
719 |
+
|
720 |
+
if not return_dict:
|
721 |
+
return (image, has_nsfw_concept)
|
722 |
+
|
723 |
+
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
scheduler/scheduler_config.json
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_class_name": "DDIMScheduler",
|
3 |
+
"_diffusers_version": "0.8.0",
|
4 |
+
"beta_end": 0.012,
|
5 |
+
"beta_schedule": "scaled_linear",
|
6 |
+
"beta_start": 0.00085,
|
7 |
+
"clip_sample": false,
|
8 |
+
"num_train_timesteps": 1000,
|
9 |
+
"prediction_type": "epsilon",
|
10 |
+
"set_alpha_to_one": false,
|
11 |
+
"skip_prk_steps": true,
|
12 |
+
"steps_offset": 1,
|
13 |
+
"trained_betas": null
|
14 |
+
}
|
tokenization_viscpmbee.py
ADDED
@@ -0,0 +1,1008 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2022 The OpenBMB Team and The HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""Tokenization classes for CpmBee."""
|
16 |
+
import json
|
17 |
+
import os
|
18 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
from numpy.typing import NDArray
|
22 |
+
from typing_extensions import TypedDict
|
23 |
+
|
24 |
+
from transformers.tokenization_utils import PaddingStrategy, PreTrainedTokenizer, TensorType
|
25 |
+
from transformers.tokenization_utils_base import AddedToken, BatchEncoding, TextInput, TruncationStrategy
|
26 |
+
from transformers.utils import logging
|
27 |
+
|
28 |
+
|
29 |
+
logger = logging.get_logger(__name__)
|
30 |
+
|
31 |
+
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
|
32 |
+
|
33 |
+
PRETRAINED_VOCAB_FILES_MAP = {
|
34 |
+
"vocab_file": {
|
35 |
+
"openbmb/viscpmchat-bee-10b": "https://huggingface.co/openbmb/VisCPM-Chat/blob/main/vocab.txt",
|
36 |
+
},
|
37 |
+
}
|
38 |
+
|
39 |
+
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
40 |
+
"openbmb/viscpmchat-bee-10b": 4096,
|
41 |
+
}
|
42 |
+
|
43 |
+
|
44 |
+
class _PrevExtTableStates(TypedDict):
|
45 |
+
ext_table: Dict[int, str]
|
46 |
+
token_id_table: Dict[str, Dict[int, int]]
|
47 |
+
|
48 |
+
|
49 |
+
CPMBeeInputType = Union[str, Dict[str, "CPMBeeInputType"]]
|
50 |
+
|
51 |
+
|
52 |
+
def rel_to_bucket(n_up: int, n_down: int, max_depth: int = 8):
|
53 |
+
ret = n_up * max_depth + n_down
|
54 |
+
if ret == 0:
|
55 |
+
return ret
|
56 |
+
else:
|
57 |
+
# bucket 1 is reserved for incontext samples
|
58 |
+
return ret + 1
|
59 |
+
|
60 |
+
|
61 |
+
class _DictTree(TypedDict):
|
62 |
+
value: str
|
63 |
+
children: List["_DictTree"]
|
64 |
+
depth: int
|
65 |
+
segment_id: int
|
66 |
+
need_predict: bool
|
67 |
+
is_image: bool
|
68 |
+
|
69 |
+
|
70 |
+
class VisCpmBeeTokenizer(PreTrainedTokenizer):
|
71 |
+
"""
|
72 |
+
Construct a CPMBee tokenizer.
|
73 |
+
|
74 |
+
Args:
|
75 |
+
vocab_file (`str`):
|
76 |
+
Path to the vocabulary file.
|
77 |
+
bos_token (`str`, *optional*, defaults to `"<s>"`):
|
78 |
+
The beginning of sequence token.
|
79 |
+
eos_token (`str`, *optional*, defaults to `"</s>"`):
|
80 |
+
The end of sequence token.
|
81 |
+
line_token (`str`, *optional*, defaults to `"\n"`):
|
82 |
+
The line token.
|
83 |
+
space_token (`str`, *optional*, defaults to `" "`):
|
84 |
+
The space token.
|
85 |
+
unk_token (`str`, *optional*, defaults to `"<unk>"`):
|
86 |
+
The unknown token.
|
87 |
+
mask_token (`str`, *optional*, defaults to `"<mask>"`):
|
88 |
+
The mask token.
|
89 |
+
pad_token (`str`, *optional*, defaults to `"<pad>"`):
|
90 |
+
The token used for padding.
|
91 |
+
padding_side (`str`, *optional*, defaults to `"left"`):
|
92 |
+
The padding side. CPM-Bee will use left padding by default.
|
93 |
+
"""
|
94 |
+
|
95 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
96 |
+
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
97 |
+
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
98 |
+
model_input_names: List[str] = [
|
99 |
+
"input_ids",
|
100 |
+
"attention_mask",
|
101 |
+
"input_id_sub",
|
102 |
+
"position",
|
103 |
+
"context",
|
104 |
+
"sample_ids",
|
105 |
+
"num_segments",
|
106 |
+
"segment",
|
107 |
+
"segment_rel_offset",
|
108 |
+
"segment_rel",
|
109 |
+
]
|
110 |
+
add_prefix_space = False
|
111 |
+
|
112 |
+
def __init__(
|
113 |
+
self,
|
114 |
+
vocab_file,
|
115 |
+
bos_token="<s>",
|
116 |
+
eos_token="</s>",
|
117 |
+
line_token="\n",
|
118 |
+
space_token=" ",
|
119 |
+
unk_token="<unk>",
|
120 |
+
mask_token="<mask>",
|
121 |
+
pad_token="<pad>",
|
122 |
+
padding_side="left",
|
123 |
+
**kwargs,
|
124 |
+
):
|
125 |
+
super().__init__(
|
126 |
+
bos_token=bos_token,
|
127 |
+
eos_token=eos_token,
|
128 |
+
line_token=line_token,
|
129 |
+
space_token=space_token,
|
130 |
+
unk_token=unk_token,
|
131 |
+
mask_token=mask_token,
|
132 |
+
pad_token=pad_token,
|
133 |
+
padding_side=padding_side,
|
134 |
+
**kwargs,
|
135 |
+
)
|
136 |
+
|
137 |
+
self.encoder: Dict[str, int] = {}
|
138 |
+
|
139 |
+
with open(vocab_file, "r", encoding="utf-8") as reader:
|
140 |
+
for token in reader.readlines():
|
141 |
+
token = token.rstrip("\n")
|
142 |
+
if len(token) == 0:
|
143 |
+
continue
|
144 |
+
self.encoder[token] = len(self.encoder)
|
145 |
+
|
146 |
+
self.encoder[" "] = self.encoder["</_>"]
|
147 |
+
self.encoder["\n"] = self.encoder["</n>"]
|
148 |
+
del self.encoder["</_>"]
|
149 |
+
del self.encoder["</n>"]
|
150 |
+
|
151 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
152 |
+
|
153 |
+
self._max_word_len = max([len(x) for x in self.encoder.keys()])
|
154 |
+
self.cpmbee_special_tokens = {k: v for k, v in self.encoder.items() if k.startswith("<") and k.endswith(">")}
|
155 |
+
|
156 |
+
self.ext_table: Dict[int, str] = {}
|
157 |
+
self.ext_table_rev: Dict[str, int] = {}
|
158 |
+
|
159 |
+
self.token_id_table: Dict[str, Dict[int, int]] = {}
|
160 |
+
self.ext_special_tokens = []
|
161 |
+
|
162 |
+
self.ext_args_for_model = [
|
163 |
+
"input_id_subs",
|
164 |
+
"input_pos",
|
165 |
+
"context",
|
166 |
+
"segment_ids",
|
167 |
+
"segment_rel_offset",
|
168 |
+
"segment_rel",
|
169 |
+
"sample_ids",
|
170 |
+
"num_segments",
|
171 |
+
"predict_segments",
|
172 |
+
"answer_placeholders",
|
173 |
+
"ext_table",
|
174 |
+
"token_id_table",
|
175 |
+
"image_bound"
|
176 |
+
]
|
177 |
+
|
178 |
+
@property
|
179 |
+
def bod_token_id(self):
|
180 |
+
return self.encoder[self.bod_token]
|
181 |
+
|
182 |
+
@property
|
183 |
+
def eod_token_id(self):
|
184 |
+
return self.encoder[self.eod_token]
|
185 |
+
|
186 |
+
@property
|
187 |
+
def newline_id(self):
|
188 |
+
return self.encoder[self.line_token]
|
189 |
+
|
190 |
+
@property
|
191 |
+
def vocab_size(self) -> int:
|
192 |
+
return len(self.encoder)
|
193 |
+
|
194 |
+
def __len__(self):
|
195 |
+
"""
|
196 |
+
Size of the full vocabulary with the added tokens.
|
197 |
+
"""
|
198 |
+
return self.vocab_size + len(self.added_tokens_encoder)
|
199 |
+
|
200 |
+
def get_vocab(self):
|
201 |
+
return dict(self.encoder, **self.added_tokens_encoder)
|
202 |
+
|
203 |
+
def get_piece(self, text: str) -> str:
|
204 |
+
"""
|
205 |
+
Match with maximum length.
|
206 |
+
"""
|
207 |
+
len_text = len(text)
|
208 |
+
for i in range(len(text)):
|
209 |
+
sub = text[: len_text - i]
|
210 |
+
if (sub in self.encoder) or (sub in self.added_tokens_encoder):
|
211 |
+
return sub
|
212 |
+
return text[0]
|
213 |
+
|
214 |
+
def tokenize(self, text: TextInput, **kwargs) -> List[str]:
|
215 |
+
r"""
|
216 |
+
Override the `tokenize` to meet the needs of CPMBee:
|
217 |
+
1. Mark the special token with `<` and `>`. The `<>` will be ignored.
|
218 |
+
2. Split sentences by the marked special tokens.
|
219 |
+
3. Record the marked special token by `ext_table` and `ext_table_rev`.
|
220 |
+
4. Tokenize the sentence without special tokens.
|
221 |
+
"""
|
222 |
+
for_cpmbee = kwargs.get("for_cpmbee", False)
|
223 |
+
all_special_tokens_extended = {
|
224 |
+
str(t): t for t in self.all_special_tokens_extended if isinstance(t, AddedToken)
|
225 |
+
}
|
226 |
+
|
227 |
+
sentence_split = [""]
|
228 |
+
is_special_token = False
|
229 |
+
for i, c in enumerate(text):
|
230 |
+
if is_special_token:
|
231 |
+
if c == "<":
|
232 |
+
tail = sentence_split.pop(-1)
|
233 |
+
sentence_split[-1] += tail
|
234 |
+
sentence_split.append(c)
|
235 |
+
elif c == ">":
|
236 |
+
# end of special token
|
237 |
+
sentence_split[-1] += c
|
238 |
+
if sentence_split[-1] == "<>":
|
239 |
+
continue
|
240 |
+
is_special_token = False
|
241 |
+
sentence_split.append("")
|
242 |
+
else:
|
243 |
+
sentence_split[-1] += c
|
244 |
+
else:
|
245 |
+
if c == "<":
|
246 |
+
is_special_token = True
|
247 |
+
sentence_split.append(c)
|
248 |
+
else:
|
249 |
+
sentence_split[-1] += c
|
250 |
+
if is_special_token:
|
251 |
+
tail = sentence_split.pop(-1)
|
252 |
+
sentence_split[-1] += tail
|
253 |
+
|
254 |
+
output_tokens = []
|
255 |
+
for i, part in enumerate(sentence_split):
|
256 |
+
if (i & 1) == 1:
|
257 |
+
# special token
|
258 |
+
output_tokens.append(part)
|
259 |
+
if for_cpmbee and (part not in self.encoder) and (part not in self.ext_table_rev):
|
260 |
+
self.ext_table_rev[part] = len(self.ext_table_rev) + self.vocab_size
|
261 |
+
self.ext_table[self.ext_table_rev[part]] = part
|
262 |
+
else:
|
263 |
+
output_tokens.extend(self._tokenize(part, for_cpmbee=for_cpmbee))
|
264 |
+
|
265 |
+
# drop spaces
|
266 |
+
for i, token in enumerate(output_tokens):
|
267 |
+
if token in self.added_tokens_encoder:
|
268 |
+
token = all_special_tokens_extended.get(token, None)
|
269 |
+
left = output_tokens[i - 1] if i > 0 else None
|
270 |
+
right = output_tokens[i + 1] if i < len(output_tokens) - 1 else None
|
271 |
+
if isinstance(token, AddedToken):
|
272 |
+
if token.rstrip and right:
|
273 |
+
# A bit counter-intuitive but we strip the left of the string
|
274 |
+
# since tok_extended.rstrip means the special token is eating all white spaces on its right
|
275 |
+
output_tokens[i + 1] = right.lstrip()
|
276 |
+
# Strip white spaces on the left
|
277 |
+
if token.lstrip and left:
|
278 |
+
output_tokens[i - 1] = left.rstrip() # Opposite here
|
279 |
+
else:
|
280 |
+
if right:
|
281 |
+
output_tokens[i + 1] = right.lstrip()
|
282 |
+
if left:
|
283 |
+
output_tokens[i - 1] = left.rstrip()
|
284 |
+
|
285 |
+
skipped_tokens = []
|
286 |
+
for token in output_tokens:
|
287 |
+
if not token:
|
288 |
+
continue
|
289 |
+
else:
|
290 |
+
skipped_tokens.append(token)
|
291 |
+
|
292 |
+
return skipped_tokens
|
293 |
+
|
294 |
+
def _tokenize(self, text, **kwargs):
|
295 |
+
"""
|
296 |
+
Converts a string in a sequence of tokens (string), using the tokenizer. Split in words for word-based
|
297 |
+
vocabulary.
|
298 |
+
|
299 |
+
Do NOT take care of added tokens. Record the unk tokens and special tokens in `ext_table` and `ext_table_rev`.
|
300 |
+
"""
|
301 |
+
for_cpmbee = kwargs.get("for_cpmbee", False)
|
302 |
+
output_tokens = []
|
303 |
+
|
304 |
+
part_st = 0
|
305 |
+
last_unk = None
|
306 |
+
while part_st < len(text):
|
307 |
+
piece = self.get_piece(text[part_st:])
|
308 |
+
if piece in self.encoder or self.added_tokens_encoder:
|
309 |
+
if last_unk is None:
|
310 |
+
output_tokens.append(piece)
|
311 |
+
else:
|
312 |
+
if for_cpmbee and (last_unk not in self.ext_table_rev):
|
313 |
+
self.ext_table_rev[last_unk] = len(self.ext_table_rev) + self.vocab_size
|
314 |
+
self.ext_table[self.ext_table_rev[last_unk]] = last_unk
|
315 |
+
output_tokens.append(last_unk)
|
316 |
+
output_tokens.append(piece)
|
317 |
+
last_unk = None
|
318 |
+
else:
|
319 |
+
if last_unk is None:
|
320 |
+
last_unk = piece
|
321 |
+
else:
|
322 |
+
last_unk += piece
|
323 |
+
part_st += len(piece)
|
324 |
+
if last_unk is not None:
|
325 |
+
# part end with UNK
|
326 |
+
if for_cpmbee and (last_unk not in self.ext_table_rev):
|
327 |
+
self.ext_table_rev[last_unk] = len(self.ext_table_rev) + self.vocab_size
|
328 |
+
self.ext_table[self.ext_table_rev[last_unk]] = last_unk
|
329 |
+
output_tokens.append(last_unk)
|
330 |
+
|
331 |
+
return output_tokens
|
332 |
+
|
333 |
+
def check(self, token):
|
334 |
+
return token in self.encoder
|
335 |
+
|
336 |
+
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
337 |
+
return "".join(tokens)
|
338 |
+
|
339 |
+
def _convert_token_to_id(self, token: str):
|
340 |
+
"""Converts a token (str) in an id using the vocab and ext_table."""
|
341 |
+
if token in self.encoder:
|
342 |
+
return self.encoder.get(token)
|
343 |
+
elif token in self.ext_table_rev:
|
344 |
+
return self.ext_table_rev[token]
|
345 |
+
elif token in self.added_tokens_encoder:
|
346 |
+
return self.added_tokens_encoder[token]
|
347 |
+
else:
|
348 |
+
return self.unk_token_id
|
349 |
+
|
350 |
+
def _convert_id_to_token(self, index):
|
351 |
+
"""Converts an index (integer) in a token (str) using the vocab and ext_table."""
|
352 |
+
if index in self.ext_table:
|
353 |
+
return self.ext_table[index]
|
354 |
+
elif index in self.added_tokens_decoder:
|
355 |
+
return self.added_tokens_decoder[index]
|
356 |
+
else:
|
357 |
+
if index >= 0:
|
358 |
+
return self.decoder[index]
|
359 |
+
|
360 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
361 |
+
if os.path.isdir(save_directory):
|
362 |
+
vocab_file = os.path.join(
|
363 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
364 |
+
)
|
365 |
+
else:
|
366 |
+
vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
|
367 |
+
index = 0
|
368 |
+
self.encoder["</n>"] = self.encoder["\n"]
|
369 |
+
del self.encoder["\n"]
|
370 |
+
self.encoder["</_>"] = self.encoder[" "]
|
371 |
+
del self.encoder[" "]
|
372 |
+
with open(vocab_file, "w", encoding="utf-8") as writer:
|
373 |
+
for token, token_index in sorted(self.encoder.items(), key=lambda x: x[1]):
|
374 |
+
if index != token_index:
|
375 |
+
logger.warning(
|
376 |
+
f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
|
377 |
+
" Please check that the vocabulary is not corrupted!"
|
378 |
+
)
|
379 |
+
index = token_index
|
380 |
+
writer.write(token + "\n")
|
381 |
+
index += 1
|
382 |
+
return (vocab_file,)
|
383 |
+
|
384 |
+
def __call__(self, text, *args, **kwargs):
|
385 |
+
r"""
|
386 |
+
CPMBee `call` method will use `_tokenize_cpmbee` when the input type is dict.
|
387 |
+
"""
|
388 |
+
if isinstance(text, dict):
|
389 |
+
return self._batch_tokenize_cpmbee([text], *args, **kwargs)
|
390 |
+
elif isinstance(text, (list, tuple)):
|
391 |
+
if isinstance(text[0], dict):
|
392 |
+
return self._batch_tokenize_cpmbee(text, *args, **kwargs)
|
393 |
+
else:
|
394 |
+
return super().__call__(text, *args, **kwargs)
|
395 |
+
else:
|
396 |
+
return super().__call__(text, *args, **kwargs)
|
397 |
+
|
398 |
+
# 分词
|
399 |
+
def _tokenize_cpmbee(self, data: TextInput, *args, **kwargs) -> List[str]:
|
400 |
+
"""
|
401 |
+
A tokenize method to process dict data. Exclusive for CPMBee.
|
402 |
+
"""
|
403 |
+
if isinstance(data, str):
|
404 |
+
data = json.loads(data)
|
405 |
+
if not isinstance(data, Dict):
|
406 |
+
raise TypeError(
|
407 |
+
"CpmBeeTokenizer input data should be dict or str in dict format, but got {}".format(type(data))
|
408 |
+
)
|
409 |
+
|
410 |
+
# 1. prepare answer placeholder
|
411 |
+
answer_placeholders = []
|
412 |
+
|
413 |
+
def _put_placeholder(data: Any, path: List[str] = []):
|
414 |
+
if isinstance(data, dict):
|
415 |
+
ret = {}
|
416 |
+
for k, v in data.items():
|
417 |
+
ret[k] = _put_placeholder(v, path + [k])
|
418 |
+
return ret
|
419 |
+
else:
|
420 |
+
answer_placeholders.append(path)
|
421 |
+
return "<ans_{}>".format(len(answer_placeholders))
|
422 |
+
|
423 |
+
data["<ans>"] = _put_placeholder(data["<ans>"])
|
424 |
+
|
425 |
+
(
|
426 |
+
input_ids,
|
427 |
+
input_id_subs,
|
428 |
+
context,
|
429 |
+
segment_ids,
|
430 |
+
segment_rel,
|
431 |
+
n_segments,
|
432 |
+
table_states,
|
433 |
+
image_bound
|
434 |
+
) = self.convert_data_to_id(data, shuffle_answer=False, max_depth=8)
|
435 |
+
|
436 |
+
# <ans> mapping from sub to id
|
437 |
+
sub_ans_map: Dict[int, int] = {}
|
438 |
+
for fake_id, token_sub in table_states["token_id_table"]["<ans>"].items():
|
439 |
+
token = table_states["ext_table"][fake_id]
|
440 |
+
if token.startswith("<ans_") and token.endswith(">"):
|
441 |
+
ans_id = int(token[5:-1])
|
442 |
+
sub_ans_map[token_sub] = ans_id
|
443 |
+
|
444 |
+
tmp_input_ids = []
|
445 |
+
tmp_input_sub = []
|
446 |
+
tmp_input_seg = []
|
447 |
+
|
448 |
+
# get predict segments
|
449 |
+
predict_segments: List[Tuple[int, int]] = []
|
450 |
+
for i in range(input_ids.shape[0]):
|
451 |
+
if context[i] == 0:
|
452 |
+
if input_ids[i] == self.encoder["<ans>"]:
|
453 |
+
# is ans
|
454 |
+
# (segment_id, ans_id)
|
455 |
+
predict_segments.append((segment_ids[i], sub_ans_map[input_id_subs[i]]))
|
456 |
+
else:
|
457 |
+
tmp_input_ids.append(input_ids[i])
|
458 |
+
tmp_input_sub.append(input_id_subs[i])
|
459 |
+
tmp_input_seg.append(segment_ids[i])
|
460 |
+
|
461 |
+
if len(predict_segments) == 0:
|
462 |
+
raise ValueError("No answer to predict")
|
463 |
+
|
464 |
+
input_ids = np.array(tmp_input_ids, dtype=np.int32) # all context
|
465 |
+
input_id_subs = np.array(tmp_input_sub, dtype=np.int32) # [0, 0, 0, 0, 1, 0, 0, 2, 0, ...]
|
466 |
+
context = np.full_like(tmp_input_ids, 1, dtype=np.int8) # [1, 1, 1, ...]
|
467 |
+
segment_ids = np.array(tmp_input_seg, dtype=np.int32) # [0, 0, 0, 1, 1, 1, 2, 2, 2, 2, ...]
|
468 |
+
sample_ids = np.zeros(input_ids.shape, dtype=np.int32) # [0, 0, 0, 0, ...]
|
469 |
+
segment_rel_offset = np.zeros(input_ids.shape, dtype=np.int32) # [0, 0, 0, ...]
|
470 |
+
num_segments = np.full(input_ids.shape, n_segments, dtype=np.int32) # [n_seg, n_seg, n_seg, ...]
|
471 |
+
input_pos = np.arange(input_ids.shape[0], dtype=np.int32) # [0, 1, 2, 3, 4, ...]
|
472 |
+
image_bound = np.array(image_bound)
|
473 |
+
|
474 |
+
return (
|
475 |
+
self.prepare_for_model(
|
476 |
+
input_ids.tolist(),
|
477 |
+
input_id_subs=input_id_subs.tolist(),
|
478 |
+
input_pos=input_pos.tolist(),
|
479 |
+
context=context.tolist(),
|
480 |
+
segment_ids=segment_ids.tolist(),
|
481 |
+
segment_rel_offset=segment_rel_offset.tolist(),
|
482 |
+
segment_rel=segment_rel.tolist(),
|
483 |
+
sample_ids=sample_ids.tolist(),
|
484 |
+
num_segments=num_segments.tolist(),
|
485 |
+
image_bound=image_bound,
|
486 |
+
**kwargs,
|
487 |
+
),
|
488 |
+
predict_segments,
|
489 |
+
answer_placeholders,
|
490 |
+
table_states["ext_table"],
|
491 |
+
table_states["token_id_table"],
|
492 |
+
)
|
493 |
+
|
494 |
+
def _batch_tokenize_cpmbee(self, data_lst, *args, **kwargs):
|
495 |
+
"""
|
496 |
+
Batched _token_cpmbee.
|
497 |
+
"""
|
498 |
+
device = kwargs.get("device", "cpu")
|
499 |
+
return_tensors = kwargs.get("return_tensors", None)
|
500 |
+
batch_outputs = {}
|
501 |
+
segment_rel_pack = []
|
502 |
+
other_info = []
|
503 |
+
|
504 |
+
batch_ext_table_map: Dict[Tuple[int, int], int] = {}
|
505 |
+
batch_ext_table_ids: List[int] = []
|
506 |
+
batch_ext_table_sub: List[int] = []
|
507 |
+
|
508 |
+
for data in data_lst:
|
509 |
+
self.ext_table = {}
|
510 |
+
self.ext_table_rev = {}
|
511 |
+
self.token_id_table = {}
|
512 |
+
(outputs, predict_segments, answer_placeholders, ext_table, token_id_table) = self._tokenize_cpmbee(
|
513 |
+
data,
|
514 |
+
truncation=None,
|
515 |
+
padding=PaddingStrategy.DO_NOT_PAD.value,
|
516 |
+
max_length=None,
|
517 |
+
pad_to_multiple_of=None,
|
518 |
+
return_attention_mask=False,
|
519 |
+
return_tensors=None,
|
520 |
+
)
|
521 |
+
rev_ext_table = {}
|
522 |
+
for token, mp in token_id_table.items():
|
523 |
+
if token == "<ans>":
|
524 |
+
continue
|
525 |
+
token_id = self.encoder[token]
|
526 |
+
for fake_id, token_sub in mp.items():
|
527 |
+
if token_sub > 0:
|
528 |
+
if (token_id, token_sub) not in batch_ext_table_map:
|
529 |
+
batch_ext_table_map[(token_id, token_sub)] = len(batch_ext_table_ids) + self.vocab_size
|
530 |
+
batch_ext_table_ids.append(token_id)
|
531 |
+
batch_ext_table_sub.append(token_sub)
|
532 |
+
rev_ext_table[batch_ext_table_map[(token_id, token_sub)]] = ext_table[fake_id]
|
533 |
+
else:
|
534 |
+
rev_ext_table[token_id] = ext_table[fake_id]
|
535 |
+
|
536 |
+
segment_rel_pack.append(np.array(outputs.pop("segment_rel")))
|
537 |
+
other_info.append(
|
538 |
+
{
|
539 |
+
"predict_segments": predict_segments,
|
540 |
+
"answer_placeholders": answer_placeholders,
|
541 |
+
"ext_table": rev_ext_table,
|
542 |
+
}
|
543 |
+
)
|
544 |
+
|
545 |
+
for key, value in outputs.items():
|
546 |
+
if key not in batch_outputs:
|
547 |
+
batch_outputs[key] = []
|
548 |
+
batch_outputs[key].append(value)
|
549 |
+
|
550 |
+
max_length = max([len(item) for item in batch_outputs[self.model_input_names[0]]])
|
551 |
+
batch_size = len(batch_outputs[self.model_input_names[0]])
|
552 |
+
for i in range(batch_size):
|
553 |
+
inputs = {k: v[i] for k, v in batch_outputs.items()}
|
554 |
+
|
555 |
+
for k, v in inputs.items():
|
556 |
+
required_input = v
|
557 |
+
|
558 |
+
needs_to_be_padded = len(required_input) != max_length and k != 'image_bound'
|
559 |
+
|
560 |
+
if needs_to_be_padded:
|
561 |
+
difference = max_length - len(required_input)
|
562 |
+
batch_outputs[k][i] = [self.pad_token_id] * difference + required_input
|
563 |
+
|
564 |
+
max_num_rels = 0
|
565 |
+
for rel in segment_rel_pack:
|
566 |
+
max_num_rels = max(max_num_rels, rel.shape[0])
|
567 |
+
padded_rels = np.zeros((len(segment_rel_pack), max_num_rels), dtype=np.int32)
|
568 |
+
for i, rel in enumerate(segment_rel_pack):
|
569 |
+
padded_rels[i, : rel.shape[0]] = rel
|
570 |
+
batch_outputs["segment_rel"] = padded_rels
|
571 |
+
batch_outputs["batch_ext_table_ids"] = np.array(batch_ext_table_ids, dtype=np.int32)
|
572 |
+
batch_outputs["batch_ext_table_sub"] = np.array(batch_ext_table_sub, dtype=np.int32)
|
573 |
+
batch_outputs = BatchEncoding(batch_outputs, tensor_type=return_tensors)
|
574 |
+
if return_tensors == "pt":
|
575 |
+
batch_outputs = batch_outputs.to(device=device)
|
576 |
+
batch_outputs["other_info"] = other_info
|
577 |
+
|
578 |
+
return batch_outputs
|
579 |
+
|
580 |
+
def convert_data_to_id(
|
581 |
+
self,
|
582 |
+
data: Any,
|
583 |
+
prev_ext_states: Optional[_PrevExtTableStates] = None,
|
584 |
+
shuffle_answer: bool = True,
|
585 |
+
max_depth: int = 8,
|
586 |
+
):
|
587 |
+
"""
|
588 |
+
Parse a dict to data ids. Exclusive for CPMBee. It will
|
589 |
+
1. parse the dict to segments and get segment_rel, which for calculating of position_bias.
|
590 |
+
2. tokenize every segment.
|
591 |
+
"""
|
592 |
+
root: _DictTree = {
|
593 |
+
"value": "<root>",
|
594 |
+
"children": [],
|
595 |
+
"depth": 0,
|
596 |
+
"segment_id": 0,
|
597 |
+
"need_predict": False,
|
598 |
+
"is_image": False
|
599 |
+
}
|
600 |
+
|
601 |
+
segments = [root]
|
602 |
+
|
603 |
+
def _build_dict_tree(data: CPMBeeInputType, depth: int, need_predict: bool, is_image: bool) -> List[_DictTree]:
|
604 |
+
if isinstance(data, dict):
|
605 |
+
ret_list: List[_DictTree] = []
|
606 |
+
curr_items = list(data.items())
|
607 |
+
if need_predict and shuffle_answer:
|
608 |
+
access_idx = np.arange(len(curr_items))
|
609 |
+
np.random.shuffle(access_idx)
|
610 |
+
curr_items = [curr_items[idx] for idx in access_idx]
|
611 |
+
for k, v in curr_items:
|
612 |
+
child_info: _DictTree = {
|
613 |
+
"value": k,
|
614 |
+
"children": [],
|
615 |
+
"depth": depth,
|
616 |
+
"segment_id": len(segments),
|
617 |
+
"need_predict": False, # only leaves are contexts
|
618 |
+
"is_image": False,
|
619 |
+
}
|
620 |
+
segments.append(child_info)
|
621 |
+
child_info["children"] = _build_dict_tree(
|
622 |
+
v, depth + 1,
|
623 |
+
need_predict=need_predict or (depth == 1 and k == "<ans>"),
|
624 |
+
is_image=is_image or (depth == 1 and k == "image")
|
625 |
+
) # elements in <root>.<ans>
|
626 |
+
|
627 |
+
ret_list.append(child_info)
|
628 |
+
return ret_list
|
629 |
+
else:
|
630 |
+
assert isinstance(data, str), "Invalid data {}".format(data)
|
631 |
+
ret: _DictTree = {
|
632 |
+
"value": data,
|
633 |
+
"children": [],
|
634 |
+
"depth": depth,
|
635 |
+
"segment_id": len(segments),
|
636 |
+
"need_predict": need_predict,
|
637 |
+
"is_image": is_image,
|
638 |
+
}
|
639 |
+
segments.append(ret)
|
640 |
+
return [ret]
|
641 |
+
|
642 |
+
root["children"] = _build_dict_tree(data, 1, False, False)
|
643 |
+
|
644 |
+
num_segments = len(segments)
|
645 |
+
segment_rel = np.zeros((num_segments * num_segments,), dtype=np.int32)
|
646 |
+
|
647 |
+
def _build_segment_rel(node: _DictTree) -> List[Tuple[int, int]]:
|
648 |
+
ret: List[Tuple[int, int]] = [(node["segment_id"], node["depth"])]
|
649 |
+
for child in node["children"]:
|
650 |
+
sub = _build_segment_rel(child)
|
651 |
+
for seg_id_1, depth_1 in sub:
|
652 |
+
for seg_id_2, depth_2 in ret:
|
653 |
+
n_up = min(depth_1 - node["depth"], max_depth - 1)
|
654 |
+
n_down = min(depth_2 - node["depth"], max_depth - 1)
|
655 |
+
segment_rel[seg_id_1 * num_segments + seg_id_2] = rel_to_bucket(
|
656 |
+
n_up, n_down, max_depth=max_depth
|
657 |
+
)
|
658 |
+
segment_rel[seg_id_2 * num_segments + seg_id_1] = rel_to_bucket(
|
659 |
+
n_down, n_up, max_depth=max_depth
|
660 |
+
)
|
661 |
+
ret.extend(sub)
|
662 |
+
return ret
|
663 |
+
|
664 |
+
_build_segment_rel(root)
|
665 |
+
|
666 |
+
input_ids: List[int] = []
|
667 |
+
input_id_subs: List[int] = []
|
668 |
+
segment_bound: List[Tuple[int, int]] = []
|
669 |
+
image_bound: List[Tuple[int, int]] = []
|
670 |
+
|
671 |
+
|
672 |
+
if prev_ext_states is not None:
|
673 |
+
self.ext_table = prev_ext_states["ext_table"]
|
674 |
+
self.token_id_table = prev_ext_states["token_id_table"]
|
675 |
+
|
676 |
+
for seg in segments:
|
677 |
+
# tokenize
|
678 |
+
tokens = self.convert_tokens_to_ids(self.tokenize(seg["value"], for_cpmbee=True))
|
679 |
+
|
680 |
+
token_id_subs = []
|
681 |
+
reid_token_ids = []
|
682 |
+
for idx in tokens:
|
683 |
+
if idx in self.ext_table:
|
684 |
+
# unk or special token
|
685 |
+
token = self.ext_table[idx]
|
686 |
+
if token.startswith("<") and token.endswith(">"):
|
687 |
+
# special token
|
688 |
+
if "_" in token:
|
689 |
+
token_name = token[1:-1].split("_", maxsplit=1)[0]
|
690 |
+
else:
|
691 |
+
token_name = token[1:-1]
|
692 |
+
token_name = "<{}>".format(token_name)
|
693 |
+
else:
|
694 |
+
token_name = "<unk>"
|
695 |
+
|
696 |
+
if token_name not in self.token_id_table:
|
697 |
+
self.token_id_table[token_name] = {}
|
698 |
+
if idx not in self.token_id_table[token_name]:
|
699 |
+
self.token_id_table[token_name][idx] = len(self.token_id_table[token_name])
|
700 |
+
if token_name not in self.encoder:
|
701 |
+
raise ValueError("Invalid token {}".format(token))
|
702 |
+
reid_token_ids.append(self.encoder[token_name])
|
703 |
+
token_id_subs.append(self.token_id_table[token_name][idx])
|
704 |
+
else:
|
705 |
+
reid_token_ids.append(idx)
|
706 |
+
token_id_subs.append(0)
|
707 |
+
tokens = [self.bos_token_id] + reid_token_ids
|
708 |
+
token_id_subs = [0] + token_id_subs
|
709 |
+
# eos_id 表示 no need_predict
|
710 |
+
if not seg["need_predict"]: # eos
|
711 |
+
tokens = tokens + [self.eos_token_id]
|
712 |
+
token_id_subs = token_id_subs + [0]
|
713 |
+
else:
|
714 |
+
# no eos
|
715 |
+
pass
|
716 |
+
begin = len(input_ids)
|
717 |
+
input_ids.extend(tokens)
|
718 |
+
input_id_subs.extend(token_id_subs)
|
719 |
+
end = len(input_ids)
|
720 |
+
segment_bound.append((begin, end))
|
721 |
+
|
722 |
+
ids = np.array(input_ids, dtype=np.int32)
|
723 |
+
id_subs = np.array(input_id_subs, dtype=np.int32)
|
724 |
+
segs = np.zeros((ids.shape[0],), dtype=np.int32) # 按segment_bound对seg编号
|
725 |
+
context = np.zeros((ids.shape[0],), dtype=np.int8)
|
726 |
+
for i, (begin, end) in enumerate(segment_bound):
|
727 |
+
if not segments[i]["need_predict"]:
|
728 |
+
context[begin:end] = 1
|
729 |
+
if segments[i]["is_image"]:
|
730 |
+
image_bound.append((begin + 1, end - 1))
|
731 |
+
segs[begin:end] = i
|
732 |
+
|
733 |
+
curr_ext_table_states: _PrevExtTableStates = {
|
734 |
+
"ext_table": self.ext_table,
|
735 |
+
"token_id_table": self.token_id_table,
|
736 |
+
}
|
737 |
+
image_bound = np.array(image_bound, dtype=np.int32)
|
738 |
+
return ids, id_subs, context, segs, segment_rel, num_segments, curr_ext_table_states, image_bound
|
739 |
+
|
740 |
+
def prepare_for_model(
|
741 |
+
self,
|
742 |
+
ids: List[int],
|
743 |
+
pair_ids: Optional[List[int]] = None,
|
744 |
+
add_special_tokens: bool = True,
|
745 |
+
padding: Union[bool, str, PaddingStrategy] = False,
|
746 |
+
truncation: Union[bool, str, TruncationStrategy] = None,
|
747 |
+
max_length: Optional[int] = None,
|
748 |
+
stride: int = 0,
|
749 |
+
pad_to_multiple_of: Optional[int] = None,
|
750 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
751 |
+
return_token_type_ids: Optional[bool] = None,
|
752 |
+
return_attention_mask: Optional[bool] = None,
|
753 |
+
return_overflowing_tokens: bool = False,
|
754 |
+
return_special_tokens_mask: bool = False,
|
755 |
+
return_length: bool = False,
|
756 |
+
verbose: bool = True,
|
757 |
+
prepend_batch_axis: bool = False,
|
758 |
+
**kwargs,
|
759 |
+
) -> BatchEncoding:
|
760 |
+
"""
|
761 |
+
Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It
|
762 |
+
adds special tokens, truncates sequences if overflowing while taking into account the special tokens and
|
763 |
+
manages a moving window (with user defined stride) for overflowing tokens. Please Note, for *pair_ids*
|
764 |
+
different than `None` and *truncation_strategy = longest_first* or `True`, it is not possible to return
|
765 |
+
overflowing tokens. Such a combination of arguments will raise an error.
|
766 |
+
|
767 |
+
Args:
|
768 |
+
ids (`List[int]`):
|
769 |
+
Tokenized input ids of the first sequence. Can be obtained from a string by chaining the `tokenize` and
|
770 |
+
`convert_tokens_to_ids` methods.
|
771 |
+
pair_ids (`List[int]`, *optional*):
|
772 |
+
Tokenized input ids of the second sequence. Can be obtained from a string by chaining the `tokenize`
|
773 |
+
and `convert_tokens_to_ids` methods.
|
774 |
+
"""
|
775 |
+
|
776 |
+
# Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
|
777 |
+
padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
|
778 |
+
padding=padding,
|
779 |
+
truncation=truncation,
|
780 |
+
max_length=max_length,
|
781 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
782 |
+
verbose=verbose,
|
783 |
+
**kwargs,
|
784 |
+
)
|
785 |
+
|
786 |
+
pair = bool(pair_ids is not None)
|
787 |
+
len_ids = len(ids)
|
788 |
+
len_pair_ids = len(pair_ids) if pair else 0
|
789 |
+
|
790 |
+
if return_token_type_ids and not add_special_tokens:
|
791 |
+
raise ValueError(
|
792 |
+
"Asking to return token_type_ids while setting add_special_tokens to False "
|
793 |
+
"results in an undefined behavior. Please set add_special_tokens to True or "
|
794 |
+
"set return_token_type_ids to None."
|
795 |
+
)
|
796 |
+
|
797 |
+
if (
|
798 |
+
return_overflowing_tokens
|
799 |
+
and truncation_strategy == TruncationStrategy.LONGEST_FIRST
|
800 |
+
and pair_ids is not None
|
801 |
+
):
|
802 |
+
raise ValueError(
|
803 |
+
"Not possible to return overflowing tokens for pair of sequences with the "
|
804 |
+
"`longest_first`. Please select another truncation strategy than `longest_first`, "
|
805 |
+
"for instance `only_second` or `only_first`."
|
806 |
+
)
|
807 |
+
|
808 |
+
# Load from model defaults
|
809 |
+
if return_token_type_ids is None:
|
810 |
+
return_token_type_ids = "token_type_ids" in self.model_input_names
|
811 |
+
if return_attention_mask is None:
|
812 |
+
return_attention_mask = "attention_mask" in self.model_input_names
|
813 |
+
|
814 |
+
encoded_inputs = {}
|
815 |
+
|
816 |
+
# Compute the total size of the returned encodings
|
817 |
+
total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0)
|
818 |
+
|
819 |
+
# Truncation: Handle max sequence length
|
820 |
+
overflowing_tokens = []
|
821 |
+
if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length:
|
822 |
+
ids, pair_ids, overflowing_tokens = self.truncate_sequences(
|
823 |
+
ids,
|
824 |
+
pair_ids=pair_ids,
|
825 |
+
num_tokens_to_remove=total_len - max_length,
|
826 |
+
truncation_strategy=truncation_strategy,
|
827 |
+
stride=stride,
|
828 |
+
)
|
829 |
+
|
830 |
+
if return_overflowing_tokens:
|
831 |
+
encoded_inputs["overflowing_tokens"] = overflowing_tokens
|
832 |
+
encoded_inputs["num_truncated_tokens"] = total_len - max_length
|
833 |
+
|
834 |
+
# Add special tokens
|
835 |
+
if add_special_tokens:
|
836 |
+
sequence = self.build_inputs_with_special_tokens(ids, pair_ids)
|
837 |
+
token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids)
|
838 |
+
else:
|
839 |
+
sequence = ids + pair_ids if pair else ids
|
840 |
+
token_type_ids = [0] * len(ids) + ([0] * len(pair_ids) if pair else [])
|
841 |
+
|
842 |
+
# Build output dictionary
|
843 |
+
encoded_inputs["input_ids"] = sequence
|
844 |
+
if return_token_type_ids:
|
845 |
+
encoded_inputs["token_type_ids"] = token_type_ids
|
846 |
+
if return_special_tokens_mask:
|
847 |
+
if add_special_tokens:
|
848 |
+
encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids)
|
849 |
+
else:
|
850 |
+
encoded_inputs["special_tokens_mask"] = [0] * len(sequence)
|
851 |
+
|
852 |
+
# Check lengths
|
853 |
+
self._eventual_warn_about_too_long_sequence(encoded_inputs["input_ids"], max_length, verbose)
|
854 |
+
|
855 |
+
# Padding
|
856 |
+
if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask:
|
857 |
+
encoded_inputs = self.pad(
|
858 |
+
encoded_inputs,
|
859 |
+
max_length=max_length,
|
860 |
+
padding=padding_strategy.value,
|
861 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
862 |
+
return_attention_mask=return_attention_mask,
|
863 |
+
)
|
864 |
+
|
865 |
+
if return_length:
|
866 |
+
encoded_inputs["length"] = len(encoded_inputs["input_ids"])
|
867 |
+
|
868 |
+
# for CPMBee, encode all the model arguments
|
869 |
+
for arg in self.ext_args_for_model:
|
870 |
+
v = kwargs.get(arg, None)
|
871 |
+
if v is not None:
|
872 |
+
encoded_inputs[arg] = v
|
873 |
+
|
874 |
+
batch_outputs = BatchEncoding(
|
875 |
+
encoded_inputs, tensor_type=return_tensors, prepend_batch_axis=prepend_batch_axis
|
876 |
+
)
|
877 |
+
|
878 |
+
return batch_outputs
|
879 |
+
|
880 |
+
def prepare_for_finetune(
|
881 |
+
self,
|
882 |
+
data_list: List[Dict],
|
883 |
+
max_length: int = 2048
|
884 |
+
):
|
885 |
+
_inputs: List[NDArray[np.int32]] = []
|
886 |
+
_inputs_sub: List[NDArray[np.int32]] = []
|
887 |
+
_context: List[NDArray[np.int8]] = []
|
888 |
+
_sample_ids: List[NDArray[np.int32]] = []
|
889 |
+
_segments: List[NDArray[np.int32]] = []
|
890 |
+
_num_segments: List[NDArray[np.int32]] = []
|
891 |
+
_segment_rel_offset: List[NDArray[np.int32]] = []
|
892 |
+
_segment_rel: List[NDArray[np.int32]] = []
|
893 |
+
_spans: List[List[int]] = []
|
894 |
+
_raw_data: List[List[Any]] = []
|
895 |
+
|
896 |
+
raw_data = {}
|
897 |
+
for data in data_list:
|
898 |
+
(
|
899 |
+
input_ids,
|
900 |
+
input_id_subs,
|
901 |
+
context,
|
902 |
+
segment_ids,
|
903 |
+
segment_rel,
|
904 |
+
n_segments,
|
905 |
+
_
|
906 |
+
) = self.convert_data_to_id(data)
|
907 |
+
|
908 |
+
input_ids = input_ids[: max_length]
|
909 |
+
context = context[: max_length]
|
910 |
+
segment_ids = segment_ids[: max_length]
|
911 |
+
raw_data["input"] = data
|
912 |
+
raw_data["samples"] = []
|
913 |
+
|
914 |
+
sample_ids = np.zeros(input_ids.shape, dtype=np.int32)
|
915 |
+
segment_rel_offset = np.zeros(input_ids.shape, dtype=np.int32)
|
916 |
+
num_segments = np.full(input_ids.shape, n_segments, dtype=np.int32)
|
917 |
+
|
918 |
+
_inputs.append(input_ids)
|
919 |
+
_inputs_sub.append(input_id_subs)
|
920 |
+
_context.append(context)
|
921 |
+
_sample_ids.append(sample_ids)
|
922 |
+
_segments.append(segment_ids)
|
923 |
+
_num_segments.append(num_segments)
|
924 |
+
_segment_rel_offset.append(segment_rel_offset)
|
925 |
+
_segment_rel.append(segment_rel)
|
926 |
+
_spans.append([input_ids.shape[0]])
|
927 |
+
_raw_data.append([raw_data])
|
928 |
+
|
929 |
+
batch_size = len(_inputs)
|
930 |
+
inputs = np.zeros((batch_size, max_length), dtype=np.int32)
|
931 |
+
inputs_sub = np.zeros((batch_size, max_length), dtype=np.int32)
|
932 |
+
context = np.zeros((batch_size, max_length), dtype=np.int8)
|
933 |
+
sample_ids = np.zeros((batch_size, max_length), dtype=np.int32)
|
934 |
+
segments = np.zeros((batch_size, max_length), dtype=np.int32)
|
935 |
+
num_segments = np.zeros((batch_size, max_length), dtype=np.int32)
|
936 |
+
segment_rel_offset = np.zeros((batch_size, max_length), dtype=np.int32)
|
937 |
+
tgt = np.full((batch_size, max_length), -100, dtype=np.int32)
|
938 |
+
|
939 |
+
max_rel = 0
|
940 |
+
for i in range(batch_size):
|
941 |
+
max_rel = max(max_rel, _segment_rel[i].shape[0])
|
942 |
+
segment_rel = np.zeros((batch_size, max_rel), dtype=np.int32)
|
943 |
+
spans = np.zeros((batch_size, max_length), dtype=np.int32)
|
944 |
+
length = np.zeros((batch_size,), dtype=np.int32)
|
945 |
+
|
946 |
+
batch_ext_table_map: Dict[Tuple[int, int], int] = {}
|
947 |
+
batch_ext_table_ids: List[int] = []
|
948 |
+
batch_ext_table_sub: List[int] = []
|
949 |
+
raw_data_list: List[Any] = []
|
950 |
+
|
951 |
+
for i in range(batch_size):
|
952 |
+
instance_length = _inputs[i].shape[0]
|
953 |
+
rel_size = _segment_rel[i].shape[0]
|
954 |
+
inputs[i, :instance_length] = _inputs[i]
|
955 |
+
inputs_sub[i, :instance_length] = _inputs_sub[i]
|
956 |
+
context[i, :instance_length] = _context[i]
|
957 |
+
sample_ids[i, :instance_length] = _sample_ids[i]
|
958 |
+
segments[i, :instance_length] = _segments[i]
|
959 |
+
num_segments[i, :instance_length] = _num_segments[i]
|
960 |
+
segment_rel_offset[i, :instance_length] = _segment_rel_offset[i]
|
961 |
+
segment_rel[i, :rel_size] = _segment_rel[i]
|
962 |
+
|
963 |
+
span_begin = 0
|
964 |
+
for span_id, span_end in enumerate(_spans[i]):
|
965 |
+
spans[i, span_begin:span_end] = span_id
|
966 |
+
span_begin = span_end
|
967 |
+
length[i] = instance_length
|
968 |
+
raw_data_list.extend(_raw_data[i])
|
969 |
+
|
970 |
+
for j in range(instance_length):
|
971 |
+
idx, idx_sub = _inputs[i][j], _inputs_sub[i][j]
|
972 |
+
tgt_idx = idx
|
973 |
+
if idx_sub > 0:
|
974 |
+
# need to be in ext table
|
975 |
+
if (idx, idx_sub) not in batch_ext_table_map:
|
976 |
+
batch_ext_table_map[(idx, idx_sub)] = len(batch_ext_table_map)
|
977 |
+
batch_ext_table_ids.append(idx)
|
978 |
+
batch_ext_table_sub.append(idx_sub)
|
979 |
+
tgt_idx = batch_ext_table_map[(idx, idx_sub)] + self.vocab_size
|
980 |
+
if j > 1 and context[i, j - 1] == 0:
|
981 |
+
if idx != self.bos_token_id:
|
982 |
+
tgt[i, j - 1] = tgt_idx
|
983 |
+
else:
|
984 |
+
tgt[i, j - 1] = self.eos_token_id
|
985 |
+
if context[i, instance_length - 1] == 0:
|
986 |
+
tgt[i, instance_length - 1] = self.eos_token_id
|
987 |
+
|
988 |
+
if len(batch_ext_table_map) == 0:
|
989 |
+
# placeholder
|
990 |
+
batch_ext_table_ids.append(0)
|
991 |
+
batch_ext_table_sub.append(1)
|
992 |
+
|
993 |
+
return BatchEncoding({
|
994 |
+
"input_ids": inputs,
|
995 |
+
"input_id_sub": inputs_sub,
|
996 |
+
"length": length,
|
997 |
+
"context": context > 0,
|
998 |
+
"sample_ids": sample_ids,
|
999 |
+
"num_segments": num_segments,
|
1000 |
+
"segment": segments,
|
1001 |
+
"segment_rel_offset": segment_rel_offset,
|
1002 |
+
"segment_rel": segment_rel,
|
1003 |
+
"span": spans,
|
1004 |
+
"labels": tgt,
|
1005 |
+
"ext_table_ids": np.array(batch_ext_table_ids, dtype=np.int32),
|
1006 |
+
"ext_table_sub": np.array(batch_ext_table_sub, dtype=np.int32)
|
1007 |
+
}, tensor_type="pt")
|
1008 |
+
|
tokenizer_config.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"name_or_path": "openbmb/cpm-bee-10b",
|
3 |
+
"tokenizer_class": "CpmBeeTokenizer",
|
4 |
+
"auto_map": {
|
5 |
+
"AutoTokenizer": [
|
6 |
+
"tokenization_viscpmbee.VisCpmBeeTokenizer",
|
7 |
+
null
|
8 |
+
]
|
9 |
+
}
|
10 |
+
}
|
unet/config.json
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_class_name": "UNet2DConditionModel",
|
3 |
+
"_diffusers_version": "0.10.0.dev0",
|
4 |
+
"act_fn": "silu",
|
5 |
+
"attention_head_dim": [
|
6 |
+
5,
|
7 |
+
10,
|
8 |
+
20,
|
9 |
+
20
|
10 |
+
],
|
11 |
+
"block_out_channels": [
|
12 |
+
320,
|
13 |
+
640,
|
14 |
+
1280,
|
15 |
+
1280
|
16 |
+
],
|
17 |
+
"center_input_sample": false,
|
18 |
+
"cross_attention_dim": 1024,
|
19 |
+
"down_block_types": [
|
20 |
+
"CrossAttnDownBlock2D",
|
21 |
+
"CrossAttnDownBlock2D",
|
22 |
+
"CrossAttnDownBlock2D",
|
23 |
+
"DownBlock2D"
|
24 |
+
],
|
25 |
+
"downsample_padding": 1,
|
26 |
+
"dual_cross_attention": false,
|
27 |
+
"flip_sin_to_cos": true,
|
28 |
+
"freq_shift": 0,
|
29 |
+
"in_channels": 4,
|
30 |
+
"layers_per_block": 2,
|
31 |
+
"mid_block_scale_factor": 1,
|
32 |
+
"norm_eps": 1e-05,
|
33 |
+
"norm_num_groups": 32,
|
34 |
+
"num_class_embeds": null,
|
35 |
+
"only_cross_attention": false,
|
36 |
+
"out_channels": 4,
|
37 |
+
"sample_size": 64,
|
38 |
+
"up_block_types": [
|
39 |
+
"UpBlock2D",
|
40 |
+
"CrossAttnUpBlock2D",
|
41 |
+
"CrossAttnUpBlock2D",
|
42 |
+
"CrossAttnUpBlock2D"
|
43 |
+
],
|
44 |
+
"use_linear_projection": true
|
45 |
+
}
|
vae/config.json
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_class_name": "AutoencoderKL",
|
3 |
+
"_diffusers_version": "0.8.0",
|
4 |
+
"act_fn": "silu",
|
5 |
+
"block_out_channels": [
|
6 |
+
128,
|
7 |
+
256,
|
8 |
+
512,
|
9 |
+
512
|
10 |
+
],
|
11 |
+
"down_block_types": [
|
12 |
+
"DownEncoderBlock2D",
|
13 |
+
"DownEncoderBlock2D",
|
14 |
+
"DownEncoderBlock2D",
|
15 |
+
"DownEncoderBlock2D"
|
16 |
+
],
|
17 |
+
"in_channels": 3,
|
18 |
+
"latent_channels": 4,
|
19 |
+
"layers_per_block": 2,
|
20 |
+
"norm_num_groups": 32,
|
21 |
+
"out_channels": 3,
|
22 |
+
"sample_size": 768,
|
23 |
+
"up_block_types": [
|
24 |
+
"UpDecoderBlock2D",
|
25 |
+
"UpDecoderBlock2D",
|
26 |
+
"UpDecoderBlock2D",
|
27 |
+
"UpDecoderBlock2D"
|
28 |
+
]
|
29 |
+
}
|
vocab.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|