liangyupu commited on
Commit
064752a
·
verified ·
1 Parent(s): bc7f189

Upload 10 files

Browse files
baseline/README.md ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ICDAR 2025 Competition on End-to-end Document Image Machine Translation (OCR-free Track)
2
+
3
+ This is the official baseline code repository for [***ICDAR 2025 Competition on End-to-end Document Image Machine Translation (OCR-free Track)***](https://cip-documentai.github.io/)
4
+
5
+ ## Dataset Download
6
+ The dataset can be downloaded from this [huggingface link](https://huggingface.co/datasets/liangyupu/DoTA_dataset).
7
+
8
+ ## Baseline Implementation
9
+ This is an implementation of a simple end-to-end document image machine translation model with an image encoder and a translation decoder.
10
+ Details can be found in [***Document Image Machine Translation with Dynamic Multi-pre-trained Models Assembling (NAACL 2024 Main)***](https://aclanthology.org/2024.naacl-long.392/) Section 5.3 Base.
11
+
12
+ ### 1. Requirements
13
+ ```bash
14
+ python==3.10.13
15
+ pytorch==1.13.1
16
+ transformers==4.33.2
17
+ ```
18
+
19
+ ### 2. Download Pre-trained models
20
+ Download pre-trained Donut model from [naver-clova-ix/donut-base](https://huggingface.co/naver-clova-ix/donut-base).
21
+
22
+ Download pre-trained Nougat model from [facebook/nougat-small](https://huggingface.co/facebook/nougat-small).
23
+
24
+ ### 3. Train
25
+ ```bash
26
+ bash launch_train.sh
27
+ ```
28
+
29
+ ### 5. Inference
30
+ Before running the script, you need to replace the `~/anaconda3/envs/your_env_name/lib/python3.10/site-packages/transformers/models/bert/modeling_bert.py` file with the `modeling_bert.py` file.
31
+ ```bash
32
+ bash launch_inference.sh
33
+ ```
34
+
35
+ ## Citation
36
+ If you want to use our dataset and cite our paper, please use the following BibTex entries:
37
+ ```BibTex
38
+ @inproceedings{liang2024document,
39
+ title={Document Image Machine Translation with Dynamic Multi-pre-trained Models Assembling},
40
+ author={Liang, Yupu and Zhang, Yaping and Ma, Cong and Zhang, Zhiyang and Zhao, Yang and Xiang, Lu and Zong, Chengqing and Zhou, Yu},
41
+ booktitle={Proceedings of the 2024 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies},
42
+ pages={xxxx--xxxx},
43
+ year={2024}
44
+ }
45
+ ```
46
+
47
+ If you have any question, feel free to contact [[email protected]](mailto:[email protected]).
baseline/inference.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ from transformers import DonutProcessor, AutoTokenizer
5
+ import argparse
6
+ from transformers import VisionEncoderDecoderModel, EncoderDecoderModel, EncoderDecoderConfig, BertConfig
7
+ from my_model import MyModel, MyDataset
8
+ from transformers import GenerationConfig
9
+ from PIL import Image
10
+
11
+ def inference(args):
12
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
13
+
14
+ processor = DonutProcessor.from_pretrained(args.donut_dir)
15
+ processor.image_processor.size = {'height': 896, 'width': 672}
16
+ processor.image_processor.image_mean = [0.485, 0.456, 0.406]
17
+ processor.image_processor.image_std = [0.229, 0.224, 0.225]
18
+ tokenizer = AutoTokenizer.from_pretrained(os.path.join(args.base_dir, 'zh_tokenizer'))
19
+
20
+ encoder_config = BertConfig()
21
+ decoder_config = BertConfig()
22
+ encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder_config, decoder_config)
23
+ encoder_decoder_config.decoder.bos_token_id = tokenizer.bos_token_id
24
+ encoder_decoder_config.decoder.decoder_start_token_id = tokenizer.bos_token_id
25
+ encoder_decoder_config.decoder.eos_token_id = tokenizer.eos_token_id
26
+ encoder_decoder_config.decoder.hidden_size = 512
27
+ encoder_decoder_config.decoder.intermediate_size = 2048
28
+ encoder_decoder_config.decoder.max_length = args.max_length
29
+ encoder_decoder_config.decoder.max_position_embeddings = args.max_length
30
+ encoder_decoder_config.decoder.num_attention_heads = 8
31
+ encoder_decoder_config.decoder.num_hidden_layers = 6
32
+ encoder_decoder_config.decoder.pad_token_id = tokenizer.pad_token_id
33
+ encoder_decoder_config.decoder.type_vocab_size = 1
34
+ encoder_decoder_config.decoder.vocab_size = tokenizer.vocab_size
35
+
36
+ trans_model = EncoderDecoderModel(config=encoder_decoder_config)
37
+ nougat_model = VisionEncoderDecoderModel.from_pretrained(args.nougat_dir)
38
+
39
+ model = MyModel(nougat_model.config, trans_model, nougat_model)
40
+
41
+ checkpoint_file_path = os.path.join(args.checkpoint_dir, 'pytorch_model.bin')
42
+ checkpoint = torch.load(checkpoint_file_path, map_location='cpu')
43
+ model.load_state_dict(checkpoint)
44
+ model.eval()
45
+ model.to(device)
46
+
47
+ generation_config = GenerationConfig(
48
+ max_length=args.max_length,
49
+ early_stopping=True,
50
+ num_beams=args.num_beams,
51
+ use_cache=True,
52
+ length_penalty=1.0,
53
+ bos_token_id=tokenizer.bos_token_id,
54
+ pad_token_id=tokenizer.pad_token_id,
55
+ eos_token_id=tokenizer.eos_token_id,
56
+ )
57
+
58
+ image = Image.open(args.image_file_path)
59
+ if image.mode != 'RGB':
60
+ image = image.convert('RGB')
61
+ pixel_values = processor(image, return_tensors="pt").pixel_values.to(device)
62
+
63
+ generation_ids = model.generate(
64
+ pixel_values=pixel_values,
65
+ generation_config=generation_config,
66
+ )
67
+
68
+ zh_text = tokenizer.decode(generation_ids[0])
69
+
70
+ result_dir = os.path.join(args.base_dir, 'outputs')
71
+ os.makedirs(result_dir, exist_ok=True)
72
+
73
+ result_file_path = os.path.join(result_dir, args.image_file_path.split('/')[-1][:-4]+'.txt')
74
+ with open(result_file_path, 'w', encoding='utf-8') as f:
75
+ f.write(zh_text)
76
+
77
+ if __name__ == '__main__':
78
+ parser = argparse.ArgumentParser()
79
+ parser.add_argument("--base_dir", type=str)
80
+ parser.add_argument("--donut_dir", type=str)
81
+ parser.add_argument("--nougat_dir", type=str)
82
+ parser.add_argument("--checkpoint_dir", type=str)
83
+ parser.add_argument("--image_file_path", type=str)
84
+
85
+ parser.add_argument("--max_length", type=int, default=1536)
86
+ parser.add_argument("--num_beams", type=int, default=4)
87
+
88
+ args = parser.parse_args()
89
+
90
+ inference(args)
baseline/launch_inference.sh ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ base_dir=/path/to/dimt_ocr_free
4
+
5
+ donut_dir=/path/to/donut-base
6
+ nougat_dir=/path/to/nougat-small
7
+
8
+ checkpoint_dir=/path/to/checkpoint_dir
9
+
10
+ image_file_path=/path/to/image.png
11
+
12
+ export CUDA_VISIBLE_DEVICES=0
13
+
14
+ python inference.py \
15
+ --base_dir $base_dir \
16
+ --donut_dir $donut_dir \
17
+ --nougat_dir $nougat_dir \
18
+ --checkpoint_dir $checkpoint_dir \
19
+ --image_file_path $image_file_path
baseline/launch_train.sh ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ base_dir=/path/to/dimt_ocr_free
4
+ dataset_dir=/path/to/DoTA_dataset
5
+
6
+ donut_dir=/path/to/donut-base
7
+ nougat_dir=/path/to/nougat-small
8
+
9
+ export CUDA_VISIBLE_DEVICES=0,1
10
+
11
+ accelerate launch \
12
+ --main_process_port 12345 \
13
+ --num_processes 2 \
14
+ --num_machines 1 train.py \
15
+ --base_dir $base_dir \
16
+ --dataset_dir $dataset_dir \
17
+ --donut_dir $donut_dir \
18
+ --nougat_dir $nougat_dir
baseline/modeling_bert.py ADDED
@@ -0,0 +1,1916 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """PyTorch BERT model."""
17
+
18
+
19
+ import math
20
+ import os
21
+ import warnings
22
+ from dataclasses import dataclass
23
+ from typing import List, Optional, Tuple, Union
24
+
25
+ import torch
26
+ import torch.utils.checkpoint
27
+ from torch import nn
28
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
29
+
30
+ from ...activations import ACT2FN
31
+ from ...modeling_outputs import (
32
+ BaseModelOutputWithPastAndCrossAttentions,
33
+ BaseModelOutputWithPoolingAndCrossAttentions,
34
+ CausalLMOutputWithCrossAttentions,
35
+ MaskedLMOutput,
36
+ MultipleChoiceModelOutput,
37
+ NextSentencePredictorOutput,
38
+ QuestionAnsweringModelOutput,
39
+ SequenceClassifierOutput,
40
+ TokenClassifierOutput,
41
+ )
42
+ from ...modeling_utils import PreTrainedModel
43
+ from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
44
+ from ...utils import (
45
+ ModelOutput,
46
+ add_code_sample_docstrings,
47
+ add_start_docstrings,
48
+ add_start_docstrings_to_model_forward,
49
+ logging,
50
+ replace_return_docstrings,
51
+ )
52
+ from .configuration_bert import BertConfig
53
+
54
+
55
+ logger = logging.get_logger(__name__)
56
+
57
+ _CHECKPOINT_FOR_DOC = "bert-base-uncased"
58
+ _CONFIG_FOR_DOC = "BertConfig"
59
+
60
+ # TokenClassification docstring
61
+ _CHECKPOINT_FOR_TOKEN_CLASSIFICATION = "dbmdz/bert-large-cased-finetuned-conll03-english"
62
+ _TOKEN_CLASS_EXPECTED_OUTPUT = (
63
+ "['O', 'I-ORG', 'I-ORG', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'O', 'I-LOC', 'I-LOC'] "
64
+ )
65
+ _TOKEN_CLASS_EXPECTED_LOSS = 0.01
66
+
67
+ # QuestionAnswering docstring
68
+ _CHECKPOINT_FOR_QA = "deepset/bert-base-cased-squad2"
69
+ _QA_EXPECTED_OUTPUT = "'a nice puppet'"
70
+ _QA_EXPECTED_LOSS = 7.41
71
+ _QA_TARGET_START_INDEX = 14
72
+ _QA_TARGET_END_INDEX = 15
73
+
74
+ # SequenceClassification docstring
75
+ _CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "textattack/bert-base-uncased-yelp-polarity"
76
+ _SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_1'"
77
+ _SEQ_CLASS_EXPECTED_LOSS = 0.01
78
+
79
+
80
+ BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
81
+ "bert-base-uncased",
82
+ "bert-large-uncased",
83
+ "bert-base-cased",
84
+ "bert-large-cased",
85
+ "bert-base-multilingual-uncased",
86
+ "bert-base-multilingual-cased",
87
+ "bert-base-chinese",
88
+ "bert-base-german-cased",
89
+ "bert-large-uncased-whole-word-masking",
90
+ "bert-large-cased-whole-word-masking",
91
+ "bert-large-uncased-whole-word-masking-finetuned-squad",
92
+ "bert-large-cased-whole-word-masking-finetuned-squad",
93
+ "bert-base-cased-finetuned-mrpc",
94
+ "bert-base-german-dbmdz-cased",
95
+ "bert-base-german-dbmdz-uncased",
96
+ "cl-tohoku/bert-base-japanese",
97
+ "cl-tohoku/bert-base-japanese-whole-word-masking",
98
+ "cl-tohoku/bert-base-japanese-char",
99
+ "cl-tohoku/bert-base-japanese-char-whole-word-masking",
100
+ "TurkuNLP/bert-base-finnish-cased-v1",
101
+ "TurkuNLP/bert-base-finnish-uncased-v1",
102
+ "wietsedv/bert-base-dutch-cased",
103
+ # See all BERT models at https://huggingface.co/models?filter=bert
104
+ ]
105
+
106
+
107
+ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
108
+ """Load tf checkpoints in a pytorch model."""
109
+ try:
110
+ import re
111
+
112
+ import numpy as np
113
+ import tensorflow as tf
114
+ except ImportError:
115
+ logger.error(
116
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
117
+ "https://www.tensorflow.org/install/ for installation instructions."
118
+ )
119
+ raise
120
+ tf_path = os.path.abspath(tf_checkpoint_path)
121
+ logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
122
+ # Load weights from TF model
123
+ init_vars = tf.train.list_variables(tf_path)
124
+ names = []
125
+ arrays = []
126
+ for name, shape in init_vars:
127
+ logger.info(f"Loading TF weight {name} with shape {shape}")
128
+ array = tf.train.load_variable(tf_path, name)
129
+ names.append(name)
130
+ arrays.append(array)
131
+
132
+ for name, array in zip(names, arrays):
133
+ name = name.split("/")
134
+ # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
135
+ # which are not required for using pretrained model
136
+ if any(
137
+ n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
138
+ for n in name
139
+ ):
140
+ logger.info(f"Skipping {'/'.join(name)}")
141
+ continue
142
+ pointer = model
143
+ for m_name in name:
144
+ if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
145
+ scope_names = re.split(r"_(\d+)", m_name)
146
+ else:
147
+ scope_names = [m_name]
148
+ if scope_names[0] == "kernel" or scope_names[0] == "gamma":
149
+ pointer = getattr(pointer, "weight")
150
+ elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
151
+ pointer = getattr(pointer, "bias")
152
+ elif scope_names[0] == "output_weights":
153
+ pointer = getattr(pointer, "weight")
154
+ elif scope_names[0] == "squad":
155
+ pointer = getattr(pointer, "classifier")
156
+ else:
157
+ try:
158
+ pointer = getattr(pointer, scope_names[0])
159
+ except AttributeError:
160
+ logger.info(f"Skipping {'/'.join(name)}")
161
+ continue
162
+ if len(scope_names) >= 2:
163
+ num = int(scope_names[1])
164
+ pointer = pointer[num]
165
+ if m_name[-11:] == "_embeddings":
166
+ pointer = getattr(pointer, "weight")
167
+ elif m_name == "kernel":
168
+ array = np.transpose(array)
169
+ try:
170
+ if pointer.shape != array.shape:
171
+ raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
172
+ except ValueError as e:
173
+ e.args += (pointer.shape, array.shape)
174
+ raise
175
+ logger.info(f"Initialize PyTorch weight {name}")
176
+ pointer.data = torch.from_numpy(array)
177
+ return model
178
+
179
+
180
+ class BertEmbeddings(nn.Module):
181
+ """Construct the embeddings from word, position and token_type embeddings."""
182
+
183
+ def __init__(self, config):
184
+ super().__init__()
185
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
186
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
187
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
188
+
189
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
190
+ # any TensorFlow checkpoint file
191
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
192
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
193
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
194
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
195
+ self.register_buffer(
196
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
197
+ )
198
+ self.register_buffer(
199
+ "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
200
+ )
201
+
202
+ def forward(
203
+ self,
204
+ input_ids: Optional[torch.LongTensor] = None,
205
+ token_type_ids: Optional[torch.LongTensor] = None,
206
+ position_ids: Optional[torch.LongTensor] = None,
207
+ inputs_embeds: Optional[torch.FloatTensor] = None,
208
+ past_key_values_length: int = 0,
209
+ ) -> torch.Tensor:
210
+ if input_ids is not None:
211
+ input_shape = input_ids.size()
212
+ else:
213
+ input_shape = inputs_embeds.size()[:-1]
214
+
215
+ seq_length = input_shape[1]
216
+
217
+ if position_ids is None:
218
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
219
+
220
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
221
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
222
+ # issue #5664
223
+ if token_type_ids is None:
224
+ if hasattr(self, "token_type_ids"):
225
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
226
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
227
+ token_type_ids = buffered_token_type_ids_expanded
228
+ else:
229
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
230
+
231
+ if inputs_embeds is None:
232
+ inputs_embeds = self.word_embeddings(input_ids)
233
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
234
+
235
+ embeddings = inputs_embeds + token_type_embeddings
236
+ if self.position_embedding_type == "absolute":
237
+ position_embeddings = self.position_embeddings(position_ids)
238
+ embeddings += position_embeddings
239
+ embeddings = self.LayerNorm(embeddings)
240
+ embeddings = self.dropout(embeddings)
241
+ return embeddings
242
+
243
+
244
+ class BertSelfAttention(nn.Module):
245
+ def __init__(self, config, position_embedding_type=None):
246
+ super().__init__()
247
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
248
+ raise ValueError(
249
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
250
+ f"heads ({config.num_attention_heads})"
251
+ )
252
+
253
+ self.num_attention_heads = config.num_attention_heads
254
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
255
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
256
+
257
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
258
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
259
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
260
+
261
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
262
+ self.position_embedding_type = position_embedding_type or getattr(
263
+ config, "position_embedding_type", "absolute"
264
+ )
265
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
266
+ self.max_position_embeddings = config.max_position_embeddings
267
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
268
+
269
+ self.is_decoder = config.is_decoder
270
+
271
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
272
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
273
+ x = x.view(new_x_shape)
274
+ return x.permute(0, 2, 1, 3)
275
+
276
+ def forward(
277
+ self,
278
+ hidden_states: torch.Tensor,
279
+ attention_mask: Optional[torch.FloatTensor] = None,
280
+ head_mask: Optional[torch.FloatTensor] = None,
281
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
282
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
283
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
284
+ output_attentions: Optional[bool] = False,
285
+ ) -> Tuple[torch.Tensor]:
286
+ mixed_query_layer = self.query(hidden_states)
287
+
288
+ # If this is instantiated as a cross-attention module, the keys
289
+ # and values come from an encoder; the attention mask needs to be
290
+ # such that the encoder's padding tokens are not attended to.
291
+ is_cross_attention = encoder_hidden_states is not None
292
+
293
+ if is_cross_attention and past_key_value is not None:
294
+ # reuse k,v, cross_attentions
295
+ key_layer = past_key_value[0]
296
+ value_layer = past_key_value[1]
297
+ attention_mask = encoder_attention_mask
298
+ elif is_cross_attention:
299
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
300
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
301
+ attention_mask = encoder_attention_mask
302
+ elif past_key_value is not None:
303
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
304
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
305
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
306
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
307
+ else:
308
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
309
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
310
+
311
+ query_layer = self.transpose_for_scores(mixed_query_layer)
312
+
313
+ use_cache = past_key_value is not None
314
+ if self.is_decoder:
315
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
316
+ # Further calls to cross_attention layer can then reuse all cross-attention
317
+ # key/value_states (first "if" case)
318
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
319
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
320
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
321
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
322
+ past_key_value = (key_layer, value_layer)
323
+
324
+ # Take the dot product between "query" and "key" to get the raw attention scores.
325
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
326
+
327
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
328
+ query_length, key_length = query_layer.shape[2], key_layer.shape[2]
329
+ if use_cache:
330
+ position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
331
+ -1, 1
332
+ )
333
+ else:
334
+ position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
335
+ position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
336
+ distance = position_ids_l - position_ids_r
337
+
338
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
339
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
340
+
341
+ if self.position_embedding_type == "relative_key":
342
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
343
+ attention_scores = attention_scores + relative_position_scores
344
+ elif self.position_embedding_type == "relative_key_query":
345
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
346
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
347
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
348
+
349
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
350
+ if attention_mask is not None:
351
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
352
+ attention_scores = attention_scores + attention_mask
353
+
354
+ # Normalize the attention scores to probabilities.
355
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
356
+
357
+ # This is actually dropping out entire tokens to attend to, which might
358
+ # seem a bit unusual, but is taken from the original Transformer paper.
359
+ attention_probs = self.dropout(attention_probs)
360
+
361
+ # Mask heads if we want to
362
+ if head_mask is not None:
363
+ attention_probs = attention_probs * head_mask
364
+
365
+ context_layer = torch.matmul(attention_probs, value_layer)
366
+
367
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
368
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
369
+ context_layer = context_layer.view(new_context_layer_shape)
370
+
371
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
372
+
373
+ if self.is_decoder:
374
+ outputs = outputs + (past_key_value,)
375
+ return outputs
376
+
377
+
378
+ class BertSelfOutput(nn.Module):
379
+ def __init__(self, config):
380
+ super().__init__()
381
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
382
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
383
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
384
+
385
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
386
+ hidden_states = self.dense(hidden_states)
387
+ hidden_states = self.dropout(hidden_states)
388
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
389
+ return hidden_states
390
+
391
+
392
+ class BertAttention(nn.Module):
393
+ def __init__(self, config, position_embedding_type=None):
394
+ super().__init__()
395
+ self.self = BertSelfAttention(config, position_embedding_type=position_embedding_type)
396
+ self.output = BertSelfOutput(config)
397
+ self.pruned_heads = set()
398
+
399
+ def prune_heads(self, heads):
400
+ if len(heads) == 0:
401
+ return
402
+ heads, index = find_pruneable_heads_and_indices(
403
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
404
+ )
405
+
406
+ # Prune linear layers
407
+ self.self.query = prune_linear_layer(self.self.query, index)
408
+ self.self.key = prune_linear_layer(self.self.key, index)
409
+ self.self.value = prune_linear_layer(self.self.value, index)
410
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
411
+
412
+ # Update hyper params and store pruned heads
413
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
414
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
415
+ self.pruned_heads = self.pruned_heads.union(heads)
416
+
417
+ def forward(
418
+ self,
419
+ hidden_states: torch.Tensor,
420
+ attention_mask: Optional[torch.FloatTensor] = None,
421
+ head_mask: Optional[torch.FloatTensor] = None,
422
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
423
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
424
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
425
+ output_attentions: Optional[bool] = False,
426
+ ) -> Tuple[torch.Tensor]:
427
+ self_outputs = self.self(
428
+ hidden_states,
429
+ attention_mask,
430
+ head_mask,
431
+ encoder_hidden_states,
432
+ encoder_attention_mask,
433
+ past_key_value,
434
+ output_attentions,
435
+ )
436
+ attention_output = self.output(self_outputs[0], hidden_states)
437
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
438
+ return outputs
439
+
440
+
441
+ class BertIntermediate(nn.Module):
442
+ def __init__(self, config):
443
+ super().__init__()
444
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
445
+ if isinstance(config.hidden_act, str):
446
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
447
+ else:
448
+ self.intermediate_act_fn = config.hidden_act
449
+
450
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
451
+ hidden_states = self.dense(hidden_states)
452
+ hidden_states = self.intermediate_act_fn(hidden_states)
453
+ return hidden_states
454
+
455
+
456
+ class BertOutput(nn.Module):
457
+ def __init__(self, config):
458
+ super().__init__()
459
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
460
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
461
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
462
+
463
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
464
+ hidden_states = self.dense(hidden_states)
465
+ hidden_states = self.dropout(hidden_states)
466
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
467
+ return hidden_states
468
+
469
+
470
+ class BertLayer(nn.Module):
471
+ def __init__(self, config):
472
+ super().__init__()
473
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
474
+ self.seq_len_dim = 1
475
+ self.attention = BertAttention(config)
476
+ self.is_decoder = config.is_decoder
477
+ self.add_cross_attention = config.add_cross_attention
478
+ if self.add_cross_attention:
479
+ if not self.is_decoder:
480
+ raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
481
+ self.crossattention = BertAttention(config, position_embedding_type="absolute")
482
+ self.intermediate = BertIntermediate(config)
483
+ self.output = BertOutput(config)
484
+
485
+ def forward(
486
+ self,
487
+ hidden_states: torch.Tensor,
488
+ attention_mask: Optional[torch.FloatTensor] = None,
489
+ head_mask: Optional[torch.FloatTensor] = None,
490
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
491
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
492
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
493
+ output_attentions: Optional[bool] = False,
494
+ ) -> Tuple[torch.Tensor]:
495
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
496
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
497
+ self_attention_outputs = self.attention(
498
+ hidden_states,
499
+ attention_mask,
500
+ head_mask,
501
+ output_attentions=output_attentions,
502
+ past_key_value=self_attn_past_key_value,
503
+ )
504
+ attention_output = self_attention_outputs[0]
505
+
506
+ # if decoder, the last output is tuple of self-attn cache
507
+ if self.is_decoder:
508
+ outputs = self_attention_outputs[1:-1]
509
+ present_key_value = self_attention_outputs[-1]
510
+ else:
511
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
512
+
513
+ cross_attn_present_key_value = None
514
+ if self.is_decoder and encoder_hidden_states is not None:
515
+ if not hasattr(self, "crossattention"):
516
+ raise ValueError(
517
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
518
+ " by setting `config.add_cross_attention=True`"
519
+ )
520
+
521
+ # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
522
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
523
+ cross_attention_outputs = self.crossattention(
524
+ attention_output,
525
+ attention_mask,
526
+ head_mask,
527
+ encoder_hidden_states,
528
+ encoder_attention_mask,
529
+ cross_attn_past_key_value,
530
+ output_attentions,
531
+ )
532
+ attention_output = cross_attention_outputs[0]
533
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
534
+
535
+ # add cross-attn cache to positions 3,4 of present_key_value tuple
536
+ cross_attn_present_key_value = cross_attention_outputs[-1]
537
+ present_key_value = present_key_value + cross_attn_present_key_value
538
+
539
+ layer_output = apply_chunking_to_forward(
540
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
541
+ )
542
+ outputs = (layer_output,) + outputs
543
+
544
+ # if decoder, return the attn key/values as the last output
545
+ if self.is_decoder:
546
+ outputs = outputs + (present_key_value,)
547
+
548
+ return outputs
549
+
550
+ def feed_forward_chunk(self, attention_output):
551
+ intermediate_output = self.intermediate(attention_output)
552
+ layer_output = self.output(intermediate_output, attention_output)
553
+ return layer_output
554
+
555
+
556
+ class BertEncoder(nn.Module):
557
+ def __init__(self, config):
558
+ super().__init__()
559
+ self.config = config
560
+ self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
561
+ self.gradient_checkpointing = False
562
+
563
+ def forward(
564
+ self,
565
+ hidden_states: torch.Tensor,
566
+ attention_mask: Optional[torch.FloatTensor] = None,
567
+ head_mask: Optional[torch.FloatTensor] = None,
568
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
569
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
570
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
571
+ use_cache: Optional[bool] = None,
572
+ output_attentions: Optional[bool] = False,
573
+ output_hidden_states: Optional[bool] = False,
574
+ return_dict: Optional[bool] = True,
575
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
576
+ all_hidden_states = () if output_hidden_states else None
577
+ all_self_attentions = () if output_attentions else None
578
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
579
+
580
+ if self.gradient_checkpointing and self.training:
581
+ if use_cache:
582
+ logger.warning_once(
583
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
584
+ )
585
+ use_cache = False
586
+
587
+ next_decoder_cache = () if use_cache else None
588
+ for i, layer_module in enumerate(self.layer):
589
+ if output_hidden_states:
590
+ all_hidden_states = all_hidden_states + (hidden_states,)
591
+
592
+ layer_head_mask = head_mask[i] if head_mask is not None else None
593
+ past_key_value = past_key_values[i] if past_key_values is not None else None
594
+
595
+ if self.gradient_checkpointing and self.training:
596
+
597
+ def create_custom_forward(module):
598
+ def custom_forward(*inputs):
599
+ return module(*inputs, past_key_value, output_attentions)
600
+
601
+ return custom_forward
602
+
603
+ layer_outputs = torch.utils.checkpoint.checkpoint(
604
+ create_custom_forward(layer_module),
605
+ hidden_states,
606
+ attention_mask,
607
+ layer_head_mask,
608
+ encoder_hidden_states,
609
+ encoder_attention_mask,
610
+ )
611
+ else:
612
+ layer_outputs = layer_module(
613
+ hidden_states,
614
+ attention_mask,
615
+ layer_head_mask,
616
+ encoder_hidden_states,
617
+ encoder_attention_mask,
618
+ past_key_value,
619
+ output_attentions,
620
+ )
621
+
622
+ hidden_states = layer_outputs[0]
623
+ if use_cache:
624
+ next_decoder_cache += (layer_outputs[-1],)
625
+ if output_attentions:
626
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
627
+ if self.config.add_cross_attention:
628
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
629
+
630
+ if output_hidden_states:
631
+ all_hidden_states = all_hidden_states + (hidden_states,)
632
+
633
+ if not return_dict:
634
+ return tuple(
635
+ v
636
+ for v in [
637
+ hidden_states,
638
+ next_decoder_cache,
639
+ all_hidden_states,
640
+ all_self_attentions,
641
+ all_cross_attentions,
642
+ ]
643
+ if v is not None
644
+ )
645
+ return BaseModelOutputWithPastAndCrossAttentions(
646
+ last_hidden_state=hidden_states,
647
+ past_key_values=next_decoder_cache,
648
+ hidden_states=all_hidden_states,
649
+ attentions=all_self_attentions,
650
+ cross_attentions=all_cross_attentions,
651
+ )
652
+
653
+
654
+ class BertPooler(nn.Module):
655
+ def __init__(self, config):
656
+ super().__init__()
657
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
658
+ self.activation = nn.Tanh()
659
+
660
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
661
+ # We "pool" the model by simply taking the hidden state corresponding
662
+ # to the first token.
663
+ first_token_tensor = hidden_states[:, 0]
664
+ pooled_output = self.dense(first_token_tensor)
665
+ pooled_output = self.activation(pooled_output)
666
+ return pooled_output
667
+
668
+
669
+ class BertPredictionHeadTransform(nn.Module):
670
+ def __init__(self, config):
671
+ super().__init__()
672
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
673
+ if isinstance(config.hidden_act, str):
674
+ self.transform_act_fn = ACT2FN[config.hidden_act]
675
+ else:
676
+ self.transform_act_fn = config.hidden_act
677
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
678
+
679
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
680
+ hidden_states = self.dense(hidden_states)
681
+ hidden_states = self.transform_act_fn(hidden_states)
682
+ hidden_states = self.LayerNorm(hidden_states)
683
+ return hidden_states
684
+
685
+
686
+ class BertLMPredictionHead(nn.Module):
687
+ def __init__(self, config):
688
+ super().__init__()
689
+ self.transform = BertPredictionHeadTransform(config)
690
+
691
+ # The output weights are the same as the input embeddings, but there is
692
+ # an output-only bias for each token.
693
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
694
+
695
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
696
+
697
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
698
+ self.decoder.bias = self.bias
699
+
700
+ def forward(self, hidden_states):
701
+ hidden_states = self.transform(hidden_states)
702
+ hidden_states = self.decoder(hidden_states)
703
+ return hidden_states
704
+
705
+
706
+ class BertOnlyMLMHead(nn.Module):
707
+ def __init__(self, config):
708
+ super().__init__()
709
+ self.predictions = BertLMPredictionHead(config)
710
+
711
+ def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
712
+ prediction_scores = self.predictions(sequence_output)
713
+ return prediction_scores
714
+
715
+
716
+ class BertOnlyNSPHead(nn.Module):
717
+ def __init__(self, config):
718
+ super().__init__()
719
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
720
+
721
+ def forward(self, pooled_output):
722
+ seq_relationship_score = self.seq_relationship(pooled_output)
723
+ return seq_relationship_score
724
+
725
+
726
+ class BertPreTrainingHeads(nn.Module):
727
+ def __init__(self, config):
728
+ super().__init__()
729
+ self.predictions = BertLMPredictionHead(config)
730
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
731
+
732
+ def forward(self, sequence_output, pooled_output):
733
+ prediction_scores = self.predictions(sequence_output)
734
+ seq_relationship_score = self.seq_relationship(pooled_output)
735
+ return prediction_scores, seq_relationship_score
736
+
737
+
738
+ class BertPreTrainedModel(PreTrainedModel):
739
+ """
740
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
741
+ models.
742
+ """
743
+
744
+ config_class = BertConfig
745
+ load_tf_weights = load_tf_weights_in_bert
746
+ base_model_prefix = "bert"
747
+ supports_gradient_checkpointing = True
748
+
749
+ def _init_weights(self, module):
750
+ """Initialize the weights"""
751
+ if isinstance(module, nn.Linear):
752
+ # Slightly different from the TF version which uses truncated_normal for initialization
753
+ # cf https://github.com/pytorch/pytorch/pull/5617
754
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
755
+ if module.bias is not None:
756
+ module.bias.data.zero_()
757
+ elif isinstance(module, nn.Embedding):
758
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
759
+ if module.padding_idx is not None:
760
+ module.weight.data[module.padding_idx].zero_()
761
+ elif isinstance(module, nn.LayerNorm):
762
+ module.bias.data.zero_()
763
+ module.weight.data.fill_(1.0)
764
+
765
+ def _set_gradient_checkpointing(self, module, value=False):
766
+ if isinstance(module, BertEncoder):
767
+ module.gradient_checkpointing = value
768
+
769
+
770
+ @dataclass
771
+ class BertForPreTrainingOutput(ModelOutput):
772
+ """
773
+ Output type of [`BertForPreTraining`].
774
+
775
+ Args:
776
+ loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
777
+ Total loss as the sum of the masked language modeling loss and the next sequence prediction
778
+ (classification) loss.
779
+ prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
780
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
781
+ seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, 2)`):
782
+ Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
783
+ before SoftMax).
784
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
785
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
786
+ shape `(batch_size, sequence_length, hidden_size)`.
787
+
788
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
789
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
790
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
791
+ sequence_length)`.
792
+
793
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
794
+ heads.
795
+ """
796
+
797
+ loss: Optional[torch.FloatTensor] = None
798
+ prediction_logits: torch.FloatTensor = None
799
+ seq_relationship_logits: torch.FloatTensor = None
800
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
801
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
802
+
803
+
804
+ BERT_START_DOCSTRING = r"""
805
+
806
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
807
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
808
+ etc.)
809
+
810
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
811
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
812
+ and behavior.
813
+
814
+ Parameters:
815
+ config ([`BertConfig`]): Model configuration class with all the parameters of the model.
816
+ Initializing with a config file does not load the weights associated with the model, only the
817
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
818
+ """
819
+
820
+ BERT_INPUTS_DOCSTRING = r"""
821
+ Args:
822
+ input_ids (`torch.LongTensor` of shape `({0})`):
823
+ Indices of input sequence tokens in the vocabulary.
824
+
825
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
826
+ [`PreTrainedTokenizer.__call__`] for details.
827
+
828
+ [What are input IDs?](../glossary#input-ids)
829
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
830
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
831
+
832
+ - 1 for tokens that are **not masked**,
833
+ - 0 for tokens that are **masked**.
834
+
835
+ [What are attention masks?](../glossary#attention-mask)
836
+ token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
837
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
838
+ 1]`:
839
+
840
+ - 0 corresponds to a *sentence A* token,
841
+ - 1 corresponds to a *sentence B* token.
842
+
843
+ [What are token type IDs?](../glossary#token-type-ids)
844
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
845
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
846
+ config.max_position_embeddings - 1]`.
847
+
848
+ [What are position IDs?](../glossary#position-ids)
849
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
850
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
851
+
852
+ - 1 indicates the head is **not masked**,
853
+ - 0 indicates the head is **masked**.
854
+
855
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
856
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
857
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
858
+ model's internal embedding lookup matrix.
859
+ output_attentions (`bool`, *optional*):
860
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
861
+ tensors for more detail.
862
+ output_hidden_states (`bool`, *optional*):
863
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
864
+ more detail.
865
+ return_dict (`bool`, *optional*):
866
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
867
+ """
868
+
869
+
870
+ @add_start_docstrings(
871
+ "The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
872
+ BERT_START_DOCSTRING,
873
+ )
874
+ class BertModel(BertPreTrainedModel):
875
+ """
876
+
877
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
878
+ cross-attention is added between the self-attention layers, following the architecture described in [Attention is
879
+ all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
880
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
881
+
882
+ To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
883
+ to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
884
+ `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
885
+ """
886
+
887
+ def __init__(self, config, add_pooling_layer=True):
888
+ super().__init__(config)
889
+ self.config = config
890
+
891
+ self.embeddings = BertEmbeddings(config)
892
+ self.encoder = BertEncoder(config)
893
+
894
+ self.pooler = BertPooler(config) if add_pooling_layer else None
895
+
896
+ # Initialize weights and apply final processing
897
+ self.post_init()
898
+
899
+ def get_input_embeddings(self):
900
+ return self.embeddings.word_embeddings
901
+
902
+ def set_input_embeddings(self, value):
903
+ self.embeddings.word_embeddings = value
904
+
905
+ def _prune_heads(self, heads_to_prune):
906
+ """
907
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
908
+ class PreTrainedModel
909
+ """
910
+ for layer, heads in heads_to_prune.items():
911
+ self.encoder.layer[layer].attention.prune_heads(heads)
912
+
913
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
914
+ @add_code_sample_docstrings(
915
+ checkpoint=_CHECKPOINT_FOR_DOC,
916
+ output_type=BaseModelOutputWithPoolingAndCrossAttentions,
917
+ config_class=_CONFIG_FOR_DOC,
918
+ )
919
+ def forward(
920
+ self,
921
+ input_ids: Optional[torch.Tensor] = None,
922
+ attention_mask: Optional[torch.Tensor] = None,
923
+ token_type_ids: Optional[torch.Tensor] = None,
924
+ position_ids: Optional[torch.Tensor] = None,
925
+ head_mask: Optional[torch.Tensor] = None,
926
+ inputs_embeds: Optional[torch.Tensor] = None,
927
+ encoder_hidden_states: Optional[torch.Tensor] = None,
928
+ encoder_attention_mask: Optional[torch.Tensor] = None,
929
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
930
+ use_cache: Optional[bool] = None,
931
+ output_attentions: Optional[bool] = None,
932
+ output_hidden_states: Optional[bool] = None,
933
+ return_dict: Optional[bool] = None,
934
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
935
+ r"""
936
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
937
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
938
+ the model is configured as a decoder.
939
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
940
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
941
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
942
+
943
+ - 1 for tokens that are **not masked**,
944
+ - 0 for tokens that are **masked**.
945
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
946
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
947
+
948
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
949
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
950
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
951
+ use_cache (`bool`, *optional*):
952
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
953
+ `past_key_values`).
954
+ """
955
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
956
+ output_hidden_states = (
957
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
958
+ )
959
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
960
+
961
+ if self.config.is_decoder:
962
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
963
+ else:
964
+ use_cache = False
965
+
966
+ if input_ids is not None and inputs_embeds is not None:
967
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
968
+ elif input_ids is not None:
969
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
970
+ input_shape = input_ids.size()
971
+ elif inputs_embeds is not None:
972
+ input_shape = inputs_embeds.size()[:-1]
973
+ else:
974
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
975
+
976
+ batch_size, seq_length = input_shape
977
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
978
+
979
+ # past_key_values_length
980
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
981
+
982
+ # print('attention_mask before')
983
+ # print(attention_mask)
984
+ if attention_mask is None:
985
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
986
+ # print('attention_mask after')
987
+ # print(attention_mask)
988
+ # print(attention_mask.shape)
989
+ # print('input_ids')
990
+ # print(input_ids.shape)
991
+
992
+ if token_type_ids is None:
993
+ if hasattr(self.embeddings, "token_type_ids"):
994
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
995
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
996
+ token_type_ids = buffered_token_type_ids_expanded
997
+ else:
998
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
999
+
1000
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
1001
+ # ourselves in which case we just need to make it broadcastable to all heads.
1002
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
1003
+
1004
+ # If a 2D or 3D attention mask is provided for the cross-attention
1005
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
1006
+ if self.config.is_decoder and encoder_hidden_states is not None:
1007
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
1008
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
1009
+ if encoder_attention_mask is None:
1010
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
1011
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
1012
+ else:
1013
+ encoder_extended_attention_mask = None
1014
+
1015
+ # Prepare head mask if needed
1016
+ # 1.0 in head_mask indicate we keep the head
1017
+ # attention_probs has shape bsz x n_heads x N x N
1018
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
1019
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
1020
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
1021
+
1022
+ embedding_output = self.embeddings(
1023
+ input_ids=input_ids,
1024
+ position_ids=position_ids,
1025
+ token_type_ids=token_type_ids,
1026
+ inputs_embeds=inputs_embeds,
1027
+ past_key_values_length=past_key_values_length,
1028
+ )
1029
+ encoder_outputs = self.encoder(
1030
+ embedding_output,
1031
+ attention_mask=extended_attention_mask,
1032
+ head_mask=head_mask,
1033
+ encoder_hidden_states=encoder_hidden_states,
1034
+ encoder_attention_mask=encoder_extended_attention_mask,
1035
+ past_key_values=past_key_values,
1036
+ use_cache=use_cache,
1037
+ output_attentions=output_attentions,
1038
+ output_hidden_states=output_hidden_states,
1039
+ return_dict=return_dict,
1040
+ )
1041
+ sequence_output = encoder_outputs[0]
1042
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
1043
+
1044
+ if not return_dict:
1045
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
1046
+
1047
+ return BaseModelOutputWithPoolingAndCrossAttentions(
1048
+ last_hidden_state=sequence_output,
1049
+ pooler_output=pooled_output,
1050
+ past_key_values=encoder_outputs.past_key_values,
1051
+ hidden_states=encoder_outputs.hidden_states,
1052
+ attentions=encoder_outputs.attentions,
1053
+ cross_attentions=encoder_outputs.cross_attentions,
1054
+ )
1055
+
1056
+
1057
+ @add_start_docstrings(
1058
+ """
1059
+ Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
1060
+ sentence prediction (classification)` head.
1061
+ """,
1062
+ BERT_START_DOCSTRING,
1063
+ )
1064
+ class BertForPreTraining(BertPreTrainedModel):
1065
+ _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
1066
+
1067
+ def __init__(self, config):
1068
+ super().__init__(config)
1069
+
1070
+ self.bert = BertModel(config)
1071
+ self.cls = BertPreTrainingHeads(config)
1072
+
1073
+ # Initialize weights and apply final processing
1074
+ self.post_init()
1075
+
1076
+ def get_output_embeddings(self):
1077
+ return self.cls.predictions.decoder
1078
+
1079
+ def set_output_embeddings(self, new_embeddings):
1080
+ self.cls.predictions.decoder = new_embeddings
1081
+
1082
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1083
+ @replace_return_docstrings(output_type=BertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
1084
+ def forward(
1085
+ self,
1086
+ input_ids: Optional[torch.Tensor] = None,
1087
+ attention_mask: Optional[torch.Tensor] = None,
1088
+ token_type_ids: Optional[torch.Tensor] = None,
1089
+ position_ids: Optional[torch.Tensor] = None,
1090
+ head_mask: Optional[torch.Tensor] = None,
1091
+ inputs_embeds: Optional[torch.Tensor] = None,
1092
+ labels: Optional[torch.Tensor] = None,
1093
+ next_sentence_label: Optional[torch.Tensor] = None,
1094
+ output_attentions: Optional[bool] = None,
1095
+ output_hidden_states: Optional[bool] = None,
1096
+ return_dict: Optional[bool] = None,
1097
+ ) -> Union[Tuple[torch.Tensor], BertForPreTrainingOutput]:
1098
+ r"""
1099
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1100
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
1101
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked),
1102
+ the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
1103
+ next_sentence_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1104
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence
1105
+ pair (see `input_ids` docstring) Indices should be in `[0, 1]`:
1106
+
1107
+ - 0 indicates sequence B is a continuation of sequence A,
1108
+ - 1 indicates sequence B is a random sequence.
1109
+ kwargs (`Dict[str, any]`, optional, defaults to *{}*):
1110
+ Used to hide legacy arguments that have been deprecated.
1111
+
1112
+ Returns:
1113
+
1114
+ Example:
1115
+
1116
+ ```python
1117
+ >>> from transformers import AutoTokenizer, BertForPreTraining
1118
+ >>> import torch
1119
+
1120
+ >>> tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
1121
+ >>> model = BertForPreTraining.from_pretrained("bert-base-uncased")
1122
+
1123
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
1124
+ >>> outputs = model(**inputs)
1125
+
1126
+ >>> prediction_logits = outputs.prediction_logits
1127
+ >>> seq_relationship_logits = outputs.seq_relationship_logits
1128
+ ```
1129
+ """
1130
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1131
+
1132
+ outputs = self.bert(
1133
+ input_ids,
1134
+ attention_mask=attention_mask,
1135
+ token_type_ids=token_type_ids,
1136
+ position_ids=position_ids,
1137
+ head_mask=head_mask,
1138
+ inputs_embeds=inputs_embeds,
1139
+ output_attentions=output_attentions,
1140
+ output_hidden_states=output_hidden_states,
1141
+ return_dict=return_dict,
1142
+ )
1143
+
1144
+ sequence_output, pooled_output = outputs[:2]
1145
+ prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
1146
+
1147
+ total_loss = None
1148
+ if labels is not None and next_sentence_label is not None:
1149
+ loss_fct = CrossEntropyLoss()
1150
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1151
+ next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
1152
+ total_loss = masked_lm_loss + next_sentence_loss
1153
+
1154
+ if not return_dict:
1155
+ output = (prediction_scores, seq_relationship_score) + outputs[2:]
1156
+ return ((total_loss,) + output) if total_loss is not None else output
1157
+
1158
+ return BertForPreTrainingOutput(
1159
+ loss=total_loss,
1160
+ prediction_logits=prediction_scores,
1161
+ seq_relationship_logits=seq_relationship_score,
1162
+ hidden_states=outputs.hidden_states,
1163
+ attentions=outputs.attentions,
1164
+ )
1165
+
1166
+
1167
+ @add_start_docstrings(
1168
+ """Bert Model with a `language modeling` head on top for CLM fine-tuning.""", BERT_START_DOCSTRING
1169
+ )
1170
+ class BertLMHeadModel(BertPreTrainedModel):
1171
+ _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
1172
+
1173
+ def __init__(self, config):
1174
+ super().__init__(config)
1175
+
1176
+ if not config.is_decoder:
1177
+ logger.warning("If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`")
1178
+
1179
+ self.bert = BertModel(config, add_pooling_layer=False)
1180
+ self.cls = BertOnlyMLMHead(config)
1181
+
1182
+ # Initialize weights and apply final processing
1183
+ self.post_init()
1184
+
1185
+ def get_output_embeddings(self):
1186
+ return self.cls.predictions.decoder
1187
+
1188
+ def set_output_embeddings(self, new_embeddings):
1189
+ self.cls.predictions.decoder = new_embeddings
1190
+
1191
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1192
+ @add_code_sample_docstrings(
1193
+ checkpoint=_CHECKPOINT_FOR_DOC,
1194
+ output_type=CausalLMOutputWithCrossAttentions,
1195
+ config_class=_CONFIG_FOR_DOC,
1196
+ )
1197
+ def forward(
1198
+ self,
1199
+ input_ids: Optional[torch.Tensor] = None,
1200
+ attention_mask: Optional[torch.Tensor] = None,
1201
+ token_type_ids: Optional[torch.Tensor] = None,
1202
+ position_ids: Optional[torch.Tensor] = None,
1203
+ head_mask: Optional[torch.Tensor] = None,
1204
+ inputs_embeds: Optional[torch.Tensor] = None,
1205
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1206
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1207
+ labels: Optional[torch.Tensor] = None,
1208
+ past_key_values: Optional[List[torch.Tensor]] = None,
1209
+ use_cache: Optional[bool] = None,
1210
+ output_attentions: Optional[bool] = None,
1211
+ output_hidden_states: Optional[bool] = None,
1212
+ return_dict: Optional[bool] = None,
1213
+ ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
1214
+ r"""
1215
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1216
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
1217
+ the model is configured as a decoder.
1218
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
1219
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1220
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
1221
+
1222
+ - 1 for tokens that are **not masked**,
1223
+ - 0 for tokens that are **masked**.
1224
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1225
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
1226
+ `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
1227
+ ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`
1228
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1229
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1230
+
1231
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
1232
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
1233
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
1234
+ use_cache (`bool`, *optional*):
1235
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1236
+ `past_key_values`).
1237
+ """
1238
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1239
+ if labels is not None:
1240
+ use_cache = False
1241
+
1242
+ outputs = self.bert(
1243
+ input_ids,
1244
+ attention_mask=attention_mask,
1245
+ token_type_ids=token_type_ids,
1246
+ position_ids=position_ids,
1247
+ head_mask=head_mask,
1248
+ inputs_embeds=inputs_embeds,
1249
+ encoder_hidden_states=encoder_hidden_states,
1250
+ encoder_attention_mask=encoder_attention_mask,
1251
+ past_key_values=past_key_values,
1252
+ use_cache=use_cache,
1253
+ output_attentions=output_attentions,
1254
+ output_hidden_states=output_hidden_states,
1255
+ return_dict=return_dict,
1256
+ )
1257
+
1258
+ sequence_output = outputs[0]
1259
+ prediction_scores = self.cls(sequence_output)
1260
+
1261
+ lm_loss = None
1262
+ if labels is not None:
1263
+ # we are doing next-token prediction; shift prediction scores and input ids by one
1264
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
1265
+ labels = labels[:, 1:].contiguous()
1266
+ loss_fct = CrossEntropyLoss()
1267
+ lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1268
+
1269
+ if not return_dict:
1270
+ output = (prediction_scores,) + outputs[2:]
1271
+ return ((lm_loss,) + output) if lm_loss is not None else output
1272
+
1273
+ return CausalLMOutputWithCrossAttentions(
1274
+ loss=lm_loss,
1275
+ logits=prediction_scores,
1276
+ past_key_values=outputs.past_key_values,
1277
+ hidden_states=outputs.hidden_states,
1278
+ attentions=outputs.attentions,
1279
+ cross_attentions=outputs.cross_attentions,
1280
+ )
1281
+ def prepare_inputs_for_generation(
1282
+ self, input_ids, encoder_hidden_states, past_key_values=None, attention_mask=None, use_cache=True, **model_kwargs
1283
+ ):
1284
+ input_shape = input_ids.shape
1285
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
1286
+ if attention_mask is None:
1287
+ attention_mask = input_ids.new_ones(input_shape)
1288
+
1289
+ # cut decoder_input_ids if past_key_values is used
1290
+ if past_key_values is not None:
1291
+ input_ids = input_ids[:, -1:]
1292
+
1293
+ return {
1294
+ "input_ids": input_ids,
1295
+ "attention_mask": attention_mask,
1296
+ "past_key_values": past_key_values,
1297
+ "use_cache": use_cache,
1298
+ "encoder_hidden_states": encoder_hidden_states,
1299
+ }
1300
+
1301
+ # def prepare_inputs_for_generation(
1302
+ # self, input_ids, past_key_values=None, attention_mask=None, use_cache=True, **model_kwargs
1303
+ # ):
1304
+ # input_shape = input_ids.shape
1305
+ # # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
1306
+ # if attention_mask is None:
1307
+ # attention_mask = input_ids.new_ones(input_shape)
1308
+
1309
+ # # cut decoder_input_ids if past_key_values is used
1310
+ # if past_key_values is not None:
1311
+ # input_ids = input_ids[:, -1:]
1312
+
1313
+ # return {
1314
+ # "input_ids": input_ids,
1315
+ # "attention_mask": attention_mask,
1316
+ # "past_key_values": past_key_values,
1317
+ # "use_cache": use_cache,
1318
+ # }
1319
+
1320
+ def _reorder_cache(self, past_key_values, beam_idx):
1321
+ reordered_past = ()
1322
+ for layer_past in past_key_values:
1323
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
1324
+ return reordered_past
1325
+
1326
+
1327
+ @add_start_docstrings("""Bert Model with a `language modeling` head on top.""", BERT_START_DOCSTRING)
1328
+ class BertForMaskedLM(BertPreTrainedModel):
1329
+ _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
1330
+
1331
+ def __init__(self, config):
1332
+ super().__init__(config)
1333
+
1334
+ if config.is_decoder:
1335
+ logger.warning(
1336
+ "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for "
1337
+ "bi-directional self-attention."
1338
+ )
1339
+
1340
+ self.bert = BertModel(config, add_pooling_layer=False)
1341
+ self.cls = BertOnlyMLMHead(config)
1342
+
1343
+ # Initialize weights and apply final processing
1344
+ self.post_init()
1345
+
1346
+ def get_output_embeddings(self):
1347
+ return self.cls.predictions.decoder
1348
+
1349
+ def set_output_embeddings(self, new_embeddings):
1350
+ self.cls.predictions.decoder = new_embeddings
1351
+
1352
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1353
+ @add_code_sample_docstrings(
1354
+ checkpoint=_CHECKPOINT_FOR_DOC,
1355
+ output_type=MaskedLMOutput,
1356
+ config_class=_CONFIG_FOR_DOC,
1357
+ expected_output="'paris'",
1358
+ expected_loss=0.88,
1359
+ )
1360
+ def forward(
1361
+ self,
1362
+ input_ids: Optional[torch.Tensor] = None,
1363
+ attention_mask: Optional[torch.Tensor] = None,
1364
+ token_type_ids: Optional[torch.Tensor] = None,
1365
+ position_ids: Optional[torch.Tensor] = None,
1366
+ head_mask: Optional[torch.Tensor] = None,
1367
+ inputs_embeds: Optional[torch.Tensor] = None,
1368
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1369
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1370
+ labels: Optional[torch.Tensor] = None,
1371
+ output_attentions: Optional[bool] = None,
1372
+ output_hidden_states: Optional[bool] = None,
1373
+ return_dict: Optional[bool] = None,
1374
+ ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
1375
+ r"""
1376
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1377
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
1378
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
1379
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
1380
+ """
1381
+
1382
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1383
+
1384
+ outputs = self.bert(
1385
+ input_ids,
1386
+ attention_mask=attention_mask,
1387
+ token_type_ids=token_type_ids,
1388
+ position_ids=position_ids,
1389
+ head_mask=head_mask,
1390
+ inputs_embeds=inputs_embeds,
1391
+ encoder_hidden_states=encoder_hidden_states,
1392
+ encoder_attention_mask=encoder_attention_mask,
1393
+ output_attentions=output_attentions,
1394
+ output_hidden_states=output_hidden_states,
1395
+ return_dict=return_dict,
1396
+ )
1397
+
1398
+ sequence_output = outputs[0]
1399
+ prediction_scores = self.cls(sequence_output)
1400
+
1401
+ masked_lm_loss = None
1402
+ if labels is not None:
1403
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
1404
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1405
+
1406
+ if not return_dict:
1407
+ output = (prediction_scores,) + outputs[2:]
1408
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1409
+
1410
+ return MaskedLMOutput(
1411
+ loss=masked_lm_loss,
1412
+ logits=prediction_scores,
1413
+ hidden_states=outputs.hidden_states,
1414
+ attentions=outputs.attentions,
1415
+ )
1416
+
1417
+ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
1418
+ input_shape = input_ids.shape
1419
+ effective_batch_size = input_shape[0]
1420
+
1421
+ # add a dummy token
1422
+ if self.config.pad_token_id is None:
1423
+ raise ValueError("The PAD token should be defined for generation")
1424
+
1425
+ attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
1426
+ dummy_token = torch.full(
1427
+ (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
1428
+ )
1429
+ input_ids = torch.cat([input_ids, dummy_token], dim=1)
1430
+
1431
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
1432
+
1433
+
1434
+ @add_start_docstrings(
1435
+ """Bert Model with a `next sentence prediction (classification)` head on top.""",
1436
+ BERT_START_DOCSTRING,
1437
+ )
1438
+ class BertForNextSentencePrediction(BertPreTrainedModel):
1439
+ def __init__(self, config):
1440
+ super().__init__(config)
1441
+
1442
+ self.bert = BertModel(config)
1443
+ self.cls = BertOnlyNSPHead(config)
1444
+
1445
+ # Initialize weights and apply final processing
1446
+ self.post_init()
1447
+
1448
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1449
+ @replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
1450
+ def forward(
1451
+ self,
1452
+ input_ids: Optional[torch.Tensor] = None,
1453
+ attention_mask: Optional[torch.Tensor] = None,
1454
+ token_type_ids: Optional[torch.Tensor] = None,
1455
+ position_ids: Optional[torch.Tensor] = None,
1456
+ head_mask: Optional[torch.Tensor] = None,
1457
+ inputs_embeds: Optional[torch.Tensor] = None,
1458
+ labels: Optional[torch.Tensor] = None,
1459
+ output_attentions: Optional[bool] = None,
1460
+ output_hidden_states: Optional[bool] = None,
1461
+ return_dict: Optional[bool] = None,
1462
+ **kwargs,
1463
+ ) -> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]:
1464
+ r"""
1465
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1466
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
1467
+ (see `input_ids` docstring). Indices should be in `[0, 1]`:
1468
+
1469
+ - 0 indicates sequence B is a continuation of sequence A,
1470
+ - 1 indicates sequence B is a random sequence.
1471
+
1472
+ Returns:
1473
+
1474
+ Example:
1475
+
1476
+ ```python
1477
+ >>> from transformers import AutoTokenizer, BertForNextSentencePrediction
1478
+ >>> import torch
1479
+
1480
+ >>> tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
1481
+ >>> model = BertForNextSentencePrediction.from_pretrained("bert-base-uncased")
1482
+
1483
+ >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
1484
+ >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
1485
+ >>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt")
1486
+
1487
+ >>> outputs = model(**encoding, labels=torch.LongTensor([1]))
1488
+ >>> logits = outputs.logits
1489
+ >>> assert logits[0, 0] < logits[0, 1] # next sentence was random
1490
+ ```
1491
+ """
1492
+
1493
+ if "next_sentence_label" in kwargs:
1494
+ warnings.warn(
1495
+ "The `next_sentence_label` argument is deprecated and will be removed in a future version, use"
1496
+ " `labels` instead.",
1497
+ FutureWarning,
1498
+ )
1499
+ labels = kwargs.pop("next_sentence_label")
1500
+
1501
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1502
+
1503
+ outputs = self.bert(
1504
+ input_ids,
1505
+ attention_mask=attention_mask,
1506
+ token_type_ids=token_type_ids,
1507
+ position_ids=position_ids,
1508
+ head_mask=head_mask,
1509
+ inputs_embeds=inputs_embeds,
1510
+ output_attentions=output_attentions,
1511
+ output_hidden_states=output_hidden_states,
1512
+ return_dict=return_dict,
1513
+ )
1514
+
1515
+ pooled_output = outputs[1]
1516
+
1517
+ seq_relationship_scores = self.cls(pooled_output)
1518
+
1519
+ next_sentence_loss = None
1520
+ if labels is not None:
1521
+ loss_fct = CrossEntropyLoss()
1522
+ next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1))
1523
+
1524
+ if not return_dict:
1525
+ output = (seq_relationship_scores,) + outputs[2:]
1526
+ return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
1527
+
1528
+ return NextSentencePredictorOutput(
1529
+ loss=next_sentence_loss,
1530
+ logits=seq_relationship_scores,
1531
+ hidden_states=outputs.hidden_states,
1532
+ attentions=outputs.attentions,
1533
+ )
1534
+
1535
+
1536
+ @add_start_docstrings(
1537
+ """
1538
+ Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
1539
+ output) e.g. for GLUE tasks.
1540
+ """,
1541
+ BERT_START_DOCSTRING,
1542
+ )
1543
+ class BertForSequenceClassification(BertPreTrainedModel):
1544
+ def __init__(self, config):
1545
+ super().__init__(config)
1546
+ self.num_labels = config.num_labels
1547
+ self.config = config
1548
+
1549
+ self.bert = BertModel(config)
1550
+ classifier_dropout = (
1551
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1552
+ )
1553
+ self.dropout = nn.Dropout(classifier_dropout)
1554
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1555
+
1556
+ # Initialize weights and apply final processing
1557
+ self.post_init()
1558
+
1559
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1560
+ @add_code_sample_docstrings(
1561
+ checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION,
1562
+ output_type=SequenceClassifierOutput,
1563
+ config_class=_CONFIG_FOR_DOC,
1564
+ expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
1565
+ expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
1566
+ )
1567
+ def forward(
1568
+ self,
1569
+ input_ids: Optional[torch.Tensor] = None,
1570
+ attention_mask: Optional[torch.Tensor] = None,
1571
+ token_type_ids: Optional[torch.Tensor] = None,
1572
+ position_ids: Optional[torch.Tensor] = None,
1573
+ head_mask: Optional[torch.Tensor] = None,
1574
+ inputs_embeds: Optional[torch.Tensor] = None,
1575
+ labels: Optional[torch.Tensor] = None,
1576
+ output_attentions: Optional[bool] = None,
1577
+ output_hidden_states: Optional[bool] = None,
1578
+ return_dict: Optional[bool] = None,
1579
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
1580
+ r"""
1581
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1582
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1583
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1584
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1585
+ """
1586
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1587
+
1588
+ outputs = self.bert(
1589
+ input_ids,
1590
+ attention_mask=attention_mask,
1591
+ token_type_ids=token_type_ids,
1592
+ position_ids=position_ids,
1593
+ head_mask=head_mask,
1594
+ inputs_embeds=inputs_embeds,
1595
+ output_attentions=output_attentions,
1596
+ output_hidden_states=output_hidden_states,
1597
+ return_dict=return_dict,
1598
+ )
1599
+
1600
+ pooled_output = outputs[1]
1601
+
1602
+ pooled_output = self.dropout(pooled_output)
1603
+ logits = self.classifier(pooled_output)
1604
+
1605
+ loss = None
1606
+ if labels is not None:
1607
+ if self.config.problem_type is None:
1608
+ if self.num_labels == 1:
1609
+ self.config.problem_type = "regression"
1610
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1611
+ self.config.problem_type = "single_label_classification"
1612
+ else:
1613
+ self.config.problem_type = "multi_label_classification"
1614
+
1615
+ if self.config.problem_type == "regression":
1616
+ loss_fct = MSELoss()
1617
+ if self.num_labels == 1:
1618
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1619
+ else:
1620
+ loss = loss_fct(logits, labels)
1621
+ elif self.config.problem_type == "single_label_classification":
1622
+ loss_fct = CrossEntropyLoss()
1623
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1624
+ elif self.config.problem_type == "multi_label_classification":
1625
+ loss_fct = BCEWithLogitsLoss()
1626
+ loss = loss_fct(logits, labels)
1627
+ if not return_dict:
1628
+ output = (logits,) + outputs[2:]
1629
+ return ((loss,) + output) if loss is not None else output
1630
+
1631
+ return SequenceClassifierOutput(
1632
+ loss=loss,
1633
+ logits=logits,
1634
+ hidden_states=outputs.hidden_states,
1635
+ attentions=outputs.attentions,
1636
+ )
1637
+
1638
+
1639
+ @add_start_docstrings(
1640
+ """
1641
+ Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
1642
+ softmax) e.g. for RocStories/SWAG tasks.
1643
+ """,
1644
+ BERT_START_DOCSTRING,
1645
+ )
1646
+ class BertForMultipleChoice(BertPreTrainedModel):
1647
+ def __init__(self, config):
1648
+ super().__init__(config)
1649
+
1650
+ self.bert = BertModel(config)
1651
+ classifier_dropout = (
1652
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1653
+ )
1654
+ self.dropout = nn.Dropout(classifier_dropout)
1655
+ self.classifier = nn.Linear(config.hidden_size, 1)
1656
+
1657
+ # Initialize weights and apply final processing
1658
+ self.post_init()
1659
+
1660
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
1661
+ @add_code_sample_docstrings(
1662
+ checkpoint=_CHECKPOINT_FOR_DOC,
1663
+ output_type=MultipleChoiceModelOutput,
1664
+ config_class=_CONFIG_FOR_DOC,
1665
+ )
1666
+ def forward(
1667
+ self,
1668
+ input_ids: Optional[torch.Tensor] = None,
1669
+ attention_mask: Optional[torch.Tensor] = None,
1670
+ token_type_ids: Optional[torch.Tensor] = None,
1671
+ position_ids: Optional[torch.Tensor] = None,
1672
+ head_mask: Optional[torch.Tensor] = None,
1673
+ inputs_embeds: Optional[torch.Tensor] = None,
1674
+ labels: Optional[torch.Tensor] = None,
1675
+ output_attentions: Optional[bool] = None,
1676
+ output_hidden_states: Optional[bool] = None,
1677
+ return_dict: Optional[bool] = None,
1678
+ ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
1679
+ r"""
1680
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1681
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
1682
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
1683
+ `input_ids` above)
1684
+ """
1685
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1686
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
1687
+
1688
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
1689
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
1690
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
1691
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
1692
+ inputs_embeds = (
1693
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
1694
+ if inputs_embeds is not None
1695
+ else None
1696
+ )
1697
+
1698
+ outputs = self.bert(
1699
+ input_ids,
1700
+ attention_mask=attention_mask,
1701
+ token_type_ids=token_type_ids,
1702
+ position_ids=position_ids,
1703
+ head_mask=head_mask,
1704
+ inputs_embeds=inputs_embeds,
1705
+ output_attentions=output_attentions,
1706
+ output_hidden_states=output_hidden_states,
1707
+ return_dict=return_dict,
1708
+ )
1709
+
1710
+ pooled_output = outputs[1]
1711
+
1712
+ pooled_output = self.dropout(pooled_output)
1713
+ logits = self.classifier(pooled_output)
1714
+ reshaped_logits = logits.view(-1, num_choices)
1715
+
1716
+ loss = None
1717
+ if labels is not None:
1718
+ loss_fct = CrossEntropyLoss()
1719
+ loss = loss_fct(reshaped_logits, labels)
1720
+
1721
+ if not return_dict:
1722
+ output = (reshaped_logits,) + outputs[2:]
1723
+ return ((loss,) + output) if loss is not None else output
1724
+
1725
+ return MultipleChoiceModelOutput(
1726
+ loss=loss,
1727
+ logits=reshaped_logits,
1728
+ hidden_states=outputs.hidden_states,
1729
+ attentions=outputs.attentions,
1730
+ )
1731
+
1732
+
1733
+ @add_start_docstrings(
1734
+ """
1735
+ Bert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1736
+ Named-Entity-Recognition (NER) tasks.
1737
+ """,
1738
+ BERT_START_DOCSTRING,
1739
+ )
1740
+ class BertForTokenClassification(BertPreTrainedModel):
1741
+ def __init__(self, config):
1742
+ super().__init__(config)
1743
+ self.num_labels = config.num_labels
1744
+
1745
+ self.bert = BertModel(config, add_pooling_layer=False)
1746
+ classifier_dropout = (
1747
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1748
+ )
1749
+ self.dropout = nn.Dropout(classifier_dropout)
1750
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1751
+
1752
+ # Initialize weights and apply final processing
1753
+ self.post_init()
1754
+
1755
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1756
+ @add_code_sample_docstrings(
1757
+ checkpoint=_CHECKPOINT_FOR_TOKEN_CLASSIFICATION,
1758
+ output_type=TokenClassifierOutput,
1759
+ config_class=_CONFIG_FOR_DOC,
1760
+ expected_output=_TOKEN_CLASS_EXPECTED_OUTPUT,
1761
+ expected_loss=_TOKEN_CLASS_EXPECTED_LOSS,
1762
+ )
1763
+ def forward(
1764
+ self,
1765
+ input_ids: Optional[torch.Tensor] = None,
1766
+ attention_mask: Optional[torch.Tensor] = None,
1767
+ token_type_ids: Optional[torch.Tensor] = None,
1768
+ position_ids: Optional[torch.Tensor] = None,
1769
+ head_mask: Optional[torch.Tensor] = None,
1770
+ inputs_embeds: Optional[torch.Tensor] = None,
1771
+ labels: Optional[torch.Tensor] = None,
1772
+ output_attentions: Optional[bool] = None,
1773
+ output_hidden_states: Optional[bool] = None,
1774
+ return_dict: Optional[bool] = None,
1775
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
1776
+ r"""
1777
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1778
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
1779
+ """
1780
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1781
+
1782
+ outputs = self.bert(
1783
+ input_ids,
1784
+ attention_mask=attention_mask,
1785
+ token_type_ids=token_type_ids,
1786
+ position_ids=position_ids,
1787
+ head_mask=head_mask,
1788
+ inputs_embeds=inputs_embeds,
1789
+ output_attentions=output_attentions,
1790
+ output_hidden_states=output_hidden_states,
1791
+ return_dict=return_dict,
1792
+ )
1793
+
1794
+ sequence_output = outputs[0]
1795
+
1796
+ sequence_output = self.dropout(sequence_output)
1797
+ logits = self.classifier(sequence_output)
1798
+
1799
+ loss = None
1800
+ if labels is not None:
1801
+ loss_fct = CrossEntropyLoss()
1802
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1803
+
1804
+ if not return_dict:
1805
+ output = (logits,) + outputs[2:]
1806
+ return ((loss,) + output) if loss is not None else output
1807
+
1808
+ return TokenClassifierOutput(
1809
+ loss=loss,
1810
+ logits=logits,
1811
+ hidden_states=outputs.hidden_states,
1812
+ attentions=outputs.attentions,
1813
+ )
1814
+
1815
+
1816
+ @add_start_docstrings(
1817
+ """
1818
+ Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
1819
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
1820
+ """,
1821
+ BERT_START_DOCSTRING,
1822
+ )
1823
+ class BertForQuestionAnswering(BertPreTrainedModel):
1824
+ def __init__(self, config):
1825
+ super().__init__(config)
1826
+ self.num_labels = config.num_labels
1827
+
1828
+ self.bert = BertModel(config, add_pooling_layer=False)
1829
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1830
+
1831
+ # Initialize weights and apply final processing
1832
+ self.post_init()
1833
+
1834
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1835
+ @add_code_sample_docstrings(
1836
+ checkpoint=_CHECKPOINT_FOR_QA,
1837
+ output_type=QuestionAnsweringModelOutput,
1838
+ config_class=_CONFIG_FOR_DOC,
1839
+ qa_target_start_index=_QA_TARGET_START_INDEX,
1840
+ qa_target_end_index=_QA_TARGET_END_INDEX,
1841
+ expected_output=_QA_EXPECTED_OUTPUT,
1842
+ expected_loss=_QA_EXPECTED_LOSS,
1843
+ )
1844
+ def forward(
1845
+ self,
1846
+ input_ids: Optional[torch.Tensor] = None,
1847
+ attention_mask: Optional[torch.Tensor] = None,
1848
+ token_type_ids: Optional[torch.Tensor] = None,
1849
+ position_ids: Optional[torch.Tensor] = None,
1850
+ head_mask: Optional[torch.Tensor] = None,
1851
+ inputs_embeds: Optional[torch.Tensor] = None,
1852
+ start_positions: Optional[torch.Tensor] = None,
1853
+ end_positions: Optional[torch.Tensor] = None,
1854
+ output_attentions: Optional[bool] = None,
1855
+ output_hidden_states: Optional[bool] = None,
1856
+ return_dict: Optional[bool] = None,
1857
+ ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
1858
+ r"""
1859
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1860
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1861
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1862
+ are not taken into account for computing the loss.
1863
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1864
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1865
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1866
+ are not taken into account for computing the loss.
1867
+ """
1868
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1869
+
1870
+ outputs = self.bert(
1871
+ input_ids,
1872
+ attention_mask=attention_mask,
1873
+ token_type_ids=token_type_ids,
1874
+ position_ids=position_ids,
1875
+ head_mask=head_mask,
1876
+ inputs_embeds=inputs_embeds,
1877
+ output_attentions=output_attentions,
1878
+ output_hidden_states=output_hidden_states,
1879
+ return_dict=return_dict,
1880
+ )
1881
+
1882
+ sequence_output = outputs[0]
1883
+
1884
+ logits = self.qa_outputs(sequence_output)
1885
+ start_logits, end_logits = logits.split(1, dim=-1)
1886
+ start_logits = start_logits.squeeze(-1).contiguous()
1887
+ end_logits = end_logits.squeeze(-1).contiguous()
1888
+
1889
+ total_loss = None
1890
+ if start_positions is not None and end_positions is not None:
1891
+ # If we are on multi-GPU, split add a dimension
1892
+ if len(start_positions.size()) > 1:
1893
+ start_positions = start_positions.squeeze(-1)
1894
+ if len(end_positions.size()) > 1:
1895
+ end_positions = end_positions.squeeze(-1)
1896
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1897
+ ignored_index = start_logits.size(1)
1898
+ start_positions = start_positions.clamp(0, ignored_index)
1899
+ end_positions = end_positions.clamp(0, ignored_index)
1900
+
1901
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1902
+ start_loss = loss_fct(start_logits, start_positions)
1903
+ end_loss = loss_fct(end_logits, end_positions)
1904
+ total_loss = (start_loss + end_loss) / 2
1905
+
1906
+ if not return_dict:
1907
+ output = (start_logits, end_logits) + outputs[2:]
1908
+ return ((total_loss,) + output) if total_loss is not None else output
1909
+
1910
+ return QuestionAnsweringModelOutput(
1911
+ loss=total_loss,
1912
+ start_logits=start_logits,
1913
+ end_logits=end_logits,
1914
+ hidden_states=outputs.hidden_states,
1915
+ attentions=outputs.attentions,
1916
+ )
baseline/my_model.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from transformers import PreTrainedModel, GenerationConfig, BertLMHeadModel
4
+ from transformers.modeling_outputs import Seq2SeqLMOutput
5
+ from torch import nn
6
+ from torch.nn import CrossEntropyLoss
7
+ from typing import Optional, Tuple, Union
8
+ from torch.utils.data import Dataset
9
+ from PIL import Image
10
+
11
+ class MyModel(PreTrainedModel):
12
+ def __init__(self, config, trans_model, nougat_model):
13
+ super().__init__(config)
14
+ self.encoder = nougat_model.encoder
15
+ self.decoder = trans_model.decoder
16
+ self.project = nn.Linear(self.encoder.config.hidden_size, self.decoder.config.hidden_size)
17
+
18
+ def forward(
19
+ self,
20
+ pixel_values: Optional[torch.FloatTensor] = None,
21
+ decoder_input_ids: Optional[torch.LongTensor] = None,
22
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
23
+ encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
24
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
25
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
26
+ labels: Optional[torch.LongTensor] = None,
27
+ use_cache: Optional[bool] = None,
28
+ output_attentions: Optional[bool] = None,
29
+ output_hidden_states: Optional[bool] = None,
30
+ return_dict=True,
31
+ **kwargs,
32
+ ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
33
+
34
+ encoder_outputs = self.encoder(
35
+ pixel_values=pixel_values,
36
+ output_attentions=output_attentions,
37
+ output_hidden_states=output_hidden_states,
38
+ return_dict=return_dict,
39
+ )
40
+ encoder_hidden_states = encoder_outputs.last_hidden_state
41
+ encoder_hidden_states_proj = self.project(encoder_hidden_states)
42
+
43
+ decoder_outputs = self.decoder(
44
+ input_ids=decoder_input_ids,
45
+ attention_mask=decoder_attention_mask,
46
+ encoder_hidden_states=encoder_hidden_states_proj,
47
+ output_attentions=output_attentions,
48
+ output_hidden_states=output_hidden_states,
49
+ use_cache=use_cache,
50
+ past_key_values=past_key_values,
51
+ return_dict=return_dict,
52
+ )
53
+
54
+ # Compute loss independent from decoder (as some shift the logits inside them)
55
+ loss = None
56
+ if labels is not None:
57
+ logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
58
+ loss_fct_trans = CrossEntropyLoss()
59
+ loss_trans = loss_fct_trans(logits.reshape(-1, self.decoder.config.vocab_size), labels.reshape(-1).long())
60
+
61
+ loss = loss_trans
62
+
63
+ if not return_dict:
64
+ if loss is not None:
65
+ return (loss,) + decoder_outputs + encoder_outputs
66
+ else:
67
+ return decoder_outputs + encoder_outputs
68
+
69
+ return Seq2SeqLMOutput(
70
+ loss=loss,
71
+ logits=decoder_outputs.logits,
72
+ past_key_values=decoder_outputs.past_key_values,
73
+ decoder_hidden_states=decoder_outputs.hidden_states,
74
+ decoder_attentions=decoder_outputs.attentions,
75
+ cross_attentions=decoder_outputs.cross_attentions,
76
+ encoder_last_hidden_state=encoder_hidden_states,
77
+ )
78
+
79
+ def generate(
80
+ self,
81
+ pixel_values: Optional[torch.FloatTensor] = None,
82
+ decoder_input_ids: Optional[torch.LongTensor] = None,
83
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
84
+ encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
85
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
86
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
87
+ labels: Optional[torch.LongTensor] = None,
88
+ use_cache: Optional[bool] = None,
89
+ output_attentions: Optional[bool] = None,
90
+ output_hidden_states: Optional[bool] = None,
91
+ return_dict=True,
92
+ generation_config: Optional[GenerationConfig] = None,
93
+ **kwargs,
94
+ ):
95
+
96
+ encoder_outputs = self.encoder(
97
+ pixel_values=pixel_values,
98
+ output_attentions=output_attentions,
99
+ output_hidden_states=output_hidden_states,
100
+ return_dict=return_dict,
101
+ )
102
+ encoder_hidden_states = encoder_outputs.last_hidden_state
103
+ encoder_hidden_states_proj = self.project(encoder_hidden_states)
104
+
105
+ generation_outputs = self.decoder.generate(
106
+ encoder_hidden_states=encoder_hidden_states_proj,
107
+ generation_config=generation_config,
108
+ )
109
+
110
+ return generation_outputs
111
+
112
+ class MyDataset(Dataset):
113
+ def __init__(self, processor, tokenizer, name_list, max_length, image_dir, text_dir):
114
+ self.processor = processor
115
+ self.tokenizer = tokenizer
116
+ self.name_list = name_list
117
+ self.max_length = max_length
118
+ self.image_dir = image_dir
119
+ self.text_dir = text_dir
120
+
121
+ def __len__(self):
122
+ return len(self.name_list)
123
+
124
+ def __getitem__(self, index):
125
+ encoding = {}
126
+ image_file_path = os.path.join(self.image_dir, self.name_list[index]+'.png')
127
+ image = Image.open(image_file_path)
128
+ if image.mode != 'RGB':
129
+ image = image.convert('RGB')
130
+ pixel_values = self.processor(image, return_tensors="pt").pixel_values.squeeze(0)
131
+ encoding['pixel_values'] = pixel_values
132
+
133
+ text_file_path = os.path.join(self.text_dir, self.name_list[index]+'.mmd')
134
+ with open(text_file_path, 'r') as f:
135
+ lines = f.readlines()
136
+ text = ''.join(lines)
137
+ input_ids = self.tokenizer(text, max_length=self.max_length, truncation=True).input_ids
138
+ input_ids = [x for x in input_ids if x != 6]
139
+ input_ids = [self.tokenizer.bos_token_id] + input_ids[1:]
140
+
141
+ decoder_input_ids = input_ids + [self.tokenizer.pad_token_id]*(self.max_length-len(input_ids))
142
+ decoder_input_ids = torch.tensor(decoder_input_ids, dtype=torch.long)
143
+ labels = input_ids[1:] + [-100]*(self.max_length-len(input_ids)+1)
144
+ labels = torch.tensor(labels, dtype=torch.long)
145
+ encoding['decoder_input_ids'] = decoder_input_ids
146
+ encoding['labels'] = labels
147
+
148
+ return encoding
baseline/train.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ from transformers import DonutProcessor, AutoTokenizer
5
+ import argparse
6
+ from transformers import VisionEncoderDecoderModel, EncoderDecoderModel, EncoderDecoderConfig, BertConfig
7
+ from my_model import MyModel, MyDataset
8
+ from transformers import Trainer, TrainingArguments
9
+
10
+ def train(args):
11
+ processor = DonutProcessor.from_pretrained(args.donut_dir)
12
+ processor.image_processor.size = {'height': 896, 'width': 672}
13
+ processor.image_processor.image_mean = [0.485, 0.456, 0.406]
14
+ processor.image_processor.image_std = [0.229, 0.224, 0.225]
15
+ tokenizer = AutoTokenizer.from_pretrained(os.path.join(args.base_dir, 'zh_tokenizer'))
16
+
17
+ image_dir = os.path.join(args.dataset_dir, 'imgs')
18
+ text_dir = os.path.join(args.dataset_dir, 'zh_mmd')
19
+
20
+ json_file_path = os.path.join(args.dataset_dir, 'split_dataset.json')
21
+ with open(json_file_path, 'r') as f:
22
+ json_dict = json.load(f)
23
+ train_name_list = json_dict['train_name_list']
24
+ valid_name_list = json_dict['valid_name_list']
25
+
26
+ train_dataset = MyDataset(processor, tokenizer, train_name_list, args.max_length, image_dir, text_dir)
27
+ valid_dataset = MyDataset(processor, tokenizer, valid_name_list, args.max_length, image_dir, text_dir)
28
+
29
+ encoder_config = BertConfig()
30
+ decoder_config = BertConfig()
31
+ encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder_config, decoder_config)
32
+ encoder_decoder_config.decoder.bos_token_id = tokenizer.bos_token_id
33
+ encoder_decoder_config.decoder.decoder_start_token_id = tokenizer.bos_token_id
34
+ encoder_decoder_config.decoder.eos_token_id = tokenizer.eos_token_id
35
+ encoder_decoder_config.decoder.hidden_size = 512
36
+ encoder_decoder_config.decoder.intermediate_size = 2048
37
+ encoder_decoder_config.decoder.max_length = args.max_length
38
+ encoder_decoder_config.decoder.max_position_embeddings = args.max_length
39
+ encoder_decoder_config.decoder.num_attention_heads = 8
40
+ encoder_decoder_config.decoder.num_hidden_layers = 6
41
+ encoder_decoder_config.decoder.pad_token_id = tokenizer.pad_token_id
42
+ encoder_decoder_config.decoder.type_vocab_size = 1
43
+ encoder_decoder_config.decoder.vocab_size = tokenizer.vocab_size
44
+
45
+ trans_model = EncoderDecoderModel(config=encoder_decoder_config)
46
+ nougat_model = VisionEncoderDecoderModel.from_pretrained(args.nougat_dir)
47
+
48
+ model = MyModel(nougat_model.config, trans_model, nougat_model)
49
+
50
+ num_gpu = torch.cuda.device_count()
51
+ gradient_accumulation_steps = args.batch_size // (num_gpu * args.batch_size_per_gpu)
52
+
53
+ training_args = TrainingArguments(
54
+ output_dir=os.path.join(args.base_dir, 'models'),
55
+ per_device_train_batch_size=args.batch_size_per_gpu,
56
+ per_device_eval_batch_size=args.batch_size_per_gpu,
57
+ gradient_accumulation_steps=gradient_accumulation_steps,
58
+ logging_strategy='steps',
59
+ logging_steps=1,
60
+ evaluation_strategy='steps',
61
+ eval_steps=args.eval_steps,
62
+ save_strategy='steps',
63
+ save_steps=args.save_steps,
64
+ fp16=args.fp16,
65
+ learning_rate=args.learning_rate,
66
+ max_steps=args.max_steps,
67
+ warmup_steps=args.warmup_steps,
68
+ dataloader_num_workers=args.dataloader_num_workers,
69
+ )
70
+
71
+ trainer = Trainer(
72
+ model=model,
73
+ args=training_args,
74
+ train_dataset=train_dataset,
75
+ eval_dataset=valid_dataset,
76
+ )
77
+
78
+ trainer.train()
79
+
80
+ if __name__ == '__main__':
81
+ parser = argparse.ArgumentParser()
82
+ parser.add_argument("--base_dir", type=str)
83
+ parser.add_argument("--dataset_dir", type=str)
84
+ parser.add_argument("--donut_dir", type=str)
85
+ parser.add_argument("--nougat_dir", type=str)
86
+
87
+ parser.add_argument("--max_length", type=int, default=1536)
88
+ parser.add_argument("--batch_size", type=int, default=64)
89
+ parser.add_argument("--batch_size_per_gpu", type=int, default=4)
90
+ parser.add_argument("--eval_steps", type=int, default=1000)
91
+ parser.add_argument("--save_steps", type=int, default=1000)
92
+ parser.add_argument("--fp16", type=bool, default=True)
93
+ parser.add_argument("--learning_rate", type=float, default=5e-5)
94
+ parser.add_argument("--max_steps", type=int, default=10000)
95
+ parser.add_argument("--warmup_steps", type=int, default=1000)
96
+ parser.add_argument("--dataloader_num_workers", type=int, default=8)
97
+
98
+ args = parser.parse_args()
99
+
100
+ train(args)
baseline/zh_tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<bos>",
3
+ "eos_token": "<eos>",
4
+ "pad_token": "<pad>",
5
+ "unk_token": "<unk>"
6
+ }
baseline/zh_tokenizer/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
baseline/zh_tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<bos>",
3
+ "clean_up_tokenization_spaces": true,
4
+ "eos_token": "<eos>",
5
+ "model_max_length": 1000000000000000019884624838656,
6
+ "pad_token": "<pad>",
7
+ "tokenizer_class": "PreTrainedTokenizerFast",
8
+ "unk_token": "<unk>"
9
+ }