pyx9913 commited on
Commit
4d32fc1
1 Parent(s): 93beeec

feat: 🎸 add paint model code

Browse files
README.md CHANGED
@@ -27,8 +27,32 @@ language:
27
 
28
  Similar to `VisCPM-Chat`, we found that due to the bilingual capability of `CPM-Bee`, `VisCPM-Paint` can achieve good Chinese text-to-image generation by training only on English text-image pairs, surpassing the performance of Chinese open-source models. By incorporating an additional 20M cleaned native Chinese text-image pairs and 120M translated text-image pairs in Chinese, the model's Chinese text-to-image generation ability can be further improved. We sample 30,000 images from the standard image generation test set MSCOCO and calculated commonly used evaluation metrics FID (Fréchet Inception Distance) to assess the quality of generated images. Similarly, we provide two versions of the model, namely `VisCPM-Paint-balance` and `VisCPM-Paint-zhplus`. The former has a balanced ability in both English and Chinese, while the latter emphasizes Chinese proficiency. `VisCPM-Paint-balance` is trained only using English text-image pairs, while `VisCPM-Paint-zhplus` incorporates an additional 20M native Chinese text-image pairs and 120M translated text-image pairs in Chinese based on `VisCPM-Paint-balance`.
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  ## 📝 License
31
 
32
  VisCPM is governed by the [GML License](https://github.com/OpenBMB/General-Model-License/blob/main/%E9%80%9A%E7%94%A8%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE-%E6%9D%A5%E6%BA%90%E8%AF%B4%E6%98%8E-%E5%AE%A3%E4%BC%A0%E9%99%90%E5%88%B6-%E9%9D%9E%E5%95%86%E4%B8%9A%E5%8C%96.md), and permits individual and research usages. If you intend to utilize the model for commercial purposes, please reach out to [email protected] to negotiate commercial licensing.
33
 
34
- The CPM-Bee base, governed by the [General Model License (GML)](https://github.com/OpenBMB/General-Model-License/blob/main/%E9%80%9A%E7%94%A8%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE-%E6%9D%A5%E6%BA%90%E8%AF%B4%E6%98%8E-%E5%AE%A3%E4%BC%A0%E9%99%90%E5%88%B6-%E5%95%86%E4%B8%9A%E6%8E%88%E6%9D%83.md), permits commercial usage. If you intend to utilize the model for commercial purposes, please reach out to [email protected] to obtain the certificate of authorization.
 
27
 
28
  Similar to `VisCPM-Chat`, we found that due to the bilingual capability of `CPM-Bee`, `VisCPM-Paint` can achieve good Chinese text-to-image generation by training only on English text-image pairs, surpassing the performance of Chinese open-source models. By incorporating an additional 20M cleaned native Chinese text-image pairs and 120M translated text-image pairs in Chinese, the model's Chinese text-to-image generation ability can be further improved. We sample 30,000 images from the standard image generation test set MSCOCO and calculated commonly used evaluation metrics FID (Fréchet Inception Distance) to assess the quality of generated images. Similarly, we provide two versions of the model, namely `VisCPM-Paint-balance` and `VisCPM-Paint-zhplus`. The former has a balanced ability in both English and Chinese, while the latter emphasizes Chinese proficiency. `VisCPM-Paint-balance` is trained only using English text-image pairs, while `VisCPM-Paint-zhplus` incorporates an additional 20M native Chinese text-image pairs and 120M translated text-image pairs in Chinese based on `VisCPM-Paint-balance`.
29
 
30
+
31
+ ## How to Use
32
+
33
+ ```python
34
+ #!/usr/bin/env python
35
+ # encoding: utf-8
36
+ from diffusers import DiffusionPipeline
37
+ from transformers import AutoModel
38
+ from transformers import AutoTokenizer
39
+
40
+
41
+ tokenizer = AutoTokenizer.from_pretrained('openbmb/VisCPM-Paint', trust_remote_code=True)
42
+ text_encoder = AutoModel.from_pretrained('openbmb/VisCPM-Paint', trust_remote_code=True)
43
+ print('load pipeline')
44
+ pipeline = DiffusionPipeline.from_pretrained('openbmb/VisCPM-Paint', custom_pipeline="pipeline_stable_diffusion.py", text_encoder=text_encoder, tokenizer=tokenizer)
45
+
46
+ pipeline = pipeline.to('cuda')
47
+
48
+ prompt = "a photo of an astronaut riding a horse on mars"
49
+ image = pipeline(prompt).images[0]
50
+
51
+ image.save("astronaut_rides_horse.png")
52
+ ```
53
+
54
  ## 📝 License
55
 
56
  VisCPM is governed by the [GML License](https://github.com/OpenBMB/General-Model-License/blob/main/%E9%80%9A%E7%94%A8%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE-%E6%9D%A5%E6%BA%90%E8%AF%B4%E6%98%8E-%E5%AE%A3%E4%BC%A0%E9%99%90%E5%88%B6-%E9%9D%9E%E5%95%86%E4%B8%9A%E5%8C%96.md), and permits individual and research usages. If you intend to utilize the model for commercial purposes, please reach out to [email protected] to negotiate commercial licensing.
57
 
58
+ The CPM-Bee base, governed by the [General Model License (GML)](https://github.com/OpenBMB/General-Model-License/blob/main/%E9%80%9A%E7%94%A8%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE-%E6%9D%A5%E6%BA%90%E8%AF%B4%E6%98%8E-%E5%AE%A3%E4%BC%A0%E9%99%90%E5%88%B6-%E5%95%86%E4%B8%9A%E6%8E%88%E6%9D%83.md), permits commercial usage. If you intend to utilize the model for commercial purposes, please reach out to [email protected] to obtain the certificate of authorization.
config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "_name_or_path": "openbmb/cpm-bee-10b",
4
+ "architectures": [
5
+ "CpmBeeForWithTransform"
6
+ ],
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_cpmbee.CpmBeeConfig",
9
+ "AutoModel": "modeling_cpmbee.CpmBeeWithTransform",
10
+ "AutoTokenizer": "tokenization_viscpmbee.VisCpmBeeTokenizer"
11
+ },
12
+ "vocab_size": 86583,
13
+ "hidden_size": 4096,
14
+ "dim_ff" : 10240,
15
+ "num_hidden_layers" : 48,
16
+ "num_attention_heads": 32,
17
+ "dim_head" : 128,
18
+ "dropout_p" : 0.0,
19
+ "position_bias_num_buckets" : 256,
20
+ "position_bias_num_segment_buckets": 256,
21
+ "position_bias_max_distance" : 2048,
22
+ "eps" : 1e-6,
23
+ "half" : false,
24
+ "model_type": "viscpmbee",
25
+ "unet_cross_attention_dim": 1024
26
+ }
configuration_cpmbee.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The OpenBMB Team and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ CpmBee model configuration"""
16
+
17
+ from typing import List, Optional, Tuple, Union
18
+
19
+ from transformers.configuration_utils import PretrainedConfig
20
+ from transformers.utils import logging
21
+
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+ CPMBEE_PRETRAINED_CONFIG_ARCHIVE_MAP = {
26
+ "openbmb/cpm-bee-10b": "https://huggingface.co/openbmb/cpm-bee-10b/resolve/main/config.json",
27
+ "openbmb/cpm-bee-5b": "https://huggingface.co/openbmb/cpm-bee-5b/resolve/main/config.json",
28
+ "openbmb/cpm-bee-2b": "https://huggingface.co/openbmb/cpm-bee-2b/resolve/main/config.json",
29
+ "openbmb/cpm-bee-1b": "https://huggingface.co/openbmb/cpm-bee-1b/resolve/main/config.json",
30
+ # See all CpmBee models at https://huggingface.co/models?filter=cpmbee
31
+ }
32
+
33
+
34
+ class CpmBeeConfig(PretrainedConfig):
35
+ r"""
36
+ This is the configuration class to store the configuration of a [`CpmBeeModel`]. It is used to instbeeiate an
37
+ CPMBee model according to the specified arguments, defining the model architecture. Instantiating a configuration
38
+ with the defaults will yield a similar configuration to that of the CPMBee
39
+ [openbmb/cpm-bee-10b](https://huggingface.co/openbmb/cpm-bee-10b) architecture.
40
+
41
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
42
+ documentation from [`PretrainedConfig`] for more information.
43
+
44
+ Args:
45
+ vocab_size (`int`, *optional*, defaults to 30720):
46
+ Vocabulary size of the CPMBee model. Defines the number of different tokens that can be represented by the
47
+ `input` passed when calling [`CpmBeeModel`].
48
+ hidden_size (`int`, *optional*, defaults to 4096):
49
+ Dimension of the encoder layers.
50
+ num_attention_heads (`int`, *optional*, defaults to 32):
51
+ Number of attention heads in the Transformer encoder.
52
+ dim_head (`int`, *optional*, defaults to 128):
53
+ Dimension of attention heads for each attention layer in the Transformer encoder.
54
+ dim_ff (`int`, *optional*, defaults to 10240):
55
+ Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
56
+ num_hidden_layers (`int`, *optional*, defaults to 48):
57
+ Number of layers of the Transformer encoder.
58
+ dropout_p (`float`, *optional*, defaults to 0.1):
59
+ The dropout probabilitiy for all fully connected layers in the embeddings, encoder.
60
+ position_bias_num_buckets (`int`, *optional*, defaults to 512):
61
+ The number of position_bias buckets.
62
+ position_bias_num_segment_buckets (`int`, *optional*, defaults to 32):
63
+ The number of segment buckets.
64
+ position_bias_max_distance (`int`, *optional*, defaults to 2048):
65
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
66
+ just in case (e.g., 512 or 1024 or 2048).
67
+ eps (`float`, *optional*, defaults to 1e-6):
68
+ The epsilon used by the layer normalization layers.
69
+ init_std (`float`, *optional*, defaults to 1.0):
70
+ Initialize parameters with std = init_std.
71
+ use_cache (`bool`, *optional*, defaults to `True`):
72
+ Whether to use cache.
73
+ distance_scale (`float` or `int`, *optional*, defaults to 16):
74
+ Scale the rotary embedding.
75
+ mask_modules (`list` or `tuple`, *optional*, defaults to None):
76
+ Decides which feedforward block or attention block is pruned.
77
+ half (`bool`, *optional*, defaults to `False`):
78
+ Decides the model parameters are half-precision or not.
79
+
80
+ Example:
81
+
82
+ ```python
83
+ >>> from transformers import CpmBeeModel, CpmBeeConfig
84
+
85
+ >>> # Initializing a CPMBee cpm-bee-10b style configuration
86
+ >>> configuration = CpmBeeConfig()
87
+
88
+ >>> # Initializing a model from the cpm-bee-10b style configuration
89
+ >>> model = CpmBeeModel(configuration)
90
+
91
+ >>> # Accessing the model configuration
92
+ >>> configuration = model.config
93
+ ```"""
94
+ model_type = "cpmbee"
95
+
96
+ def __init__(
97
+ self,
98
+ vocab_size: int = 30720,
99
+ hidden_size: int = 4096,
100
+ num_attention_heads: int = 64,
101
+ dim_head: int = 64,
102
+ dim_ff: int = 10240,
103
+ num_hidden_layers: int = 32,
104
+ dropout_p: int = 0.0,
105
+ position_bias_num_buckets: int = 256,
106
+ position_bias_num_segment_buckets: int = 32,
107
+ position_bias_max_distance: int = 2048,
108
+ eps: int = 1e-6,
109
+ init_std: float = 1.0,
110
+ use_cache: bool = True,
111
+ distance_scale: Union[int, float] = 16,
112
+ mask_modules: Optional[Union[List, Tuple]] = None,
113
+ half: bool = False,
114
+ **kwargs,
115
+ ):
116
+ super().__init__(**kwargs)
117
+ self.position_bias_num_segment_buckets = position_bias_num_segment_buckets
118
+ self.hidden_size = hidden_size
119
+ self.num_attention_heads = num_attention_heads
120
+ self.dim_head = dim_head
121
+ self.dim_ff = dim_ff
122
+ self.num_hidden_layers = num_hidden_layers
123
+ self.position_bias_num_buckets = position_bias_num_buckets
124
+ self.position_bias_max_distance = position_bias_max_distance
125
+ self.dropout_p = dropout_p
126
+ self.eps = eps
127
+ self.use_cache = use_cache
128
+ self.vocab_size = vocab_size
129
+ self.init_std = init_std
130
+ self.distance_scale = distance_scale
131
+ self.half = half
132
+ self.mask_modules = mask_modules
feature_extractor/preprocessor_config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "crop_size": 224,
3
+ "do_center_crop": true,
4
+ "do_convert_rgb": true,
5
+ "do_normalize": true,
6
+ "do_resize": true,
7
+ "feature_extractor_type": "CLIPFeatureExtractor",
8
+ "image_mean": [
9
+ 0.48145466,
10
+ 0.4578275,
11
+ 0.40821073
12
+ ],
13
+ "image_std": [
14
+ 0.26862954,
15
+ 0.26130258,
16
+ 0.27577711
17
+ ],
18
+ "resample": 3,
19
+ "size": 224
20
+ }
model_index.json CHANGED
@@ -1,13 +1,14 @@
1
  {
2
- "_class_name": "StableDiffusionPipeline",
3
  "_diffusers_version": "0.3.0",
4
  "feature_extractor": [
5
  "transformers",
6
  "CLIPImageProcessor"
7
  ],
 
8
  "safety_checker": [
9
- "stable_diffusion",
10
- "StableDiffusionSafetyChecker"
11
  ],
12
  "scheduler": [
13
  "diffusers",
@@ -15,7 +16,11 @@
15
  ],
16
  "text_encoder": [
17
  "transformers",
18
- "openbmb/cpm-bee-10b"
 
 
 
 
19
  ],
20
  "unet": [
21
  "diffusers",
@@ -24,9 +29,5 @@
24
  "vae": [
25
  "diffusers",
26
  "AutoencoderKL"
27
- ],
28
- "text_safety_checker": [
29
- "transformers",
30
- "BertForSequenceClassification"
31
  ]
32
- }
 
1
  {
2
+ "_class_name": "VisCPMPaintBeePipeline",
3
  "_diffusers_version": "0.3.0",
4
  "feature_extractor": [
5
  "transformers",
6
  "CLIPImageProcessor"
7
  ],
8
+ "requires_safety_checker": false,
9
  "safety_checker": [
10
+ null,
11
+ null
12
  ],
13
  "scheduler": [
14
  "diffusers",
 
16
  ],
17
  "text_encoder": [
18
  "transformers",
19
+ "PreTrainedModel"
20
+ ],
21
+ "tokenizer": [
22
+ "transformers",
23
+ "PreTrainedTokenizer"
24
  ],
25
  "unet": [
26
  "diffusers",
 
29
  "vae": [
30
  "diffusers",
31
  "AutoencoderKL"
 
 
 
 
32
  ]
33
+ }
modeling_cpmbee.py ADDED
@@ -0,0 +1,943 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The OpenBMB Team The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch CpmBee model."""
16
+ import copy
17
+ import math
18
+ from collections import UserDict
19
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+
24
+ from transformers.generation.beam_search import BeamHypotheses, BeamSearchScorer
25
+ from transformers.generation.streamers import BaseStreamer
26
+ from transformers.generation.utils import (
27
+ GenerationConfig,
28
+ LogitsProcessorList,
29
+ StoppingCriteriaList,
30
+ dist,
31
+ inspect,
32
+ is_deepspeed_zero3_enabled,
33
+ warnings,
34
+ )
35
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput
36
+ from transformers.modeling_utils import PreTrainedModel
37
+ from transformers.utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
38
+ from .configuration_cpmbee import CpmBeeConfig
39
+ from .tokenization_viscpmbee import VisCpmBeeTokenizer
40
+
41
+
42
+ logger = logging.get_logger(__name__)
43
+
44
+ _CHECKPOINT_FOR_DOC = "openbmb/cpm-bee-10b"
45
+ _CONFIG_FOR_DOC = "CpmBeeConfig"
46
+
47
+ CPMBEE_PRETRAINED_MODEL_ARCHIVE_LIST = [
48
+ "openbmb/cpm-bee-10b",
49
+ "openbmb/cpm-bee-5b",
50
+ "openbmb/cpm-bee-2b",
51
+ "openbmb/cpm-bee-1b",
52
+ # See all CPMBee models at https://huggingface.co/models?filter=cpmbee
53
+ ]
54
+
55
+
56
+ class CpmBeeLinear(nn.Linear):
57
+ def __init__(self, dim_in, dim_out, dtype):
58
+ """
59
+ Construct a linear for CPMBee. It contains a scale operation.
60
+ """
61
+ super().__init__(dim_in, dim_out, bias=False)
62
+ self.dim_in = self.in_features = dim_in
63
+ self.dim_out = self.out_features = dim_out
64
+
65
+ self.weight = torch.nn.parameter.Parameter(torch.empty((dim_out, dim_in), dtype=dtype))
66
+
67
+ def forward(self, x: torch.Tensor):
68
+ """
69
+ Args:
70
+ x (`torch.Tensor` of shape `(batch, seq_len, dim_in)`): The input of linear layer
71
+ Returns:
72
+ `torch.Tensor` of shape `(batch, seq_len, dim_out)`: The output of the linear transform y.
73
+ """
74
+ x = nn.functional.linear(x, self.weight)
75
+ x = x / math.sqrt(self.dim_in)
76
+ return x
77
+
78
+
79
+ class CpmBeeLayerNorm(nn.Module):
80
+ """
81
+ We use Root Mean Square (RMS) Layer Normalization, please see https://arxiv.org/abs/1910.07467 for details."
82
+ """
83
+
84
+ def __init__(self, config: CpmBeeConfig):
85
+ super().__init__()
86
+
87
+ self.eps = config.eps
88
+ self.dim_norm = config.hidden_size
89
+ self.weight = nn.Parameter(torch.empty(config.hidden_size, dtype=config.torch_dtype))
90
+
91
+ def forward(self, hidden_states: torch.Tensor):
92
+ """
93
+ Args:
94
+ hidden_states (`torch.Tensor` of shape `(batch, seq_len, dim_in)`)
95
+ """
96
+ if hidden_states.size(-1) != self.dim_norm:
97
+ raise AssertionError("hidden_states.size(-1) != self.dim_norm")
98
+ old_dtype = hidden_states.dtype
99
+ variance = hidden_states.to(torch.float32).pow(2).mean(dim=-1, keepdim=True)
100
+ hidden_states = (hidden_states * torch.rsqrt(variance + self.eps)).to(old_dtype) * self.weight
101
+ return hidden_states
102
+
103
+
104
+ class CpmBeeAttention(nn.Module):
105
+ def __init__(self, config: CpmBeeConfig):
106
+ super().__init__()
107
+ self.dim_model = config.hidden_size
108
+ self.num_heads = config.num_attention_heads
109
+ self.dim_head = config.dim_head
110
+
111
+ self.project_q = CpmBeeLinear(self.dim_model, self.num_heads * self.dim_head, dtype=config.torch_dtype)
112
+ self.project_k = CpmBeeLinear(self.dim_model, self.num_heads * self.dim_head, dtype=config.torch_dtype)
113
+ self.project_v = CpmBeeLinear(self.dim_model, self.num_heads * self.dim_head, dtype=config.torch_dtype)
114
+
115
+ self.attention_out = CpmBeeLinear(self.num_heads * self.dim_head, self.dim_model, dtype=config.torch_dtype)
116
+
117
+ self.softmax = torch.nn.Softmax(dim=-1)
118
+
119
+ if config.dropout_p is not None:
120
+ self.dropout = torch.nn.Dropout(p=config.dropout_p)
121
+ else:
122
+ self.dropout = None
123
+
124
+ def forward(
125
+ self,
126
+ hidden_q: torch.Tensor,
127
+ hidden_kv: torch.Tensor,
128
+ attention_mask: torch.BoolTensor,
129
+ position_bias: torch.Tensor,
130
+ output_attentions: Optional[bool] = False,
131
+ past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
132
+ use_cache: Optional[bool] = None,
133
+ ):
134
+ """
135
+ Args:
136
+ hidden_q (`torch.Tensor`):
137
+ Input of transformer block(self-attention block). It can be the raw embedding of a batch of sequences.
138
+ hidden_kv (`torch.Tensor` of shape `(batch, len_k, dim_model)`)):
139
+ Tensor *key_value* and *query* of shape `(batch, len_k, dim_model)`
140
+ attention_mask (`torch.Tensor` of shape `(batch, len_seq, len_seq)`):
141
+ Avoid invalid areas to participate in the calculation of self-attention.
142
+ position_bias (`torch.Tensor` of shape `(batch, len_seq, len_seq)`):
143
+ Provide positional information to self-attention block.
144
+ output_attentions (`bool`, *optional*):
145
+ Whether or not to return the attentions tensors of all attention layers.
146
+ past_key_values (`Tuple[torch.Tensor, torch.Tensor]`, *optional*):
147
+ Cached past key and value projection states.
148
+ use_cache (`bool`, *optional*):
149
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
150
+ (see `past_key_values`).
151
+ """
152
+ batch_size = hidden_q.size(0)
153
+ len_q = hidden_q.size(1)
154
+ len_k = hidden_kv.size(1)
155
+
156
+ query = self.project_q(hidden_q)
157
+ key = self.project_k(hidden_kv)
158
+ value = self.project_v(hidden_kv)
159
+
160
+ query = query.view(batch_size, len_q, self.num_heads, self.dim_head).permute(0, 2, 1, 3)
161
+ key = key.view(batch_size, len_k, self.num_heads, self.dim_head).permute(0, 2, 1, 3)
162
+ value = value.view(batch_size, len_k, self.num_heads, self.dim_head).permute(0, 2, 1, 3)
163
+
164
+ if past_key_values is not None:
165
+ key = torch.cat([past_key_values[0], key], dim=-2)
166
+ value = torch.cat([past_key_values[1], value], dim=-2)
167
+ len_k = key.size(-2)
168
+
169
+ # (batch_size, num_heads, len_q, dim_head) @ (batch_size, num_heads, dim_head, len_k) -> (batch_size, num_heads, len_q, len_k)
170
+ score = torch.matmul(query, key.transpose(-1, -2)) / math.sqrt(self.dim_head)
171
+ score = score + position_bias
172
+
173
+ score = torch.masked_fill(
174
+ score,
175
+ attention_mask.view(batch_size, 1, len_q, len_k) == torch.tensor(False),
176
+ torch.scalar_tensor(float("-inf"), device=score.device, dtype=score.dtype),
177
+ )
178
+ score = self.softmax(score)
179
+
180
+ score = torch.masked_fill(
181
+ score,
182
+ attention_mask.view(batch_size, 1, len_q, len_k) == torch.tensor(False),
183
+ torch.scalar_tensor(0, device=score.device, dtype=score.dtype),
184
+ )
185
+ if output_attentions:
186
+ attn_weights = score
187
+ else:
188
+ attn_weights = None
189
+
190
+ if self.dropout is not None:
191
+ score = self.dropout(score)
192
+
193
+ # (batch_size, num_heads, len_q, len_k) @ (batch_size, num_heads, len_k, dim_head) -> (batch_size, num_heads, len_q, dim_head)
194
+ score = torch.matmul(score, value)
195
+
196
+ score = score.view(batch_size, self.num_heads, len_q, self.dim_head).permute(0, 2, 1, 3)
197
+ score = score.contiguous().view(batch_size, len_q, self.num_heads * self.dim_head)
198
+
199
+ score = self.attention_out(score)
200
+
201
+ past_key_values = None
202
+ if use_cache:
203
+ past_key_values = (key, value)
204
+
205
+ return score, attn_weights, past_key_values
206
+
207
+
208
+ class CpmBeeSelfAttentionBlock(nn.Module):
209
+ def __init__(self, config: CpmBeeConfig):
210
+ super().__init__()
211
+ self.layernorm_before_attention = CpmBeeLayerNorm(config)
212
+ self.self_attention = CpmBeeAttention(config)
213
+ if config.dropout_p:
214
+ self.dropout = torch.nn.Dropout(config.dropout_p)
215
+ else:
216
+ self.dropout = None
217
+
218
+ def forward(
219
+ self,
220
+ hidden_states: torch.Tensor,
221
+ attention_mask: torch.Tensor,
222
+ position_bias: Optional[torch.Tensor] = None,
223
+ output_attentions: Optional[bool] = False,
224
+ past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
225
+ use_cache: Optional[bool] = None,
226
+ ):
227
+ """
228
+ Args:
229
+ hidden_states (`torch.Tensor` of shape `(batch, len_seq, dim_model)`):
230
+ Input of transformer block(self-attention block). It can be the raw embedding of a batch of sequences.
231
+ attention_mask (`torch.Tensor` of shape `(batch, len_seq, len_seq)`):
232
+ Avoid invalid areas to participate in the calculation of self-attention.
233
+ position_bias (`torch.Tensor` of shape `(batch, len_seq, len_seq)`):
234
+ Provide positional information to self-attention block.
235
+ output_attentions (`bool`, *optional*):
236
+ Whether or not to return the attentions tensors of all attention layers.
237
+ past_key_values (`Tuple(torch.FloatTensor)`, *optional*):
238
+ Cached past key and value projection states.
239
+ use_cache (`bool`, *optional*):
240
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
241
+ (see `past_key_values`).
242
+ """
243
+ outputs = self.layernorm_before_attention(hidden_states)
244
+ outputs = self.self_attention(
245
+ outputs, outputs, attention_mask, position_bias, output_attentions, past_key_values, use_cache
246
+ )
247
+
248
+ outputs, attn_weights, current_key_value = outputs
249
+
250
+ if self.dropout is not None:
251
+ outputs = self.dropout(outputs)
252
+ hidden_states = (hidden_states + outputs) / 1.05
253
+
254
+ return hidden_states, attn_weights, current_key_value
255
+
256
+
257
+ class CpmBeeDenseGatedACT(nn.Module):
258
+ def __init__(self, config: CpmBeeConfig):
259
+ super().__init__()
260
+ self.w_0 = CpmBeeLinear(config.hidden_size, config.dim_ff, dtype=config.torch_dtype)
261
+ self.w_1 = CpmBeeLinear(config.hidden_size, config.dim_ff, dtype=config.torch_dtype)
262
+ self.act = torch.nn.GELU()
263
+
264
+ def forward(self, hidden_states: torch.Tensor):
265
+ """Transform an input tensor from one feature space to another via a nonlinear operation
266
+
267
+ Args:
268
+ hidden_states (`torch.Tensor` of shape `(batch, seq_len, dim_in)`)
269
+ """
270
+ gate_score = self.act(self.w_0(hidden_states))
271
+ hidden_states = self.w_1(hidden_states)
272
+
273
+ hidden_states = gate_score * hidden_states
274
+ return hidden_states
275
+
276
+
277
+ class CpmBeeFeedForward(nn.Module):
278
+ def __init__(self, config: CpmBeeConfig):
279
+ super().__init__()
280
+ self.w_in = CpmBeeDenseGatedACT(config)
281
+ if config.dropout_p is not None:
282
+ self.dropout = torch.nn.Dropout(config.dropout_p)
283
+ else:
284
+ self.dropout = None
285
+
286
+ self.w_out = CpmBeeLinear(config.dim_ff, config.hidden_size, dtype=config.torch_dtype)
287
+
288
+ def forward(self, hidden_states: torch.Tensor):
289
+ """
290
+ Args:
291
+ hidden_states (`torch.Tensor` of shape `(batch, seq_len, dim_in)`)
292
+ """
293
+ hidden_states = self.w_in(hidden_states)
294
+
295
+ if self.dropout is not None:
296
+ hidden_states = self.dropout(hidden_states)
297
+
298
+ hidden_states = self.w_out(hidden_states)
299
+
300
+ return hidden_states
301
+
302
+
303
+ class CpmBeeFFNBlock(nn.Module):
304
+ def __init__(self, config: CpmBeeConfig):
305
+ super().__init__()
306
+ self.layernorm_before_ffn = CpmBeeLayerNorm(config)
307
+ self.ffn = CpmBeeFeedForward(config)
308
+ if config.dropout_p:
309
+ self.dropout = torch.nn.Dropout(config.dropout_p)
310
+ else:
311
+ self.dropout = None
312
+
313
+ def forward(
314
+ self,
315
+ hidden_states: torch.Tensor,
316
+ ):
317
+ """
318
+ Args:
319
+ hidden_states (`torch.Tensor` of shape `(batch, len_seq, dim_model)`):
320
+ Hidden states before feed forward layer.
321
+ """
322
+ ln_outputs = self.layernorm_before_ffn(hidden_states)
323
+ outputs = self.ffn(ln_outputs)
324
+ if self.dropout is not None:
325
+ outputs = self.dropout(outputs)
326
+ hidden_states = (hidden_states + outputs) / 1.05
327
+ return hidden_states
328
+
329
+
330
+ class CpmBeeTransformerBlock(nn.Module):
331
+ def __init__(self, config: CpmBeeConfig, mask_att: bool = False, mask_ffn: bool = False):
332
+ super().__init__()
333
+ self.mask_att = mask_att
334
+ self.mask_ffn = mask_ffn
335
+
336
+ if not self.mask_att:
337
+ self.self_att = CpmBeeSelfAttentionBlock(config)
338
+ if not self.mask_ffn:
339
+ self.ffn = CpmBeeFFNBlock(config)
340
+
341
+ def forward(
342
+ self,
343
+ hidden_states: torch.Tensor,
344
+ attention_mask: torch.Tensor,
345
+ position_bias: Optional[torch.Tensor] = None,
346
+ output_attentions: Optional[bool] = False,
347
+ past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
348
+ use_cache: Optional[bool] = None,
349
+ ):
350
+ """
351
+ Args:
352
+ hidden_states (`torch.Tensor`):
353
+ Input to the layer of shape `(batch, seq_len, dim_model)`
354
+ attention_mask (`torch.Tensor`):
355
+ Avoid invalid areas to participate in the calculation of shape `(batch, seq_len, seq_len)`
356
+ position_bias (`torch.Tensor`):
357
+ Provides position information to attention mechanism of shape `(num_heads, seq_len, seq_len)`
358
+ output_attentions (`bool`, *optional*):
359
+ Whether or not to return the attentions tensors of all attention layers.
360
+ past_key_values (`Tuple[torch.Tensor, torch.Tensor])`, *optional*):
361
+ Cached past key and value projection states
362
+ use_cache (`bool`, *optional*):
363
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
364
+ (see `past_key_values`).
365
+ """
366
+ if not self.mask_att:
367
+ hidden_states = self.self_att(
368
+ hidden_states,
369
+ attention_mask=attention_mask,
370
+ position_bias=position_bias,
371
+ output_attentions=output_attentions,
372
+ past_key_values=past_key_values,
373
+ use_cache=use_cache,
374
+ )
375
+
376
+ hidden_states, attn_weights, current_key_value = hidden_states
377
+ else:
378
+ attn_weights, current_key_value = None, (None, None)
379
+
380
+ if not self.mask_ffn:
381
+ hidden_states = self.ffn(hidden_states)
382
+
383
+ return hidden_states, attn_weights, current_key_value
384
+
385
+
386
+ class CpmBeeEncoder(nn.Module):
387
+ def __init__(self, config: CpmBeeConfig):
388
+ super().__init__()
389
+ self.num_layers = config.num_hidden_layers
390
+ if config.mask_modules is not None:
391
+ assert len(config.mask_modules) == self.num_layers, "The total number of masks should equal to num_layers"
392
+ for mask_module in config.mask_modules:
393
+ assert len(mask_module) == 2, "For encoder, each mask should be (mask_att, mask_ffn)"
394
+ else:
395
+ config.mask_modules = [(False, False)] * self.num_layers
396
+
397
+ self.layers = nn.ModuleList(
398
+ [
399
+ CpmBeeTransformerBlock(
400
+ config, mask_att=config.mask_modules[ith][0], mask_ffn=config.mask_modules[ith][1]
401
+ )
402
+ for ith in range(self.num_layers)
403
+ ]
404
+ )
405
+
406
+ self.output_layernorm = CpmBeeLayerNorm(config)
407
+
408
+ def forward(
409
+ self,
410
+ hidden_states: torch.Tensor,
411
+ attention_mask: torch.Tensor,
412
+ position_bias: torch.Tensor,
413
+ output_attentions: Optional[bool] = None,
414
+ output_hidden_states: Optional[bool] = None,
415
+ past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
416
+ use_cache: Optional[bool] = None,
417
+ ):
418
+ """
419
+ Args:
420
+ hidden_states (`torch.Tensor`):
421
+ Input to the layer of shape `(batch, seq_len, dim_model)`
422
+ attention_mask (`torch.Tensor`):
423
+ Avoid invalid areas to participate in the calculation of shape `(batch, seq_len, seq_len)`
424
+ position_bias (`torch.Tensor`):
425
+ Provides position information to attention mechanism of shape `(num_heads, seq_len, seq_len)`
426
+ output_attentions (`bool`, *optional*):
427
+ Whether or not to return the attentions tensors of all attention layers.
428
+ output_hidden_states (`bool`, *optional*):
429
+ Whether or not to return the hidden states of all layers.
430
+ past_key_values (`Tuple[torch.Tensor, torch.Tensor])`, *optional*):
431
+ Cached past key and value projection states
432
+ use_cache (`bool`, *optional*):
433
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
434
+ (see `past_key_values`).
435
+ """
436
+ all_hidden_states = () if output_hidden_states else None
437
+ all_self_attns = () if output_attentions else None
438
+ current_key_values = () if use_cache else None
439
+
440
+ for i, layer in enumerate(self.layers):
441
+ if output_hidden_states:
442
+ all_hidden_states += (hidden_states,)
443
+ layer_outputs = layer(
444
+ hidden_states,
445
+ attention_mask,
446
+ position_bias,
447
+ output_attentions=output_attentions,
448
+ past_key_values=past_key_values[i] if past_key_values else None,
449
+ use_cache=use_cache,
450
+ )
451
+ hidden_states, attn_weights, current_key_value = layer_outputs
452
+ if output_attentions:
453
+ all_self_attns += (attn_weights,)
454
+ if current_key_value is not None:
455
+ current_key_values = current_key_values + (current_key_value,)
456
+
457
+ hidden_states = self.output_layernorm(hidden_states)
458
+
459
+ if output_hidden_states:
460
+ all_hidden_states += (hidden_states,)
461
+
462
+ return hidden_states, current_key_values, all_hidden_states, all_self_attns
463
+
464
+
465
+ class CpmBeeBucketPositionBias(nn.Module):
466
+ def __init__(self, config: CpmBeeConfig) -> None:
467
+ super().__init__()
468
+
469
+ self.num_heads = config.num_attention_heads
470
+ self.num_buckets = config.position_bias_num_buckets
471
+ self.num_segment_bucket = config.position_bias_num_segment_buckets
472
+ self.max_distance = config.position_bias_max_distance
473
+
474
+ self.relative_attention_bias = nn.Parameter(
475
+ torch.empty(
476
+ config.position_bias_num_buckets + config.position_bias_num_segment_buckets,
477
+ config.num_attention_heads,
478
+ dtype=config.torch_dtype,
479
+ ),
480
+ )
481
+
482
+ def forward(self, query_pos: torch.Tensor, key_pos: torch.Tensor, rel_buckets: torch.Tensor):
483
+ with torch.no_grad():
484
+ batch = key_pos.size(0)
485
+ keylen = key_pos.size(1)
486
+ querylen = query_pos.size(1)
487
+
488
+ if key_pos.size(0) != query_pos.size(0):
489
+ raise AssertionError(
490
+ f"key_pos.size(0) should be equal to query_pos.size(0), but got {key_pos.size(0)} and {query_pos.size(0)}!"
491
+ )
492
+ if rel_buckets.size(0) != batch:
493
+ raise AssertionError(
494
+ f"rel_buckets.size(0) should be equal to batch, but got {rel_buckets.size(0)} and {batch}!"
495
+ )
496
+ if rel_buckets.size(1) != querylen:
497
+ raise AssertionError(
498
+ f"rel_buckets.size(1) should be equal to querylen, but got {rel_buckets.size(1)} and {querylen}!"
499
+ )
500
+ if rel_buckets.size(2) != keylen:
501
+ raise AssertionError(
502
+ f"rel_buckets.size(2) should be equal to keylen, but got {rel_buckets.size(2)} and {keylen}!"
503
+ )
504
+
505
+ relative_position_bucket = rel_buckets - 1 + self.num_buckets
506
+
507
+ inner_segment_bucket = self._position_bucket(
508
+ key_pos[..., None, :] - query_pos[..., :, None],
509
+ num_buckets=self.num_buckets,
510
+ max_distance=self.max_distance,
511
+ )
512
+ relative_position_bucket = torch.where(
513
+ rel_buckets == 0,
514
+ inner_segment_bucket,
515
+ relative_position_bucket,
516
+ )
517
+
518
+ embeds = nn.functional.embedding(relative_position_bucket, self.relative_attention_bias)
519
+ embeds = embeds.permute(0, 3, 1, 2).contiguous()
520
+ return embeds
521
+
522
+ def _position_bucket(self, relative_position, num_buckets=32, max_distance=128):
523
+ relative_buckets = 0
524
+ num_buckets //= 2
525
+ relative_buckets = (relative_position > 0).to(torch.int32) * num_buckets
526
+ relative_position = torch.abs(relative_position)
527
+ max_exact = num_buckets // 2
528
+ is_small = relative_position < max_exact
529
+ relative_postion_if_large = max_exact + (
530
+ torch.log(relative_position.float() / max_exact)
531
+ / math.log(max_distance / max_exact)
532
+ * (num_buckets - max_exact)
533
+ ).to(torch.int32)
534
+ relative_postion_if_large = torch.min(
535
+ relative_postion_if_large,
536
+ torch.full_like(relative_postion_if_large, num_buckets - 1),
537
+ )
538
+ relative_buckets += torch.where(is_small, relative_position.to(torch.int32), relative_postion_if_large)
539
+ return relative_buckets
540
+
541
+
542
+ # Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->CPMBee
543
+ class CpmBeeOutput(nn.Module):
544
+ def __init__(self, config):
545
+ super().__init__()
546
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
547
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
548
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
549
+
550
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
551
+ hidden_states = self.dense(hidden_states)
552
+ hidden_states = self.dropout(hidden_states)
553
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
554
+ return hidden_states
555
+
556
+
557
+ class CpmBeeRotaryEmbedding(nn.Module):
558
+ """
559
+ RotaryEmbedding embeds the unk token and special token. It will embeds the "...<mask>...<mask>...<unk>...<unk>..."
560
+ to "...<mask_0>...<mask_1>...<unk_0>...<unk_1>..."" to help model to specify different special tokens and unk
561
+ tokens.
562
+ """
563
+
564
+ def __init__(self, config: CpmBeeConfig):
565
+ super().__init__()
566
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, config.hidden_size, 2, dtype=torch.float32) / config.hidden_size))
567
+ self.distance_scale = config.distance_scale
568
+ self.dtype = config.torch_dtype
569
+ self.inv_freq = inv_freq.to(config.torch_dtype)
570
+
571
+ def forward(self, x: torch.Tensor, x_pos: torch.Tensor):
572
+ inv_freq = self.inv_freq.to(device=x.device, dtype=self.dtype)
573
+
574
+ x_pos = x_pos * self.distance_scale
575
+ freqs = x_pos[..., None].to(self.dtype) * inv_freq[None, :] # (..., dim/2)
576
+
577
+ emb = torch.cat((freqs, freqs), dim=-1) # (..., dim)
578
+ emb_cos = emb.cos() # (..., dim)
579
+ emb_sin = emb.sin() # (..., dim)
580
+
581
+ rotate_x = torch.cat([-x[..., x.size(-1) // 2 :], x[..., : x.size(-1) // 2]], dim=-1) # (..., dim)
582
+
583
+ return x * emb_cos + rotate_x * emb_sin
584
+
585
+
586
+ class CpmBeeEmbeddingExt(nn.Embedding):
587
+ """
588
+ Contains a RotaryEmbedding.
589
+ """
590
+
591
+ def __init__(self, config: CpmBeeConfig):
592
+ super().__init__(config.vocab_size, config.hidden_size, dtype=config.torch_dtype)
593
+ self.dim_model = config.hidden_size
594
+ self.rotary_emb = CpmBeeRotaryEmbedding(config)
595
+
596
+ def forward(self, ids: torch.Tensor, ids_sub: torch.Tensor):
597
+ embeds = super().forward(ids) / math.sqrt(self.dim_model)
598
+ return self.rotary_emb(embeds, ids_sub)
599
+
600
+ def projection(self, x: torch.Tensor, ext_table: Optional[torch.Tensor] = None):
601
+ logits = nn.functional.linear(x / math.sqrt(self.dim_model), self.weight)
602
+ if ext_table is not None:
603
+ logits_ext = nn.functional.linear(x, ext_table)
604
+ logits = torch.cat([logits, logits_ext], dim=-1)
605
+ return logits
606
+
607
+
608
+ class CpmBeePreTrainedModel(PreTrainedModel):
609
+ """
610
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
611
+ models.
612
+ """
613
+
614
+ config_class = CpmBeeConfig
615
+ base_model_prefix = "cpmbee"
616
+ supports_gradient_checkpointing = True
617
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
618
+
619
+ def _init_weights(self, module):
620
+ """Initialize the weights"""
621
+ if isinstance(module, nn.Linear):
622
+ module.weight.data.normal_(mean=0.0, std=self.config.init_std)
623
+ if module.bias is not None:
624
+ module.bias.data.zero_()
625
+ # still needed
626
+ elif isinstance(module, CpmBeeEmbeddingExt):
627
+ module.weight.data.normal_(mean=0.0, std=self.config.init_std)
628
+ elif isinstance(module, nn.LayerNorm):
629
+ module.bias.data.zero_()
630
+ module.weight.data.fill_(1.0)
631
+ elif isinstance(module, CpmBeeLayerNorm):
632
+ module.weight.data.fill_(1.0)
633
+ elif isinstance(module, CpmBeeBucketPositionBias):
634
+ module.relative_attention_bias.data.normal_(mean=0.0, std=self.config.init_std)
635
+
636
+ def _set_gradient_checkpointing(self, module, value=False):
637
+ if isinstance(module, CpmBeeEncoder):
638
+ module.gradient_checkpointing = value
639
+
640
+
641
+ CPMBEE_START_DOCSTRING = r"""
642
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
643
+ it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
644
+ behavior.
645
+
646
+ Parameters
647
+ config ([`~CpmBeeConfig`]): Model configuration class with all the parameters of the
648
+ Initializing with a config file does not load the weights associated with the model, only the
649
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
650
+ """
651
+
652
+ CPMBEE_INPUTS_DOCSTRING = r"""
653
+ Args:
654
+ input_ids (`torch.Tensor` of shape `(batch_size, seq_len)`):
655
+ Indices of input sequence tokens in the vocabulary.
656
+
657
+ Indices can be obtained using [`CPMBeeTokenizer`]. See [`PreTrainedTokenizer.encode`] and
658
+ [`PreTrainedTokenizer.__call__`] for details.
659
+
660
+ [What are input IDs?](../glossary#input-ids)
661
+ input_id_sub (`torch.Tensor` of shape `(batch_size, seq_len)`):
662
+ Subscription of input sequence tokens in the vocabulary.
663
+
664
+ Subscription of normal text will be zero while the special tokens of each group will be the 0, 1, 2, ...
665
+ <ans_0>, <ans_1>, <ans_2> ... belongs to group <ans>. <mask_0>, <mask_1>, <mask_2> ... belongs to group
666
+ <mask>.
667
+ position (`torch.Tensor` of shape `(batch_size, seq_len)`):
668
+ The position of input sequence tokens in the vocabulary for each segment. if segment1 is 0, 1, 2 and
669
+ segment2 is 0, 1, 2, 3, the position will be 0, 1, 2, 0, 1, 2, 3
670
+ context (`torch.Tensor` of shape `(batch_size, seq_len)`):
671
+ Whether this token id is context or not. If is context, the value is 1. If not, the value is 0. If a token
672
+ id is context, it does not need to be predicted.
673
+ sample_ids (`torch.Tensor` of shape `(batch_size, seq_len)`):
674
+ Give a sample id to every token id. The token ids with same sample ids belongs to the same sample.
675
+ num_segments (`torch.Tensor` of shape `(batch_size, seq_len)`):
676
+ Total number of segments in the current input.
677
+ segment (`torch.Tensor` of shape `(batch_size, seq_len)`):
678
+ Give a segment id to every token id. The token ids with same segment ids belongs to the same sample.
679
+
680
+ Generally, a string key or value in input data will be a segment. For example, input {"input": "hello, ",
681
+ "<ans>": ""}, the segments includes: "input", "hello, ", "<ans>" and "".
682
+ segment_rel_offset (`torch.Tensor` of shape `(batch_size, seq_len)`):
683
+ The offset of segment rel.
684
+ segment_rel (`torch.Tensor` of shape `(batch_size, seq_len)`):
685
+ The segment relevance. A relative implementation of measuring the importance of segments.
686
+ past_states (`Dict[str, Union[torch.Tensor, List]]`):
687
+ Store the history information including position, context, sample_ids, num_segments, segment and
688
+ past_key_values.
689
+ output_attentions (`bool`, *optional*):
690
+ Whether or not to return the attentions tensors of all attention layers.
691
+ output_hidden_states (`bool`, *optional*):
692
+ Whether or not to return the hidden states of all layers.
693
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
694
+ A dummy arguments for CPMBee. The `past_states` contains pre-computed hidden-states (key and values in the
695
+ self-attention blocks and in the cross-attention blocks) that can be used (see `past_key_values` input) and
696
+ other history arguments to speed up sequential decoding.
697
+ use_cache (`bool`, *optional*):
698
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
699
+ `past_key_values`).
700
+ labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
701
+ Labels for computing the masked language modeling loss.
702
+ return_dict (`bool`, *optional*):
703
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
704
+ """
705
+
706
+
707
+ @add_start_docstrings(
708
+ "The bare CPMBee Model outputting raw hidden-states without any specific head on top.",
709
+ CPMBEE_START_DOCSTRING,
710
+ )
711
+ class CpmBeeModel(CpmBeePreTrainedModel):
712
+ def __init__(self, config: CpmBeeConfig):
713
+ super().__init__(config)
714
+ if config.half:
715
+ config.torch_dtype = torch.half
716
+ else:
717
+ config.torch_dtype = torch.float
718
+ self.encoder = CpmBeeEncoder(config)
719
+ self.input_embedding = CpmBeeEmbeddingExt(config)
720
+ self.position_bias = CpmBeeBucketPositionBias(config)
721
+ self.vocab_size = config.vocab_size
722
+ self.post_init()
723
+
724
+ def get_input_embeddings(self):
725
+ return self.input_embedding
726
+
727
+ def set_input_embeddings(self, embeddings, **kwargs):
728
+ self.input_embedding = embeddings
729
+
730
+ @add_start_docstrings_to_model_forward(CPMBEE_INPUTS_DOCSTRING)
731
+ @add_code_sample_docstrings(
732
+ checkpoint=_CHECKPOINT_FOR_DOC,
733
+ output_type=BaseModelOutputWithPast,
734
+ config_class=_CONFIG_FOR_DOC,
735
+ )
736
+ def forward(
737
+ self,
738
+ input_ids: torch.Tensor,
739
+ input_id_sub: Optional[torch.Tensor] = None,
740
+ position: Optional[torch.Tensor] = None,
741
+ context: Optional[torch.Tensor] = None,
742
+ sample_ids: Optional[torch.Tensor] = None,
743
+ num_segments: Optional[torch.Tensor] = None,
744
+ segment: Optional[torch.Tensor] = None,
745
+ segment_rel_offset: Optional[torch.Tensor] = None,
746
+ segment_rel: Optional[torch.Tensor] = None,
747
+ past_states: Optional[Dict] = None,
748
+ output_attentions: Optional[bool] = None,
749
+ output_hidden_states: Optional[bool] = None,
750
+ past_key_values: Optional[List] = None,
751
+ use_cache: Optional[bool] = None,
752
+ return_dict: Optional[bool] = None,
753
+ **kwargs,
754
+ ):
755
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
756
+ output_hidden_states = (
757
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
758
+ )
759
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
760
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
761
+
762
+ # dummy setting for common tests
763
+ if input_id_sub is None:
764
+ dtype, device = input_ids.dtype, input_ids.device
765
+ batch, seq_length = input_ids.size()
766
+ segment = torch.where(input_ids != 0, 2, 0).to(dtype=dtype, device=device)
767
+ context = torch.full((batch, seq_length), 1, dtype=dtype, device=device)
768
+ position = torch.arange(seq_length, dtype=dtype, device=device).repeat(batch, 1)
769
+ input_id_sub = torch.full((batch, seq_length), 0, dtype=dtype, device=device)
770
+ segment_rel_offset = torch.full((batch, seq_length), 0, dtype=dtype, device=device)
771
+ segment_rel = torch.full((batch, seq_length), 0, dtype=dtype, device=device)
772
+ num_segments = torch.full((batch, seq_length), 0, dtype=dtype, device=device)
773
+ sample_ids = torch.zeros_like(input_ids)
774
+
775
+ with torch.no_grad():
776
+ if past_states is None:
777
+ present_position = position
778
+ present_context = context
779
+ present_sample_ids = sample_ids
780
+ present_num_segments = num_segments
781
+ present_segments = segment
782
+ present_buffer = None
783
+ else:
784
+ present_position = torch.cat([past_states["buffer_position"], position], dim=-1)
785
+ present_context = torch.cat([past_states["buffer_context"], context], dim=-1)
786
+ present_sample_ids = torch.cat([past_states["buffer_sample_ids"], sample_ids], dim=-1)
787
+ present_num_segments = torch.cat([past_states["buffer_num_segments"], num_segments], dim=-1)
788
+ present_segments = torch.cat([past_states["buffer_segments"], segment], dim=-1)
789
+ present_buffer = past_states["buffer"]
790
+
791
+ batch = input_ids.size(0)
792
+ len_q = input_ids.size(1)
793
+ len_buffer = present_position.size(1)
794
+
795
+ segment_rel_2d = torch.masked_fill(
796
+ segment[:, :, None] * num_segments[:, :, None]
797
+ + present_segments[:, None, :]
798
+ + segment_rel_offset[:, :, None],
799
+ ~((sample_ids[:, :, None] == present_sample_ids[:, None, :])), # not in the same sample
800
+ 0, # avoid torch.gather overflow
801
+ ).view(batch, len_q * len_buffer)
802
+
803
+ segment_bucket = torch.gather(
804
+ input=segment_rel,
805
+ dim=1,
806
+ index=segment_rel_2d.long(),
807
+ ).view(batch, len_q, len_buffer)
808
+
809
+ segment_bucket.masked_fill_(
810
+ ~((sample_ids[:, :, None] == present_sample_ids[:, None, :])), # not in the same span or sample
811
+ 1, # bucket is used for in-context samples
812
+ )
813
+
814
+ # directional mask
815
+ directional_mask_2d = present_position[:, None, :] <= position[:, :, None]
816
+ # sample mask
817
+ sample_mask_2d = (sample_ids[:, :, None] == 0) | (sample_ids[:, :, None] == present_sample_ids[:, None, :])
818
+ # context mask
819
+ attention_mask = present_context[:, None, :] | (
820
+ context[:, :, None].logical_not() & directional_mask_2d.view(batch, len_q, len_buffer)
821
+ )
822
+ # span mask
823
+ attention_mask = attention_mask & sample_mask_2d
824
+ # length mask
825
+ mask_1d = present_num_segments != 0
826
+ attention_mask = mask_1d.view(batch, 1, len_buffer) & attention_mask
827
+
828
+ hidden_states = self.input_embedding(input_ids, input_id_sub)
829
+ position_bias = self.position_bias(position, present_position, segment_bucket)
830
+ hidden_states, present_key_values, all_hidden_states, all_attentions = self.encoder(
831
+ hidden_states,
832
+ attention_mask,
833
+ position_bias,
834
+ output_attentions,
835
+ output_hidden_states,
836
+ present_buffer,
837
+ use_cache,
838
+ )
839
+
840
+ if not return_dict:
841
+ return tuple(
842
+ v for v in [hidden_states, present_key_values, all_hidden_states, all_attentions] if v is not None
843
+ )
844
+
845
+ return BaseModelOutputWithPast(
846
+ last_hidden_state=hidden_states,
847
+ past_key_values=present_key_values,
848
+ hidden_states=all_hidden_states,
849
+ attentions=all_attentions,
850
+ )
851
+
852
+
853
+ class CpmBeeBeamHypotheses(BeamHypotheses):
854
+ def __init__(self, num_beams: int, length_penalty: float, early_stopping: bool, max_length: Optional[int] = None):
855
+ """
856
+ Override BeamHypotheses for CpmBee. The hyp to add is list but not tensor.
857
+ """
858
+ super().__init__(num_beams, length_penalty, early_stopping, max_length)
859
+
860
+ def add(self, hyp: List, sum_logprobs: float, beam_indices: Optional[torch.LongTensor] = None):
861
+ """
862
+ Add a new hypothesis to the list.
863
+ """
864
+ score = sum_logprobs / (len(hyp) ** self.length_penalty)
865
+ if len(self) < self.num_beams or score > self.worst_score:
866
+ self.beams.append((score, hyp, beam_indices))
867
+ if len(self) > self.num_beams:
868
+ sorted_next_scores = sorted([(s, idx) for idx, (s, _, _) in enumerate(self.beams)])
869
+ del self.beams[sorted_next_scores[0][1]]
870
+ self.worst_score = sorted_next_scores[1][0]
871
+ else:
872
+ self.worst_score = min(score, self.worst_score)
873
+
874
+
875
+ class CPMBeeTransBlock(torch.nn.Module):
876
+ def __init__(
877
+ self,
878
+ dim_model=4096,
879
+ dim_ff=1024,
880
+ dim_out=768,
881
+ dtype=torch.float,
882
+ eps=1e-6,
883
+ dropout_p=0,
884
+ ):
885
+ super().__init__()
886
+ if dropout_p is not None:
887
+ self.dropout = torch.nn.Dropout(dropout_p)
888
+ else:
889
+ self.dropout = None
890
+ self.w_out_res = torch.nn.Linear(dim_model, dim_out, bias=False)
891
+ self.layernorm = torch.nn.LayerNorm(
892
+ dim_out,
893
+ dtype=dtype,
894
+ eps=eps,
895
+ )
896
+
897
+ def forward(self, hidden_states: torch.Tensor):
898
+ x_res = self.w_out_res(hidden_states)
899
+ if self.dropout is not None:
900
+ x_res = self.dropout(x_res)
901
+ hidden_states = self.layernorm(x_res)
902
+ return hidden_states
903
+
904
+
905
+ class CpmBeeWithTransform(CpmBeePreTrainedModel):
906
+ _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
907
+
908
+ def __init__(self, config: CpmBeeConfig):
909
+ super().__init__(config)
910
+ self.llm = CpmBeeModel(config)
911
+
912
+ self.trans_block = CPMBeeTransBlock(config.hidden_size, config.hidden_size // 4, config.unet_cross_attention_dim)
913
+
914
+ def forward(
915
+ self,
916
+ input_ids: torch.Tensor,
917
+ input_id_sub: Optional[torch.Tensor] = None,
918
+ position: Optional[torch.Tensor] = None,
919
+ context: Optional[torch.Tensor] = None,
920
+ sample_ids: Optional[torch.Tensor] = None,
921
+ num_segments: Optional[torch.Tensor] = None,
922
+ segment: Optional[torch.Tensor] = None,
923
+ segment_rel_offset: Optional[torch.Tensor] = None,
924
+ segment_rel: Optional[torch.Tensor] = None,
925
+ past_states: Optional[Dict] = None,
926
+ output_attentions: Optional[bool] = None,
927
+ output_hidden_states: Optional[bool] = None,
928
+ past_key_values: Optional[List] = None,
929
+ use_cache: Optional[bool] = None,
930
+ return_dict: Optional[bool] = None,
931
+ **kwargs,
932
+ ):
933
+ outputs = self.llm(input_ids, input_id_sub, position, context,
934
+ sample_ids, num_segments, segment, segment_rel_offset,
935
+ segment_rel, past_states, output_attentions, output_hidden_states,
936
+ past_key_values, use_cache, return_dict, **kwargs,)
937
+ if return_dict:
938
+ hidden_states = outputs.last_hidden_state
939
+ else:
940
+ hidden_states = outputs[0]
941
+ #if self.trans_block is not None:
942
+ # hidden_states = self.trans_block(hidden_states)
943
+ return outputs, hidden_states
pipeline_stable_diffusion.py ADDED
@@ -0,0 +1,723 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import inspect
15
+ import warnings
16
+ from typing import Any, Callable, Dict, List, Optional, Union, Tuple
17
+ import numpy as np
18
+
19
+ import torch
20
+ from torch.utils.data.dataloader import default_collate
21
+ from packaging import version
22
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
23
+
24
+ from diffusers.configuration_utils import FrozenDict
25
+ from diffusers.image_processor import VaeImageProcessor
26
+ from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
27
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
28
+ from diffusers.schedulers import KarrasDiffusionSchedulers
29
+ from diffusers.utils import (
30
+ deprecate,
31
+ is_accelerate_available,
32
+ is_accelerate_version,
33
+ logging,
34
+ randn_tensor,
35
+ replace_example_docstring,
36
+ )
37
+ from diffusers.pipeline_utils import DiffusionPipeline
38
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
39
+
40
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
41
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg, StableDiffusionPipeline
42
+ from .modeling_cpmbee import CpmBeeModel
43
+ from .tokenization_viscpmbee import VisCpmBeeTokenizer
44
+
45
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
46
+
47
+ def pad(orig_items, key, max_length=None, padding_value=0, padding_side="left"):
48
+ items = []
49
+ if isinstance(orig_items[0][key], list):
50
+ assert isinstance(orig_items[0][key][0], torch.Tensor)
51
+ for it in orig_items:
52
+ for tr in it[key]:
53
+ items.append({key: tr})
54
+ else:
55
+ assert isinstance(orig_items[0][key], torch.Tensor)
56
+ items = orig_items
57
+
58
+ batch_size = len(items)
59
+ shape = items[0][key].shape
60
+ dim = len(shape)
61
+ assert dim <= 3
62
+ if max_length is None:
63
+ max_length = 0
64
+ max_length = max(max_length, max(item[key].shape[-1] for item in items))
65
+ min_length = min(item[key].shape[-1] for item in items)
66
+ dtype = items[0][key].dtype
67
+
68
+ if dim == 1:
69
+ return torch.cat([item[key] for item in items], dim=0)
70
+ elif dim == 2:
71
+ if max_length == min_length:
72
+ return torch.cat([item[key] for item in items], dim=0)
73
+ tensor = torch.zeros((batch_size, max_length), dtype=dtype) + padding_value
74
+ else:
75
+ tensor = torch.zeros((batch_size, max_length, shape[-1]), dtype=dtype) + padding_value
76
+
77
+ for i, item in enumerate(items):
78
+ if dim == 2:
79
+ if padding_side == "left":
80
+ tensor[i, -len(item[key][0]):] = item[key][0].clone()
81
+ else:
82
+ tensor[i, : len(item[key][0])] = item[key][0].clone()
83
+ elif dim == 3:
84
+ if padding_side == "left":
85
+ tensor[i, -len(item[key][0]):, :] = item[key][0].clone()
86
+ else:
87
+ tensor[i, : len(item[key][0]), :] = item[key][0].clone()
88
+
89
+ return tensor
90
+
91
+
92
+ class CPMBeeCollater:
93
+ """
94
+ 针对 cpmbee 输入数据 collate, 对应 cpm-live 的 _MixedDatasetBatchPacker
95
+ 目前利用 torch 的原生 Dataloader 不太适合改造 in-context-learning
96
+ 并且原来实现为了最大化提高有效 token 比比例, 会有一个 best_fit 操作, 这个目前也不支持
97
+ todo: 重写一下 Dataloader or BatchPacker
98
+ """
99
+
100
+ def __init__(self, tokenizer: VisCpmBeeTokenizer, max_len):
101
+ self.tokenizer = tokenizer
102
+ self._max_length = max_len
103
+ self.pad_keys = ['input_ids', 'input_id_subs', 'context', 'segment_ids', 'segment_rel_offset',
104
+ 'segment_rel', 'sample_ids', 'num_segments']
105
+
106
+ def __call__(self, batch):
107
+ batch_size = len(batch)
108
+
109
+ tgt = np.full((batch_size, self._max_length), -100, dtype=np.int32)
110
+ # 目前没有 best_fit, span 为全 0
111
+ span = np.zeros((batch_size, self._max_length), dtype=np.int32)
112
+ length = np.zeros((batch_size,), dtype=np.int32)
113
+
114
+ batch_ext_table_map: Dict[Tuple[int, int], int] = {}
115
+ batch_ext_table_ids: List[int] = []
116
+ batch_ext_table_sub: List[int] = []
117
+ raw_data_list: List[Any] = []
118
+
119
+ for i in range(batch_size):
120
+ instance_length = batch[i]['input_ids'][0].shape[0]
121
+ length[i] = instance_length
122
+ raw_data_list.extend(batch[i]['raw_data'])
123
+
124
+ for j in range(instance_length):
125
+ idx, idx_sub = batch[i]['input_ids'][0, j], batch[i]['input_id_subs'][0, j]
126
+ tgt_idx = idx
127
+ if idx_sub > 0:
128
+ # need to be in ext table
129
+ if (idx, idx_sub) not in batch_ext_table_map:
130
+ batch_ext_table_map[(idx, idx_sub)] = len(batch_ext_table_map)
131
+ batch_ext_table_ids.append(idx)
132
+ batch_ext_table_sub.append(idx_sub)
133
+ tgt_idx = batch_ext_table_map[(idx, idx_sub)] + self.tokenizer.vocab_size
134
+ if j > 1 and batch[i]['context'][0, j - 1] == 0:
135
+ if idx != self.tokenizer.bos_id:
136
+ tgt[i, j - 1] = tgt_idx
137
+ else:
138
+ tgt[i, j - 1] = self.tokenizer.eos_id
139
+ if batch[i]['context'][0, instance_length - 1] == 0:
140
+ tgt[i, instance_length - 1] = self.tokenizer.eos_id
141
+
142
+ if len(batch_ext_table_map) == 0:
143
+ # placeholder
144
+ batch_ext_table_ids.append(0)
145
+ batch_ext_table_sub.append(1)
146
+
147
+ # image
148
+ if 'pixel_values' in batch[0]:
149
+ data = {'pixel_values': default_collate([i['pixel_values'] for i in batch])}
150
+ else:
151
+ data = {}
152
+
153
+ # image_bound
154
+ if 'image_bound' in batch[0]:
155
+ data['image_bound'] = default_collate([i['image_bound'] for i in batch])
156
+
157
+ # bee inp
158
+ for key in self.pad_keys:
159
+ data[key] = pad(batch, key, max_length=self._max_length, padding_value=0, padding_side='right')
160
+
161
+ data['context'] = data['context'] > 0
162
+ data['length'] = torch.from_numpy(length)
163
+ data['span'] = torch.from_numpy(span)
164
+ data['target'] = torch.from_numpy(tgt)
165
+ data['ext_table_ids'] = torch.from_numpy(np.array(batch_ext_table_ids))
166
+ data['ext_table_sub'] = torch.from_numpy(np.array(batch_ext_table_sub))
167
+ data['raw_data'] = raw_data_list
168
+
169
+ return data
170
+
171
+
172
+ class VisCPMPaintBeePipeline(StableDiffusionPipeline):
173
+ _optional_components = ["safety_checker", "feature_extractor"]
174
+
175
+ def __init__(
176
+ self,
177
+ vae: AutoencoderKL,
178
+ text_encoder: CpmBeeModel,
179
+ tokenizer: VisCpmBeeTokenizer,
180
+ unet: UNet2DConditionModel,
181
+ scheduler: KarrasDiffusionSchedulers,
182
+ safety_checker: StableDiffusionSafetyChecker,
183
+ feature_extractor: CLIPImageProcessor,
184
+ requires_safety_checker: bool = True,
185
+ ):
186
+ super().__init__(
187
+ vae=vae,
188
+ text_encoder=text_encoder,
189
+ tokenizer=tokenizer,
190
+ unet=unet,
191
+ scheduler=scheduler,
192
+ safety_checker=safety_checker,
193
+ feature_extractor=feature_extractor,
194
+ requires_safety_checker=requires_safety_checker
195
+ )
196
+
197
+ def build_input(
198
+ self,
199
+ prompt: str,
200
+ negative_prompt: Optional[str] = None,
201
+ image_size: int = 512
202
+ ):
203
+ data_input = {'caption': prompt, 'objects': ''}
204
+ (
205
+ input_ids,
206
+ input_id_subs,
207
+ context,
208
+ segment_ids,
209
+ segment_rel,
210
+ n_segments,
211
+ table_states,
212
+ image_bound
213
+ ) = self.tokenizer.convert_data_to_id(data=data_input, shuffle_answer=False, max_depth=8)
214
+ sample_ids = np.zeros(input_ids.shape, dtype=np.int32)
215
+ segment_rel_offset = np.zeros(input_ids.shape, dtype=np.int32)
216
+ num_segments = np.full(input_ids.shape, n_segments, dtype=np.int32)
217
+ data = {
218
+ 'pixel_values': torch.zeros(3, image_size, image_size).unsqueeze(0),
219
+ 'input_ids': torch.from_numpy(input_ids).unsqueeze(0),
220
+ 'input_id_subs': torch.from_numpy(input_id_subs).unsqueeze(0),
221
+ 'context': torch.from_numpy(context).unsqueeze(0),
222
+ 'segment_ids': torch.from_numpy(segment_ids).unsqueeze(0),
223
+ 'segment_rel_offset': torch.from_numpy(segment_rel_offset).unsqueeze(0),
224
+ 'segment_rel': torch.from_numpy(segment_rel).unsqueeze(0),
225
+ 'sample_ids': torch.from_numpy(sample_ids).unsqueeze(0),
226
+ 'num_segments': torch.from_numpy(num_segments).unsqueeze(0),
227
+ 'image_bound': image_bound,
228
+ 'raw_data': prompt,
229
+ }
230
+
231
+ uncond_data_input = {
232
+ 'caption': "" if negative_prompt is None else negative_prompt,
233
+ 'objects': ''
234
+ }
235
+ (
236
+ input_ids,
237
+ input_id_subs,
238
+ context,
239
+ segment_ids,
240
+ segment_rel,
241
+ n_segments,
242
+ table_states,
243
+ image_bound
244
+ ) = self.tokenizer.convert_data_to_id(data=uncond_data_input, shuffle_answer=False, max_depth=8)
245
+ sample_ids = np.zeros(input_ids.shape, dtype=np.int32)
246
+ segment_rel_offset = np.zeros(input_ids.shape, dtype=np.int32)
247
+ num_segments = np.full(input_ids.shape, n_segments, dtype=np.int32)
248
+ uncond_data = {
249
+ 'pixel_values': torch.zeros(3, image_size, image_size).unsqueeze(0),
250
+ 'input_ids': torch.from_numpy(input_ids).unsqueeze(0),
251
+ 'input_id_subs': torch.from_numpy(input_id_subs).unsqueeze(0),
252
+ 'context': torch.from_numpy(context).unsqueeze(0),
253
+ 'segment_ids': torch.from_numpy(segment_ids).unsqueeze(0),
254
+ 'segment_rel_offset': torch.from_numpy(segment_rel_offset).unsqueeze(0),
255
+ 'segment_rel': torch.from_numpy(segment_rel).unsqueeze(0),
256
+ 'sample_ids': torch.from_numpy(sample_ids).unsqueeze(0),
257
+ 'num_segments': torch.from_numpy(num_segments).unsqueeze(0),
258
+ 'image_bound': image_bound,
259
+ 'raw_data': "" if negative_prompt is None else negative_prompt,
260
+ }
261
+ packer = CPMBeeCollater(
262
+ tokenizer=self.tokenizer,
263
+ max_len=max(data['input_ids'].size(-1), uncond_data['input_ids'].size(-1))
264
+ )
265
+ data = packer([data])
266
+ uncond_data = packer([uncond_data])
267
+ return data, uncond_data
268
+
269
+ def _encode_prompt(
270
+ self,
271
+ prompt,
272
+ device,
273
+ num_images_per_prompt,
274
+ do_classifier_free_guidance,
275
+ negative_prompt=None,
276
+ prompt_embeds: Optional[torch.FloatTensor] = None,
277
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
278
+ lora_scale: Optional[float] = None,
279
+ ):
280
+ r"""
281
+ Encodes the prompt into text encoder hidden states.
282
+
283
+ Args:
284
+ prompt (`str` or `List[str]`, *optional*):
285
+ prompt to be encoded
286
+ device: (`torch.device`):
287
+ torch device
288
+ num_images_per_prompt (`int`):
289
+ number of images that should be generated per prompt
290
+ do_classifier_free_guidance (`bool`):
291
+ whether to use classifier free guidance or not
292
+ negative_prompt (`str` or `List[str]`, *optional*):
293
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
294
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
295
+ less than `1`).
296
+ prompt_embeds (`torch.FloatTensor`, *optional*):
297
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
298
+ provided, text embeddings will be generated from `prompt` input argument.
299
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
300
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
301
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
302
+ argument.
303
+ lora_scale (`float`, *optional*):
304
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
305
+ """
306
+ # set lora scale so that monkey patched LoRA
307
+ # function of text encoder can correctly access it
308
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
309
+ self._lora_scale = lora_scale
310
+
311
+ data, uncond_data = self.build_input(prompt, negative_prompt, image_size=512)
312
+ for key, value in data.items():
313
+ if isinstance(value, torch.Tensor):
314
+ data[key] = value.to(self.device)
315
+ for key, value in uncond_data.items():
316
+ if isinstance(value, torch.Tensor):
317
+ uncond_data[key] = value.to(self.device)
318
+
319
+ batch, seq_length = data['input_ids'].size()
320
+ dtype, device = data['input_ids'].dtype, data['input_ids'].device
321
+ data['position'] = torch.arange(seq_length, dtype=dtype, device=device).repeat(batch, 1)
322
+
323
+ batch, seq_length = uncond_data['input_ids'].size()
324
+ dtype, device = uncond_data['input_ids'].dtype, uncond_data['input_ids'].device
325
+ uncond_data['position'] = torch.arange(seq_length, dtype=dtype, device=device).repeat(batch, 1)
326
+
327
+ with torch.no_grad():
328
+ # llm_hidden_state = self.text_encoder.llm.input_embedding(data['input_ids'], data['input_id_subs'])
329
+ _, hidden_states = self.text_encoder(
330
+ input_ids=data['input_ids'],
331
+ input_id_sub=data['input_id_subs'],
332
+ position=data['position'],
333
+ #length=data['length'],
334
+ context=data['context'],
335
+ sample_ids=data['sample_ids'],
336
+ num_segments=data['num_segments'],
337
+ segment=data['segment_ids'],
338
+ segment_rel_offset=data['segment_rel_offset'],
339
+ segment_rel=data['segment_rel'],
340
+ #span=data['span'],
341
+ #ext_table_ids=data['ext_table_ids'],
342
+ #ext_table_sub=data['ext_table_sub'],
343
+ #hidden_states=llm_hidden_state
344
+ )
345
+
346
+ with torch.no_grad():
347
+ # uncond_llm_hidden_state = self.text_encoder.llm.input_embedding(uncond_data['input_ids'], uncond_data['input_id_subs'])
348
+ _, uncond_hidden_states = self.text_encoder(
349
+ input_ids=uncond_data['input_ids'],
350
+ input_id_sub=uncond_data['input_id_subs'],
351
+ position=uncond_data['position'],
352
+ #length=uncond_data['length'],
353
+ context=uncond_data['context'],
354
+ sample_ids=uncond_data['sample_ids'],
355
+ num_segments=uncond_data['num_segments'],
356
+ segment=uncond_data['segment_ids'],
357
+ segment_rel_offset=uncond_data['segment_rel_offset'],
358
+ segment_rel=uncond_data['segment_rel'],
359
+ #span=uncond_data['span'],
360
+ #ext_table_ids=uncond_data['ext_table_ids'],
361
+ #ext_table_sub=uncond_data['ext_table_sub'],
362
+ #hidden_states=uncond_llm_hidden_state
363
+ )
364
+
365
+ text_hidden_states, uncond_text_hidden_states = hidden_states, uncond_hidden_states
366
+ if self.text_encoder.trans_block is not None:
367
+ text_hidden_states = self.text_encoder.trans_block(text_hidden_states)
368
+ uncond_text_hidden_states = self.text_encoder.trans_block(uncond_text_hidden_states)
369
+ bs_embed, seq_len, _ = text_hidden_states.shape
370
+ text_hidden_states = text_hidden_states.repeat(1, num_images_per_prompt, 1)
371
+ text_hidden_states = text_hidden_states.view(bs_embed * num_images_per_prompt, seq_len, -1)
372
+
373
+ bs_embed, seq_len, _ = uncond_text_hidden_states.shape
374
+ uncond_text_hidden_states = uncond_text_hidden_states.repeat(1, num_images_per_prompt, 1)
375
+ uncond_text_hidden_states = uncond_text_hidden_states.view(bs_embed * num_images_per_prompt, seq_len, -1)
376
+
377
+ prompt_embeds = torch.cat([uncond_text_hidden_states, text_hidden_states])
378
+ return prompt_embeds
379
+
380
+ # if prompt is not None and isinstance(prompt, str):
381
+ # batch_size = 1
382
+ # elif prompt is not None and isinstance(prompt, list):
383
+ # batch_size = len(prompt)
384
+ # else:
385
+ # batch_size = prompt_embeds.shape[0]
386
+
387
+ # if prompt_embeds is None:
388
+ # # textual inversion: procecss multi-vector tokens if necessary
389
+ # if isinstance(self, TextualInversionLoaderMixin):
390
+ # prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
391
+
392
+ # text_inputs = self.tokenizer(
393
+ # prompt,
394
+ # padding="max_length",
395
+ # max_length=self.tokenizer.model_max_length,
396
+ # truncation=True,
397
+ # return_tensors="pt",
398
+ # )
399
+ # text_input_ids = text_inputs.input_ids
400
+ # untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
401
+
402
+ # if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
403
+ # text_input_ids, untruncated_ids
404
+ # ):
405
+ # removed_text = self.tokenizer.batch_decode(
406
+ # untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
407
+ # )
408
+ # logger.warning(
409
+ # "The following part of your input was truncated because CLIP can only handle sequences up to"
410
+ # f" {self.tokenizer.model_max_length} tokens: {removed_text}"
411
+ # )
412
+
413
+ # if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
414
+ # attention_mask = text_inputs.attention_mask.to(device)
415
+ # else:
416
+ # attention_mask = None
417
+
418
+ # prompt_embeds = self.text_encoder(
419
+ # text_input_ids.to(device),
420
+ # attention_mask=attention_mask,
421
+ # )
422
+ # prompt_embeds = prompt_embeds[0]
423
+
424
+ # prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
425
+
426
+ # bs_embed, seq_len, _ = prompt_embeds.shape
427
+ # # duplicate text embeddings for each generation per prompt, using mps friendly method
428
+ # prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
429
+ # prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
430
+
431
+ # # get unconditional embeddings for classifier free guidance
432
+ # if do_classifier_free_guidance and negative_prompt_embeds is None:
433
+ # uncond_tokens: List[str]
434
+ # if negative_prompt is None:
435
+ # uncond_tokens = [""] * batch_size
436
+ # elif prompt is not None and type(prompt) is not type(negative_prompt):
437
+ # raise TypeError(
438
+ # f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
439
+ # f" {type(prompt)}."
440
+ # )
441
+ # elif isinstance(negative_prompt, str):
442
+ # uncond_tokens = [negative_prompt]
443
+ # elif batch_size != len(negative_prompt):
444
+ # raise ValueError(
445
+ # f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
446
+ # f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
447
+ # " the batch size of `prompt`."
448
+ # )
449
+ # else:
450
+ # uncond_tokens = negative_prompt
451
+
452
+ # # textual inversion: procecss multi-vector tokens if necessary
453
+ # if isinstance(self, TextualInversionLoaderMixin):
454
+ # uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
455
+
456
+ # max_length = prompt_embeds.shape[1]
457
+ # uncond_input = self.tokenizer(
458
+ # uncond_tokens,
459
+ # padding="max_length",
460
+ # max_length=max_length,
461
+ # truncation=True,
462
+ # return_tensors="pt",
463
+ # )
464
+
465
+ # if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
466
+ # attention_mask = uncond_input.attention_mask.to(device)
467
+ # else:
468
+ # attention_mask = None
469
+
470
+ # negative_prompt_embeds = self.text_encoder(
471
+ # uncond_input.input_ids.to(device),
472
+ # attention_mask=attention_mask,
473
+ # )
474
+ # negative_prompt_embeds = negative_prompt_embeds[0]
475
+
476
+ # if do_classifier_free_guidance:
477
+ # # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
478
+ # seq_len = negative_prompt_embeds.shape[1]
479
+
480
+ # negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
481
+
482
+ # negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
483
+ # negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
484
+
485
+ # # For classifier free guidance, we need to do two forward passes.
486
+ # # Here we concatenate the unconditional and text embeddings into a single batch
487
+ # # to avoid doing two forward passes
488
+ # prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
489
+
490
+ # return prompt_embeds
491
+
492
+ def decode_latents(self, latents):
493
+ warnings.warn(
494
+ "The decode_latents method is deprecated and will be removed in a future version. Please"
495
+ " use VaeImageProcessor instead",
496
+ FutureWarning,
497
+ )
498
+ latents = 1 / self.vae.config.scaling_factor * latents
499
+ image = self.vae.decode(latents, return_dict=False)[0]
500
+ image = (image / 2 + 0.5).clamp(0, 1)
501
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
502
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
503
+ return image
504
+
505
+ def prepare_extra_step_kwargs(self, generator, eta):
506
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
507
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
508
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
509
+ # and should be between [0, 1]
510
+
511
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
512
+ extra_step_kwargs = {}
513
+ if accepts_eta:
514
+ extra_step_kwargs["eta"] = eta
515
+
516
+ # check if the scheduler accepts generator
517
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
518
+ if accepts_generator:
519
+ extra_step_kwargs["generator"] = generator
520
+ return extra_step_kwargs
521
+
522
+ def check_inputs(
523
+ self,
524
+ prompt,
525
+ height,
526
+ width,
527
+ callback_steps,
528
+ negative_prompt=None,
529
+ prompt_embeds=None,
530
+ negative_prompt_embeds=None,
531
+ ):
532
+ if height % 8 != 0 or width % 8 != 0:
533
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
534
+
535
+ if (callback_steps is None) or (
536
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
537
+ ):
538
+ raise ValueError(
539
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
540
+ f" {type(callback_steps)}."
541
+ )
542
+
543
+ if prompt is not None and prompt_embeds is not None:
544
+ raise ValueError(
545
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
546
+ " only forward one of the two."
547
+ )
548
+ elif prompt is None and prompt_embeds is None:
549
+ raise ValueError(
550
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
551
+ )
552
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
553
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
554
+
555
+ if negative_prompt is not None and negative_prompt_embeds is not None:
556
+ raise ValueError(
557
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
558
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
559
+ )
560
+
561
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
562
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
563
+ raise ValueError(
564
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
565
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
566
+ f" {negative_prompt_embeds.shape}."
567
+ )
568
+
569
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
570
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
571
+ if isinstance(generator, list) and len(generator) != batch_size:
572
+ raise ValueError(
573
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
574
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
575
+ )
576
+
577
+ if latents is None:
578
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
579
+ else:
580
+ latents = latents.to(device)
581
+
582
+ # scale the initial noise by the standard deviation required by the scheduler
583
+ latents = latents * self.scheduler.init_noise_sigma
584
+ return latents
585
+
586
+ @torch.no_grad()
587
+ def __call__(
588
+ self,
589
+ prompt: Union[str, List[str]] = None,
590
+ height: Optional[int] = None,
591
+ width: Optional[int] = None,
592
+ num_inference_steps: int = 50,
593
+ guidance_scale: float = 7.5,
594
+ negative_prompt: Optional[Union[str, List[str]]] = None,
595
+ num_images_per_prompt: Optional[int] = 1,
596
+ eta: float = 0.0,
597
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
598
+ latents: Optional[torch.FloatTensor] = None,
599
+ prompt_embeds: Optional[torch.FloatTensor] = None,
600
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
601
+ output_type: Optional[str] = "pil",
602
+ return_dict: bool = True,
603
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
604
+ callback_steps: int = 1,
605
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
606
+ guidance_rescale: float = 0.0,
607
+ ):
608
+ # 0. Default height and width to unet
609
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
610
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
611
+
612
+ # 1. Check inputs. Raise error if not correct
613
+ self.check_inputs(
614
+ prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
615
+ )
616
+
617
+ # 2. Define call parameters
618
+ if prompt is not None and isinstance(prompt, str):
619
+ batch_size = 1
620
+ elif prompt is not None and isinstance(prompt, list):
621
+ batch_size = len(prompt)
622
+ else:
623
+ batch_size = prompt_embeds.shape[0]
624
+
625
+ device = self._execution_device
626
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
627
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
628
+ # corresponds to doing no classifier free guidance.
629
+ do_classifier_free_guidance = guidance_scale > 1.0
630
+
631
+ # 3. Encode input prompt
632
+ text_encoder_lora_scale = (
633
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
634
+ )
635
+
636
+ prompt_embeds = self._encode_prompt(
637
+ prompt,
638
+ device,
639
+ num_images_per_prompt,
640
+ do_classifier_free_guidance,
641
+ negative_prompt,
642
+ prompt_embeds=prompt_embeds,
643
+ negative_prompt_embeds=negative_prompt_embeds,
644
+ lora_scale=text_encoder_lora_scale,
645
+ )
646
+
647
+ # 4. Prepare timesteps
648
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
649
+ timesteps = self.scheduler.timesteps
650
+
651
+ # 5. Prepare latent variables
652
+ num_channels_latents = self.unet.config.in_channels
653
+ latents = self.prepare_latents(
654
+ batch_size * num_images_per_prompt,
655
+ num_channels_latents,
656
+ height,
657
+ width,
658
+ prompt_embeds.dtype,
659
+ device,
660
+ generator,
661
+ latents,
662
+ )
663
+
664
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
665
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
666
+
667
+ # 7. Denoising loop
668
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
669
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
670
+ for i, t in enumerate(timesteps):
671
+ # expand the latents if we are doing classifier free guidance
672
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
673
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
674
+
675
+ # predict the noise residual
676
+ noise_pred = self.unet(
677
+ latent_model_input,
678
+ t,
679
+ encoder_hidden_states=prompt_embeds,
680
+ cross_attention_kwargs=cross_attention_kwargs,
681
+ return_dict=False,
682
+ )[0]
683
+
684
+ # perform guidance
685
+ if do_classifier_free_guidance:
686
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
687
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
688
+
689
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
690
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
691
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
692
+
693
+ # compute the previous noisy sample x_t -> x_t-1
694
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
695
+
696
+ # call the callback, if provided
697
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
698
+ progress_bar.update()
699
+ if callback is not None and i % callback_steps == 0:
700
+ callback(i, t, latents)
701
+
702
+ if not output_type == "latent":
703
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
704
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
705
+ else:
706
+ image = latents
707
+ has_nsfw_concept = None
708
+
709
+ if has_nsfw_concept is None:
710
+ do_denormalize = [True] * image.shape[0]
711
+ else:
712
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
713
+
714
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
715
+
716
+ # Offload last model to CPU
717
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
718
+ self.final_offload_hook.offload()
719
+
720
+ if not return_dict:
721
+ return (image, has_nsfw_concept)
722
+
723
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
scheduler/scheduler_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "DDIMScheduler",
3
+ "_diffusers_version": "0.8.0",
4
+ "beta_end": 0.012,
5
+ "beta_schedule": "scaled_linear",
6
+ "beta_start": 0.00085,
7
+ "clip_sample": false,
8
+ "num_train_timesteps": 1000,
9
+ "prediction_type": "epsilon",
10
+ "set_alpha_to_one": false,
11
+ "skip_prk_steps": true,
12
+ "steps_offset": 1,
13
+ "trained_betas": null
14
+ }
tokenization_viscpmbee.py ADDED
@@ -0,0 +1,1008 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The OpenBMB Team and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes for CpmBee."""
16
+ import json
17
+ import os
18
+ from typing import Any, Dict, List, Optional, Tuple, Union
19
+
20
+ import numpy as np
21
+ from numpy.typing import NDArray
22
+ from typing_extensions import TypedDict
23
+
24
+ from transformers.tokenization_utils import PaddingStrategy, PreTrainedTokenizer, TensorType
25
+ from transformers.tokenization_utils_base import AddedToken, BatchEncoding, TextInput, TruncationStrategy
26
+ from transformers.utils import logging
27
+
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
32
+
33
+ PRETRAINED_VOCAB_FILES_MAP = {
34
+ "vocab_file": {
35
+ "openbmb/viscpmchat-bee-10b": "https://huggingface.co/openbmb/VisCPM-Chat/blob/main/vocab.txt",
36
+ },
37
+ }
38
+
39
+ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
40
+ "openbmb/viscpmchat-bee-10b": 4096,
41
+ }
42
+
43
+
44
+ class _PrevExtTableStates(TypedDict):
45
+ ext_table: Dict[int, str]
46
+ token_id_table: Dict[str, Dict[int, int]]
47
+
48
+
49
+ CPMBeeInputType = Union[str, Dict[str, "CPMBeeInputType"]]
50
+
51
+
52
+ def rel_to_bucket(n_up: int, n_down: int, max_depth: int = 8):
53
+ ret = n_up * max_depth + n_down
54
+ if ret == 0:
55
+ return ret
56
+ else:
57
+ # bucket 1 is reserved for incontext samples
58
+ return ret + 1
59
+
60
+
61
+ class _DictTree(TypedDict):
62
+ value: str
63
+ children: List["_DictTree"]
64
+ depth: int
65
+ segment_id: int
66
+ need_predict: bool
67
+ is_image: bool
68
+
69
+
70
+ class VisCpmBeeTokenizer(PreTrainedTokenizer):
71
+ """
72
+ Construct a CPMBee tokenizer.
73
+
74
+ Args:
75
+ vocab_file (`str`):
76
+ Path to the vocabulary file.
77
+ bos_token (`str`, *optional*, defaults to `"<s>"`):
78
+ The beginning of sequence token.
79
+ eos_token (`str`, *optional*, defaults to `"</s>"`):
80
+ The end of sequence token.
81
+ line_token (`str`, *optional*, defaults to `"\n"`):
82
+ The line token.
83
+ space_token (`str`, *optional*, defaults to `" "`):
84
+ The space token.
85
+ unk_token (`str`, *optional*, defaults to `"<unk>"`):
86
+ The unknown token.
87
+ mask_token (`str`, *optional*, defaults to `"<mask>"`):
88
+ The mask token.
89
+ pad_token (`str`, *optional*, defaults to `"<pad>"`):
90
+ The token used for padding.
91
+ padding_side (`str`, *optional*, defaults to `"left"`):
92
+ The padding side. CPM-Bee will use left padding by default.
93
+ """
94
+
95
+ vocab_files_names = VOCAB_FILES_NAMES
96
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
97
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
98
+ model_input_names: List[str] = [
99
+ "input_ids",
100
+ "attention_mask",
101
+ "input_id_sub",
102
+ "position",
103
+ "context",
104
+ "sample_ids",
105
+ "num_segments",
106
+ "segment",
107
+ "segment_rel_offset",
108
+ "segment_rel",
109
+ ]
110
+ add_prefix_space = False
111
+
112
+ def __init__(
113
+ self,
114
+ vocab_file,
115
+ bos_token="<s>",
116
+ eos_token="</s>",
117
+ line_token="\n",
118
+ space_token=" ",
119
+ unk_token="<unk>",
120
+ mask_token="<mask>",
121
+ pad_token="<pad>",
122
+ padding_side="left",
123
+ **kwargs,
124
+ ):
125
+ super().__init__(
126
+ bos_token=bos_token,
127
+ eos_token=eos_token,
128
+ line_token=line_token,
129
+ space_token=space_token,
130
+ unk_token=unk_token,
131
+ mask_token=mask_token,
132
+ pad_token=pad_token,
133
+ padding_side=padding_side,
134
+ **kwargs,
135
+ )
136
+
137
+ self.encoder: Dict[str, int] = {}
138
+
139
+ with open(vocab_file, "r", encoding="utf-8") as reader:
140
+ for token in reader.readlines():
141
+ token = token.rstrip("\n")
142
+ if len(token) == 0:
143
+ continue
144
+ self.encoder[token] = len(self.encoder)
145
+
146
+ self.encoder[" "] = self.encoder["</_>"]
147
+ self.encoder["\n"] = self.encoder["</n>"]
148
+ del self.encoder["</_>"]
149
+ del self.encoder["</n>"]
150
+
151
+ self.decoder = {v: k for k, v in self.encoder.items()}
152
+
153
+ self._max_word_len = max([len(x) for x in self.encoder.keys()])
154
+ self.cpmbee_special_tokens = {k: v for k, v in self.encoder.items() if k.startswith("<") and k.endswith(">")}
155
+
156
+ self.ext_table: Dict[int, str] = {}
157
+ self.ext_table_rev: Dict[str, int] = {}
158
+
159
+ self.token_id_table: Dict[str, Dict[int, int]] = {}
160
+ self.ext_special_tokens = []
161
+
162
+ self.ext_args_for_model = [
163
+ "input_id_subs",
164
+ "input_pos",
165
+ "context",
166
+ "segment_ids",
167
+ "segment_rel_offset",
168
+ "segment_rel",
169
+ "sample_ids",
170
+ "num_segments",
171
+ "predict_segments",
172
+ "answer_placeholders",
173
+ "ext_table",
174
+ "token_id_table",
175
+ "image_bound"
176
+ ]
177
+
178
+ @property
179
+ def bod_token_id(self):
180
+ return self.encoder[self.bod_token]
181
+
182
+ @property
183
+ def eod_token_id(self):
184
+ return self.encoder[self.eod_token]
185
+
186
+ @property
187
+ def newline_id(self):
188
+ return self.encoder[self.line_token]
189
+
190
+ @property
191
+ def vocab_size(self) -> int:
192
+ return len(self.encoder)
193
+
194
+ def __len__(self):
195
+ """
196
+ Size of the full vocabulary with the added tokens.
197
+ """
198
+ return self.vocab_size + len(self.added_tokens_encoder)
199
+
200
+ def get_vocab(self):
201
+ return dict(self.encoder, **self.added_tokens_encoder)
202
+
203
+ def get_piece(self, text: str) -> str:
204
+ """
205
+ Match with maximum length.
206
+ """
207
+ len_text = len(text)
208
+ for i in range(len(text)):
209
+ sub = text[: len_text - i]
210
+ if (sub in self.encoder) or (sub in self.added_tokens_encoder):
211
+ return sub
212
+ return text[0]
213
+
214
+ def tokenize(self, text: TextInput, **kwargs) -> List[str]:
215
+ r"""
216
+ Override the `tokenize` to meet the needs of CPMBee:
217
+ 1. Mark the special token with `<` and `>`. The `<>` will be ignored.
218
+ 2. Split sentences by the marked special tokens.
219
+ 3. Record the marked special token by `ext_table` and `ext_table_rev`.
220
+ 4. Tokenize the sentence without special tokens.
221
+ """
222
+ for_cpmbee = kwargs.get("for_cpmbee", False)
223
+ all_special_tokens_extended = {
224
+ str(t): t for t in self.all_special_tokens_extended if isinstance(t, AddedToken)
225
+ }
226
+
227
+ sentence_split = [""]
228
+ is_special_token = False
229
+ for i, c in enumerate(text):
230
+ if is_special_token:
231
+ if c == "<":
232
+ tail = sentence_split.pop(-1)
233
+ sentence_split[-1] += tail
234
+ sentence_split.append(c)
235
+ elif c == ">":
236
+ # end of special token
237
+ sentence_split[-1] += c
238
+ if sentence_split[-1] == "<>":
239
+ continue
240
+ is_special_token = False
241
+ sentence_split.append("")
242
+ else:
243
+ sentence_split[-1] += c
244
+ else:
245
+ if c == "<":
246
+ is_special_token = True
247
+ sentence_split.append(c)
248
+ else:
249
+ sentence_split[-1] += c
250
+ if is_special_token:
251
+ tail = sentence_split.pop(-1)
252
+ sentence_split[-1] += tail
253
+
254
+ output_tokens = []
255
+ for i, part in enumerate(sentence_split):
256
+ if (i & 1) == 1:
257
+ # special token
258
+ output_tokens.append(part)
259
+ if for_cpmbee and (part not in self.encoder) and (part not in self.ext_table_rev):
260
+ self.ext_table_rev[part] = len(self.ext_table_rev) + self.vocab_size
261
+ self.ext_table[self.ext_table_rev[part]] = part
262
+ else:
263
+ output_tokens.extend(self._tokenize(part, for_cpmbee=for_cpmbee))
264
+
265
+ # drop spaces
266
+ for i, token in enumerate(output_tokens):
267
+ if token in self.added_tokens_encoder:
268
+ token = all_special_tokens_extended.get(token, None)
269
+ left = output_tokens[i - 1] if i > 0 else None
270
+ right = output_tokens[i + 1] if i < len(output_tokens) - 1 else None
271
+ if isinstance(token, AddedToken):
272
+ if token.rstrip and right:
273
+ # A bit counter-intuitive but we strip the left of the string
274
+ # since tok_extended.rstrip means the special token is eating all white spaces on its right
275
+ output_tokens[i + 1] = right.lstrip()
276
+ # Strip white spaces on the left
277
+ if token.lstrip and left:
278
+ output_tokens[i - 1] = left.rstrip() # Opposite here
279
+ else:
280
+ if right:
281
+ output_tokens[i + 1] = right.lstrip()
282
+ if left:
283
+ output_tokens[i - 1] = left.rstrip()
284
+
285
+ skipped_tokens = []
286
+ for token in output_tokens:
287
+ if not token:
288
+ continue
289
+ else:
290
+ skipped_tokens.append(token)
291
+
292
+ return skipped_tokens
293
+
294
+ def _tokenize(self, text, **kwargs):
295
+ """
296
+ Converts a string in a sequence of tokens (string), using the tokenizer. Split in words for word-based
297
+ vocabulary.
298
+
299
+ Do NOT take care of added tokens. Record the unk tokens and special tokens in `ext_table` and `ext_table_rev`.
300
+ """
301
+ for_cpmbee = kwargs.get("for_cpmbee", False)
302
+ output_tokens = []
303
+
304
+ part_st = 0
305
+ last_unk = None
306
+ while part_st < len(text):
307
+ piece = self.get_piece(text[part_st:])
308
+ if piece in self.encoder or self.added_tokens_encoder:
309
+ if last_unk is None:
310
+ output_tokens.append(piece)
311
+ else:
312
+ if for_cpmbee and (last_unk not in self.ext_table_rev):
313
+ self.ext_table_rev[last_unk] = len(self.ext_table_rev) + self.vocab_size
314
+ self.ext_table[self.ext_table_rev[last_unk]] = last_unk
315
+ output_tokens.append(last_unk)
316
+ output_tokens.append(piece)
317
+ last_unk = None
318
+ else:
319
+ if last_unk is None:
320
+ last_unk = piece
321
+ else:
322
+ last_unk += piece
323
+ part_st += len(piece)
324
+ if last_unk is not None:
325
+ # part end with UNK
326
+ if for_cpmbee and (last_unk not in self.ext_table_rev):
327
+ self.ext_table_rev[last_unk] = len(self.ext_table_rev) + self.vocab_size
328
+ self.ext_table[self.ext_table_rev[last_unk]] = last_unk
329
+ output_tokens.append(last_unk)
330
+
331
+ return output_tokens
332
+
333
+ def check(self, token):
334
+ return token in self.encoder
335
+
336
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
337
+ return "".join(tokens)
338
+
339
+ def _convert_token_to_id(self, token: str):
340
+ """Converts a token (str) in an id using the vocab and ext_table."""
341
+ if token in self.encoder:
342
+ return self.encoder.get(token)
343
+ elif token in self.ext_table_rev:
344
+ return self.ext_table_rev[token]
345
+ elif token in self.added_tokens_encoder:
346
+ return self.added_tokens_encoder[token]
347
+ else:
348
+ return self.unk_token_id
349
+
350
+ def _convert_id_to_token(self, index):
351
+ """Converts an index (integer) in a token (str) using the vocab and ext_table."""
352
+ if index in self.ext_table:
353
+ return self.ext_table[index]
354
+ elif index in self.added_tokens_decoder:
355
+ return self.added_tokens_decoder[index]
356
+ else:
357
+ if index >= 0:
358
+ return self.decoder[index]
359
+
360
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
361
+ if os.path.isdir(save_directory):
362
+ vocab_file = os.path.join(
363
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
364
+ )
365
+ else:
366
+ vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
367
+ index = 0
368
+ self.encoder["</n>"] = self.encoder["\n"]
369
+ del self.encoder["\n"]
370
+ self.encoder["</_>"] = self.encoder[" "]
371
+ del self.encoder[" "]
372
+ with open(vocab_file, "w", encoding="utf-8") as writer:
373
+ for token, token_index in sorted(self.encoder.items(), key=lambda x: x[1]):
374
+ if index != token_index:
375
+ logger.warning(
376
+ f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
377
+ " Please check that the vocabulary is not corrupted!"
378
+ )
379
+ index = token_index
380
+ writer.write(token + "\n")
381
+ index += 1
382
+ return (vocab_file,)
383
+
384
+ def __call__(self, text, *args, **kwargs):
385
+ r"""
386
+ CPMBee `call` method will use `_tokenize_cpmbee` when the input type is dict.
387
+ """
388
+ if isinstance(text, dict):
389
+ return self._batch_tokenize_cpmbee([text], *args, **kwargs)
390
+ elif isinstance(text, (list, tuple)):
391
+ if isinstance(text[0], dict):
392
+ return self._batch_tokenize_cpmbee(text, *args, **kwargs)
393
+ else:
394
+ return super().__call__(text, *args, **kwargs)
395
+ else:
396
+ return super().__call__(text, *args, **kwargs)
397
+
398
+ # 分词
399
+ def _tokenize_cpmbee(self, data: TextInput, *args, **kwargs) -> List[str]:
400
+ """
401
+ A tokenize method to process dict data. Exclusive for CPMBee.
402
+ """
403
+ if isinstance(data, str):
404
+ data = json.loads(data)
405
+ if not isinstance(data, Dict):
406
+ raise TypeError(
407
+ "CpmBeeTokenizer input data should be dict or str in dict format, but got {}".format(type(data))
408
+ )
409
+
410
+ # 1. prepare answer placeholder
411
+ answer_placeholders = []
412
+
413
+ def _put_placeholder(data: Any, path: List[str] = []):
414
+ if isinstance(data, dict):
415
+ ret = {}
416
+ for k, v in data.items():
417
+ ret[k] = _put_placeholder(v, path + [k])
418
+ return ret
419
+ else:
420
+ answer_placeholders.append(path)
421
+ return "<ans_{}>".format(len(answer_placeholders))
422
+
423
+ data["<ans>"] = _put_placeholder(data["<ans>"])
424
+
425
+ (
426
+ input_ids,
427
+ input_id_subs,
428
+ context,
429
+ segment_ids,
430
+ segment_rel,
431
+ n_segments,
432
+ table_states,
433
+ image_bound
434
+ ) = self.convert_data_to_id(data, shuffle_answer=False, max_depth=8)
435
+
436
+ # <ans> mapping from sub to id
437
+ sub_ans_map: Dict[int, int] = {}
438
+ for fake_id, token_sub in table_states["token_id_table"]["<ans>"].items():
439
+ token = table_states["ext_table"][fake_id]
440
+ if token.startswith("<ans_") and token.endswith(">"):
441
+ ans_id = int(token[5:-1])
442
+ sub_ans_map[token_sub] = ans_id
443
+
444
+ tmp_input_ids = []
445
+ tmp_input_sub = []
446
+ tmp_input_seg = []
447
+
448
+ # get predict segments
449
+ predict_segments: List[Tuple[int, int]] = []
450
+ for i in range(input_ids.shape[0]):
451
+ if context[i] == 0:
452
+ if input_ids[i] == self.encoder["<ans>"]:
453
+ # is ans
454
+ # (segment_id, ans_id)
455
+ predict_segments.append((segment_ids[i], sub_ans_map[input_id_subs[i]]))
456
+ else:
457
+ tmp_input_ids.append(input_ids[i])
458
+ tmp_input_sub.append(input_id_subs[i])
459
+ tmp_input_seg.append(segment_ids[i])
460
+
461
+ if len(predict_segments) == 0:
462
+ raise ValueError("No answer to predict")
463
+
464
+ input_ids = np.array(tmp_input_ids, dtype=np.int32) # all context
465
+ input_id_subs = np.array(tmp_input_sub, dtype=np.int32) # [0, 0, 0, 0, 1, 0, 0, 2, 0, ...]
466
+ context = np.full_like(tmp_input_ids, 1, dtype=np.int8) # [1, 1, 1, ...]
467
+ segment_ids = np.array(tmp_input_seg, dtype=np.int32) # [0, 0, 0, 1, 1, 1, 2, 2, 2, 2, ...]
468
+ sample_ids = np.zeros(input_ids.shape, dtype=np.int32) # [0, 0, 0, 0, ...]
469
+ segment_rel_offset = np.zeros(input_ids.shape, dtype=np.int32) # [0, 0, 0, ...]
470
+ num_segments = np.full(input_ids.shape, n_segments, dtype=np.int32) # [n_seg, n_seg, n_seg, ...]
471
+ input_pos = np.arange(input_ids.shape[0], dtype=np.int32) # [0, 1, 2, 3, 4, ...]
472
+ image_bound = np.array(image_bound)
473
+
474
+ return (
475
+ self.prepare_for_model(
476
+ input_ids.tolist(),
477
+ input_id_subs=input_id_subs.tolist(),
478
+ input_pos=input_pos.tolist(),
479
+ context=context.tolist(),
480
+ segment_ids=segment_ids.tolist(),
481
+ segment_rel_offset=segment_rel_offset.tolist(),
482
+ segment_rel=segment_rel.tolist(),
483
+ sample_ids=sample_ids.tolist(),
484
+ num_segments=num_segments.tolist(),
485
+ image_bound=image_bound,
486
+ **kwargs,
487
+ ),
488
+ predict_segments,
489
+ answer_placeholders,
490
+ table_states["ext_table"],
491
+ table_states["token_id_table"],
492
+ )
493
+
494
+ def _batch_tokenize_cpmbee(self, data_lst, *args, **kwargs):
495
+ """
496
+ Batched _token_cpmbee.
497
+ """
498
+ device = kwargs.get("device", "cpu")
499
+ return_tensors = kwargs.get("return_tensors", None)
500
+ batch_outputs = {}
501
+ segment_rel_pack = []
502
+ other_info = []
503
+
504
+ batch_ext_table_map: Dict[Tuple[int, int], int] = {}
505
+ batch_ext_table_ids: List[int] = []
506
+ batch_ext_table_sub: List[int] = []
507
+
508
+ for data in data_lst:
509
+ self.ext_table = {}
510
+ self.ext_table_rev = {}
511
+ self.token_id_table = {}
512
+ (outputs, predict_segments, answer_placeholders, ext_table, token_id_table) = self._tokenize_cpmbee(
513
+ data,
514
+ truncation=None,
515
+ padding=PaddingStrategy.DO_NOT_PAD.value,
516
+ max_length=None,
517
+ pad_to_multiple_of=None,
518
+ return_attention_mask=False,
519
+ return_tensors=None,
520
+ )
521
+ rev_ext_table = {}
522
+ for token, mp in token_id_table.items():
523
+ if token == "<ans>":
524
+ continue
525
+ token_id = self.encoder[token]
526
+ for fake_id, token_sub in mp.items():
527
+ if token_sub > 0:
528
+ if (token_id, token_sub) not in batch_ext_table_map:
529
+ batch_ext_table_map[(token_id, token_sub)] = len(batch_ext_table_ids) + self.vocab_size
530
+ batch_ext_table_ids.append(token_id)
531
+ batch_ext_table_sub.append(token_sub)
532
+ rev_ext_table[batch_ext_table_map[(token_id, token_sub)]] = ext_table[fake_id]
533
+ else:
534
+ rev_ext_table[token_id] = ext_table[fake_id]
535
+
536
+ segment_rel_pack.append(np.array(outputs.pop("segment_rel")))
537
+ other_info.append(
538
+ {
539
+ "predict_segments": predict_segments,
540
+ "answer_placeholders": answer_placeholders,
541
+ "ext_table": rev_ext_table,
542
+ }
543
+ )
544
+
545
+ for key, value in outputs.items():
546
+ if key not in batch_outputs:
547
+ batch_outputs[key] = []
548
+ batch_outputs[key].append(value)
549
+
550
+ max_length = max([len(item) for item in batch_outputs[self.model_input_names[0]]])
551
+ batch_size = len(batch_outputs[self.model_input_names[0]])
552
+ for i in range(batch_size):
553
+ inputs = {k: v[i] for k, v in batch_outputs.items()}
554
+
555
+ for k, v in inputs.items():
556
+ required_input = v
557
+
558
+ needs_to_be_padded = len(required_input) != max_length and k != 'image_bound'
559
+
560
+ if needs_to_be_padded:
561
+ difference = max_length - len(required_input)
562
+ batch_outputs[k][i] = [self.pad_token_id] * difference + required_input
563
+
564
+ max_num_rels = 0
565
+ for rel in segment_rel_pack:
566
+ max_num_rels = max(max_num_rels, rel.shape[0])
567
+ padded_rels = np.zeros((len(segment_rel_pack), max_num_rels), dtype=np.int32)
568
+ for i, rel in enumerate(segment_rel_pack):
569
+ padded_rels[i, : rel.shape[0]] = rel
570
+ batch_outputs["segment_rel"] = padded_rels
571
+ batch_outputs["batch_ext_table_ids"] = np.array(batch_ext_table_ids, dtype=np.int32)
572
+ batch_outputs["batch_ext_table_sub"] = np.array(batch_ext_table_sub, dtype=np.int32)
573
+ batch_outputs = BatchEncoding(batch_outputs, tensor_type=return_tensors)
574
+ if return_tensors == "pt":
575
+ batch_outputs = batch_outputs.to(device=device)
576
+ batch_outputs["other_info"] = other_info
577
+
578
+ return batch_outputs
579
+
580
+ def convert_data_to_id(
581
+ self,
582
+ data: Any,
583
+ prev_ext_states: Optional[_PrevExtTableStates] = None,
584
+ shuffle_answer: bool = True,
585
+ max_depth: int = 8,
586
+ ):
587
+ """
588
+ Parse a dict to data ids. Exclusive for CPMBee. It will
589
+ 1. parse the dict to segments and get segment_rel, which for calculating of position_bias.
590
+ 2. tokenize every segment.
591
+ """
592
+ root: _DictTree = {
593
+ "value": "<root>",
594
+ "children": [],
595
+ "depth": 0,
596
+ "segment_id": 0,
597
+ "need_predict": False,
598
+ "is_image": False
599
+ }
600
+
601
+ segments = [root]
602
+
603
+ def _build_dict_tree(data: CPMBeeInputType, depth: int, need_predict: bool, is_image: bool) -> List[_DictTree]:
604
+ if isinstance(data, dict):
605
+ ret_list: List[_DictTree] = []
606
+ curr_items = list(data.items())
607
+ if need_predict and shuffle_answer:
608
+ access_idx = np.arange(len(curr_items))
609
+ np.random.shuffle(access_idx)
610
+ curr_items = [curr_items[idx] for idx in access_idx]
611
+ for k, v in curr_items:
612
+ child_info: _DictTree = {
613
+ "value": k,
614
+ "children": [],
615
+ "depth": depth,
616
+ "segment_id": len(segments),
617
+ "need_predict": False, # only leaves are contexts
618
+ "is_image": False,
619
+ }
620
+ segments.append(child_info)
621
+ child_info["children"] = _build_dict_tree(
622
+ v, depth + 1,
623
+ need_predict=need_predict or (depth == 1 and k == "<ans>"),
624
+ is_image=is_image or (depth == 1 and k == "image")
625
+ ) # elements in <root>.<ans>
626
+
627
+ ret_list.append(child_info)
628
+ return ret_list
629
+ else:
630
+ assert isinstance(data, str), "Invalid data {}".format(data)
631
+ ret: _DictTree = {
632
+ "value": data,
633
+ "children": [],
634
+ "depth": depth,
635
+ "segment_id": len(segments),
636
+ "need_predict": need_predict,
637
+ "is_image": is_image,
638
+ }
639
+ segments.append(ret)
640
+ return [ret]
641
+
642
+ root["children"] = _build_dict_tree(data, 1, False, False)
643
+
644
+ num_segments = len(segments)
645
+ segment_rel = np.zeros((num_segments * num_segments,), dtype=np.int32)
646
+
647
+ def _build_segment_rel(node: _DictTree) -> List[Tuple[int, int]]:
648
+ ret: List[Tuple[int, int]] = [(node["segment_id"], node["depth"])]
649
+ for child in node["children"]:
650
+ sub = _build_segment_rel(child)
651
+ for seg_id_1, depth_1 in sub:
652
+ for seg_id_2, depth_2 in ret:
653
+ n_up = min(depth_1 - node["depth"], max_depth - 1)
654
+ n_down = min(depth_2 - node["depth"], max_depth - 1)
655
+ segment_rel[seg_id_1 * num_segments + seg_id_2] = rel_to_bucket(
656
+ n_up, n_down, max_depth=max_depth
657
+ )
658
+ segment_rel[seg_id_2 * num_segments + seg_id_1] = rel_to_bucket(
659
+ n_down, n_up, max_depth=max_depth
660
+ )
661
+ ret.extend(sub)
662
+ return ret
663
+
664
+ _build_segment_rel(root)
665
+
666
+ input_ids: List[int] = []
667
+ input_id_subs: List[int] = []
668
+ segment_bound: List[Tuple[int, int]] = []
669
+ image_bound: List[Tuple[int, int]] = []
670
+
671
+
672
+ if prev_ext_states is not None:
673
+ self.ext_table = prev_ext_states["ext_table"]
674
+ self.token_id_table = prev_ext_states["token_id_table"]
675
+
676
+ for seg in segments:
677
+ # tokenize
678
+ tokens = self.convert_tokens_to_ids(self.tokenize(seg["value"], for_cpmbee=True))
679
+
680
+ token_id_subs = []
681
+ reid_token_ids = []
682
+ for idx in tokens:
683
+ if idx in self.ext_table:
684
+ # unk or special token
685
+ token = self.ext_table[idx]
686
+ if token.startswith("<") and token.endswith(">"):
687
+ # special token
688
+ if "_" in token:
689
+ token_name = token[1:-1].split("_", maxsplit=1)[0]
690
+ else:
691
+ token_name = token[1:-1]
692
+ token_name = "<{}>".format(token_name)
693
+ else:
694
+ token_name = "<unk>"
695
+
696
+ if token_name not in self.token_id_table:
697
+ self.token_id_table[token_name] = {}
698
+ if idx not in self.token_id_table[token_name]:
699
+ self.token_id_table[token_name][idx] = len(self.token_id_table[token_name])
700
+ if token_name not in self.encoder:
701
+ raise ValueError("Invalid token {}".format(token))
702
+ reid_token_ids.append(self.encoder[token_name])
703
+ token_id_subs.append(self.token_id_table[token_name][idx])
704
+ else:
705
+ reid_token_ids.append(idx)
706
+ token_id_subs.append(0)
707
+ tokens = [self.bos_token_id] + reid_token_ids
708
+ token_id_subs = [0] + token_id_subs
709
+ # eos_id 表示 no need_predict
710
+ if not seg["need_predict"]: # eos
711
+ tokens = tokens + [self.eos_token_id]
712
+ token_id_subs = token_id_subs + [0]
713
+ else:
714
+ # no eos
715
+ pass
716
+ begin = len(input_ids)
717
+ input_ids.extend(tokens)
718
+ input_id_subs.extend(token_id_subs)
719
+ end = len(input_ids)
720
+ segment_bound.append((begin, end))
721
+
722
+ ids = np.array(input_ids, dtype=np.int32)
723
+ id_subs = np.array(input_id_subs, dtype=np.int32)
724
+ segs = np.zeros((ids.shape[0],), dtype=np.int32) # 按segment_bound对seg编号
725
+ context = np.zeros((ids.shape[0],), dtype=np.int8)
726
+ for i, (begin, end) in enumerate(segment_bound):
727
+ if not segments[i]["need_predict"]:
728
+ context[begin:end] = 1
729
+ if segments[i]["is_image"]:
730
+ image_bound.append((begin + 1, end - 1))
731
+ segs[begin:end] = i
732
+
733
+ curr_ext_table_states: _PrevExtTableStates = {
734
+ "ext_table": self.ext_table,
735
+ "token_id_table": self.token_id_table,
736
+ }
737
+ image_bound = np.array(image_bound, dtype=np.int32)
738
+ return ids, id_subs, context, segs, segment_rel, num_segments, curr_ext_table_states, image_bound
739
+
740
+ def prepare_for_model(
741
+ self,
742
+ ids: List[int],
743
+ pair_ids: Optional[List[int]] = None,
744
+ add_special_tokens: bool = True,
745
+ padding: Union[bool, str, PaddingStrategy] = False,
746
+ truncation: Union[bool, str, TruncationStrategy] = None,
747
+ max_length: Optional[int] = None,
748
+ stride: int = 0,
749
+ pad_to_multiple_of: Optional[int] = None,
750
+ return_tensors: Optional[Union[str, TensorType]] = None,
751
+ return_token_type_ids: Optional[bool] = None,
752
+ return_attention_mask: Optional[bool] = None,
753
+ return_overflowing_tokens: bool = False,
754
+ return_special_tokens_mask: bool = False,
755
+ return_length: bool = False,
756
+ verbose: bool = True,
757
+ prepend_batch_axis: bool = False,
758
+ **kwargs,
759
+ ) -> BatchEncoding:
760
+ """
761
+ Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It
762
+ adds special tokens, truncates sequences if overflowing while taking into account the special tokens and
763
+ manages a moving window (with user defined stride) for overflowing tokens. Please Note, for *pair_ids*
764
+ different than `None` and *truncation_strategy = longest_first* or `True`, it is not possible to return
765
+ overflowing tokens. Such a combination of arguments will raise an error.
766
+
767
+ Args:
768
+ ids (`List[int]`):
769
+ Tokenized input ids of the first sequence. Can be obtained from a string by chaining the `tokenize` and
770
+ `convert_tokens_to_ids` methods.
771
+ pair_ids (`List[int]`, *optional*):
772
+ Tokenized input ids of the second sequence. Can be obtained from a string by chaining the `tokenize`
773
+ and `convert_tokens_to_ids` methods.
774
+ """
775
+
776
+ # Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
777
+ padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
778
+ padding=padding,
779
+ truncation=truncation,
780
+ max_length=max_length,
781
+ pad_to_multiple_of=pad_to_multiple_of,
782
+ verbose=verbose,
783
+ **kwargs,
784
+ )
785
+
786
+ pair = bool(pair_ids is not None)
787
+ len_ids = len(ids)
788
+ len_pair_ids = len(pair_ids) if pair else 0
789
+
790
+ if return_token_type_ids and not add_special_tokens:
791
+ raise ValueError(
792
+ "Asking to return token_type_ids while setting add_special_tokens to False "
793
+ "results in an undefined behavior. Please set add_special_tokens to True or "
794
+ "set return_token_type_ids to None."
795
+ )
796
+
797
+ if (
798
+ return_overflowing_tokens
799
+ and truncation_strategy == TruncationStrategy.LONGEST_FIRST
800
+ and pair_ids is not None
801
+ ):
802
+ raise ValueError(
803
+ "Not possible to return overflowing tokens for pair of sequences with the "
804
+ "`longest_first`. Please select another truncation strategy than `longest_first`, "
805
+ "for instance `only_second` or `only_first`."
806
+ )
807
+
808
+ # Load from model defaults
809
+ if return_token_type_ids is None:
810
+ return_token_type_ids = "token_type_ids" in self.model_input_names
811
+ if return_attention_mask is None:
812
+ return_attention_mask = "attention_mask" in self.model_input_names
813
+
814
+ encoded_inputs = {}
815
+
816
+ # Compute the total size of the returned encodings
817
+ total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0)
818
+
819
+ # Truncation: Handle max sequence length
820
+ overflowing_tokens = []
821
+ if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length:
822
+ ids, pair_ids, overflowing_tokens = self.truncate_sequences(
823
+ ids,
824
+ pair_ids=pair_ids,
825
+ num_tokens_to_remove=total_len - max_length,
826
+ truncation_strategy=truncation_strategy,
827
+ stride=stride,
828
+ )
829
+
830
+ if return_overflowing_tokens:
831
+ encoded_inputs["overflowing_tokens"] = overflowing_tokens
832
+ encoded_inputs["num_truncated_tokens"] = total_len - max_length
833
+
834
+ # Add special tokens
835
+ if add_special_tokens:
836
+ sequence = self.build_inputs_with_special_tokens(ids, pair_ids)
837
+ token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids)
838
+ else:
839
+ sequence = ids + pair_ids if pair else ids
840
+ token_type_ids = [0] * len(ids) + ([0] * len(pair_ids) if pair else [])
841
+
842
+ # Build output dictionary
843
+ encoded_inputs["input_ids"] = sequence
844
+ if return_token_type_ids:
845
+ encoded_inputs["token_type_ids"] = token_type_ids
846
+ if return_special_tokens_mask:
847
+ if add_special_tokens:
848
+ encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids)
849
+ else:
850
+ encoded_inputs["special_tokens_mask"] = [0] * len(sequence)
851
+
852
+ # Check lengths
853
+ self._eventual_warn_about_too_long_sequence(encoded_inputs["input_ids"], max_length, verbose)
854
+
855
+ # Padding
856
+ if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask:
857
+ encoded_inputs = self.pad(
858
+ encoded_inputs,
859
+ max_length=max_length,
860
+ padding=padding_strategy.value,
861
+ pad_to_multiple_of=pad_to_multiple_of,
862
+ return_attention_mask=return_attention_mask,
863
+ )
864
+
865
+ if return_length:
866
+ encoded_inputs["length"] = len(encoded_inputs["input_ids"])
867
+
868
+ # for CPMBee, encode all the model arguments
869
+ for arg in self.ext_args_for_model:
870
+ v = kwargs.get(arg, None)
871
+ if v is not None:
872
+ encoded_inputs[arg] = v
873
+
874
+ batch_outputs = BatchEncoding(
875
+ encoded_inputs, tensor_type=return_tensors, prepend_batch_axis=prepend_batch_axis
876
+ )
877
+
878
+ return batch_outputs
879
+
880
+ def prepare_for_finetune(
881
+ self,
882
+ data_list: List[Dict],
883
+ max_length: int = 2048
884
+ ):
885
+ _inputs: List[NDArray[np.int32]] = []
886
+ _inputs_sub: List[NDArray[np.int32]] = []
887
+ _context: List[NDArray[np.int8]] = []
888
+ _sample_ids: List[NDArray[np.int32]] = []
889
+ _segments: List[NDArray[np.int32]] = []
890
+ _num_segments: List[NDArray[np.int32]] = []
891
+ _segment_rel_offset: List[NDArray[np.int32]] = []
892
+ _segment_rel: List[NDArray[np.int32]] = []
893
+ _spans: List[List[int]] = []
894
+ _raw_data: List[List[Any]] = []
895
+
896
+ raw_data = {}
897
+ for data in data_list:
898
+ (
899
+ input_ids,
900
+ input_id_subs,
901
+ context,
902
+ segment_ids,
903
+ segment_rel,
904
+ n_segments,
905
+ _
906
+ ) = self.convert_data_to_id(data)
907
+
908
+ input_ids = input_ids[: max_length]
909
+ context = context[: max_length]
910
+ segment_ids = segment_ids[: max_length]
911
+ raw_data["input"] = data
912
+ raw_data["samples"] = []
913
+
914
+ sample_ids = np.zeros(input_ids.shape, dtype=np.int32)
915
+ segment_rel_offset = np.zeros(input_ids.shape, dtype=np.int32)
916
+ num_segments = np.full(input_ids.shape, n_segments, dtype=np.int32)
917
+
918
+ _inputs.append(input_ids)
919
+ _inputs_sub.append(input_id_subs)
920
+ _context.append(context)
921
+ _sample_ids.append(sample_ids)
922
+ _segments.append(segment_ids)
923
+ _num_segments.append(num_segments)
924
+ _segment_rel_offset.append(segment_rel_offset)
925
+ _segment_rel.append(segment_rel)
926
+ _spans.append([input_ids.shape[0]])
927
+ _raw_data.append([raw_data])
928
+
929
+ batch_size = len(_inputs)
930
+ inputs = np.zeros((batch_size, max_length), dtype=np.int32)
931
+ inputs_sub = np.zeros((batch_size, max_length), dtype=np.int32)
932
+ context = np.zeros((batch_size, max_length), dtype=np.int8)
933
+ sample_ids = np.zeros((batch_size, max_length), dtype=np.int32)
934
+ segments = np.zeros((batch_size, max_length), dtype=np.int32)
935
+ num_segments = np.zeros((batch_size, max_length), dtype=np.int32)
936
+ segment_rel_offset = np.zeros((batch_size, max_length), dtype=np.int32)
937
+ tgt = np.full((batch_size, max_length), -100, dtype=np.int32)
938
+
939
+ max_rel = 0
940
+ for i in range(batch_size):
941
+ max_rel = max(max_rel, _segment_rel[i].shape[0])
942
+ segment_rel = np.zeros((batch_size, max_rel), dtype=np.int32)
943
+ spans = np.zeros((batch_size, max_length), dtype=np.int32)
944
+ length = np.zeros((batch_size,), dtype=np.int32)
945
+
946
+ batch_ext_table_map: Dict[Tuple[int, int], int] = {}
947
+ batch_ext_table_ids: List[int] = []
948
+ batch_ext_table_sub: List[int] = []
949
+ raw_data_list: List[Any] = []
950
+
951
+ for i in range(batch_size):
952
+ instance_length = _inputs[i].shape[0]
953
+ rel_size = _segment_rel[i].shape[0]
954
+ inputs[i, :instance_length] = _inputs[i]
955
+ inputs_sub[i, :instance_length] = _inputs_sub[i]
956
+ context[i, :instance_length] = _context[i]
957
+ sample_ids[i, :instance_length] = _sample_ids[i]
958
+ segments[i, :instance_length] = _segments[i]
959
+ num_segments[i, :instance_length] = _num_segments[i]
960
+ segment_rel_offset[i, :instance_length] = _segment_rel_offset[i]
961
+ segment_rel[i, :rel_size] = _segment_rel[i]
962
+
963
+ span_begin = 0
964
+ for span_id, span_end in enumerate(_spans[i]):
965
+ spans[i, span_begin:span_end] = span_id
966
+ span_begin = span_end
967
+ length[i] = instance_length
968
+ raw_data_list.extend(_raw_data[i])
969
+
970
+ for j in range(instance_length):
971
+ idx, idx_sub = _inputs[i][j], _inputs_sub[i][j]
972
+ tgt_idx = idx
973
+ if idx_sub > 0:
974
+ # need to be in ext table
975
+ if (idx, idx_sub) not in batch_ext_table_map:
976
+ batch_ext_table_map[(idx, idx_sub)] = len(batch_ext_table_map)
977
+ batch_ext_table_ids.append(idx)
978
+ batch_ext_table_sub.append(idx_sub)
979
+ tgt_idx = batch_ext_table_map[(idx, idx_sub)] + self.vocab_size
980
+ if j > 1 and context[i, j - 1] == 0:
981
+ if idx != self.bos_token_id:
982
+ tgt[i, j - 1] = tgt_idx
983
+ else:
984
+ tgt[i, j - 1] = self.eos_token_id
985
+ if context[i, instance_length - 1] == 0:
986
+ tgt[i, instance_length - 1] = self.eos_token_id
987
+
988
+ if len(batch_ext_table_map) == 0:
989
+ # placeholder
990
+ batch_ext_table_ids.append(0)
991
+ batch_ext_table_sub.append(1)
992
+
993
+ return BatchEncoding({
994
+ "input_ids": inputs,
995
+ "input_id_sub": inputs_sub,
996
+ "length": length,
997
+ "context": context > 0,
998
+ "sample_ids": sample_ids,
999
+ "num_segments": num_segments,
1000
+ "segment": segments,
1001
+ "segment_rel_offset": segment_rel_offset,
1002
+ "segment_rel": segment_rel,
1003
+ "span": spans,
1004
+ "labels": tgt,
1005
+ "ext_table_ids": np.array(batch_ext_table_ids, dtype=np.int32),
1006
+ "ext_table_sub": np.array(batch_ext_table_sub, dtype=np.int32)
1007
+ }, tensor_type="pt")
1008
+
tokenizer_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name_or_path": "openbmb/cpm-bee-10b",
3
+ "tokenizer_class": "CpmBeeTokenizer",
4
+ "auto_map": {
5
+ "AutoTokenizer": [
6
+ "tokenization_viscpmbee.VisCpmBeeTokenizer",
7
+ null
8
+ ]
9
+ }
10
+ }
unet/config.json ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "UNet2DConditionModel",
3
+ "_diffusers_version": "0.10.0.dev0",
4
+ "act_fn": "silu",
5
+ "attention_head_dim": [
6
+ 5,
7
+ 10,
8
+ 20,
9
+ 20
10
+ ],
11
+ "block_out_channels": [
12
+ 320,
13
+ 640,
14
+ 1280,
15
+ 1280
16
+ ],
17
+ "center_input_sample": false,
18
+ "cross_attention_dim": 1024,
19
+ "down_block_types": [
20
+ "CrossAttnDownBlock2D",
21
+ "CrossAttnDownBlock2D",
22
+ "CrossAttnDownBlock2D",
23
+ "DownBlock2D"
24
+ ],
25
+ "downsample_padding": 1,
26
+ "dual_cross_attention": false,
27
+ "flip_sin_to_cos": true,
28
+ "freq_shift": 0,
29
+ "in_channels": 4,
30
+ "layers_per_block": 2,
31
+ "mid_block_scale_factor": 1,
32
+ "norm_eps": 1e-05,
33
+ "norm_num_groups": 32,
34
+ "num_class_embeds": null,
35
+ "only_cross_attention": false,
36
+ "out_channels": 4,
37
+ "sample_size": 64,
38
+ "up_block_types": [
39
+ "UpBlock2D",
40
+ "CrossAttnUpBlock2D",
41
+ "CrossAttnUpBlock2D",
42
+ "CrossAttnUpBlock2D"
43
+ ],
44
+ "use_linear_projection": true
45
+ }
vae/config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.8.0",
4
+ "act_fn": "silu",
5
+ "block_out_channels": [
6
+ 128,
7
+ 256,
8
+ 512,
9
+ 512
10
+ ],
11
+ "down_block_types": [
12
+ "DownEncoderBlock2D",
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D"
16
+ ],
17
+ "in_channels": 3,
18
+ "latent_channels": 4,
19
+ "layers_per_block": 2,
20
+ "norm_num_groups": 32,
21
+ "out_channels": 3,
22
+ "sample_size": 768,
23
+ "up_block_types": [
24
+ "UpDecoderBlock2D",
25
+ "UpDecoderBlock2D",
26
+ "UpDecoderBlock2D",
27
+ "UpDecoderBlock2D"
28
+ ]
29
+ }
vocab.txt ADDED
The diff for this file is too large to render. See raw diff