.gitattributes CHANGED
@@ -29,4 +29,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
29
  *.zip filter=lfs diff=lfs merge=lfs -text
30
  *.zstandard filter=lfs diff=lfs merge=lfs -text
31
  *tfevents* filter=lfs diff=lfs merge=lfs -text
32
- pytorch_model.bin filter=lfs diff=lfs merge=lfs -text
 
29
  *.zip filter=lfs diff=lfs merge=lfs -text
30
  *.zstandard filter=lfs diff=lfs merge=lfs -text
31
  *tfevents* filter=lfs diff=lfs merge=lfs -text
32
+ model.safetensors filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -5,8 +5,8 @@ tags:
5
  - PyTorch
6
  - Transformers
7
  license: apache-2.0
8
- widget:
9
- - text: "sbert punc case расставляет точки запятые и знаки вопроса вам нравится"
10
  ---
11
 
12
  # SbertPuncCase
 
5
  - PyTorch
6
  - Transformers
7
  license: apache-2.0
8
+ base_model: sberbank-ai/sbert_large_nlu_ru
9
+ inference: false
10
  ---
11
 
12
  # SbertPuncCase
config.json CHANGED
@@ -55,8 +55,8 @@
55
  "pooler_type": "first_token_transform",
56
  "position_embedding_type": "absolute",
57
  "torch_dtype": "float16",
58
- "transformers_version": "4.20.1",
59
  "type_vocab_size": 2,
60
  "use_cache": true,
61
  "vocab_size": 120138
62
- }
 
55
  "pooler_type": "first_token_transform",
56
  "position_embedding_type": "absolute",
57
  "torch_dtype": "float16",
58
+ "transformers_version": "4.36.2",
59
  "type_vocab_size": 2,
60
  "use_cache": true,
61
  "vocab_size": 120138
62
+ }
pytorch_model.bin → model.safetensors RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a0928d162fd53b902c8aa1704cb29f904d777398c152e0c6a5cdc676d6cf397c
3
- size 851804225
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:06173fe7aed01a3a58385f6e724502f634909bee0818f6a8514b9eb6eb869be8
3
+ size 851791402
sbert_punc_case_ru/sbertpunccase.py CHANGED
@@ -8,62 +8,66 @@ import numpy as np
8
  from transformers import AutoTokenizer, AutoModelForTokenClassification
9
 
10
  # Прогнозируемые знаки препинания
11
- PUNK_MAPPING = {'.': 'PERIOD', ',': 'COMMA', '?': 'QUESTION'}
12
 
13
  # Прогнозируемый регистр LOWER - нижний регистр, UPPER - верхний регистр для первого символа,
14
  # UPPER_TOTAL - верхний регистр для всех символов
15
- LABELS_CASE = ['LOWER', 'UPPER', 'UPPER_TOTAL']
16
  # Добавим в пунктуацию метку O означающий отсутсвие пунктуации
17
- LABELS_PUNC = ['O'] + list(PUNK_MAPPING.values())
18
 
19
  # Сформируем метки на основе комбинаций регистра и пунктуации
20
  LABELS_list = []
21
  for case in LABELS_CASE:
22
  for punc in LABELS_PUNC:
23
- LABELS_list.append(f'{case}_{punc}')
24
- LABELS = {label: i+1 for i, label in enumerate(LABELS_list)}
25
- LABELS['O'] = -100
26
  INVERSE_LABELS = {i: label for label, i in LABELS.items()}
27
 
28
- LABEL_TO_PUNC_LABEL = {label: label.split('_')[-1] for label in LABELS.keys() if label != 'O'}
29
- LABEL_TO_CASE_LABEL = {label: '_'.join(label.split('_')[:-1]) for label in LABELS.keys() if label != 'O'}
 
 
 
 
30
 
31
 
32
  def token_to_label(token, label):
33
  if type(label) == int:
34
  label = INVERSE_LABELS[label]
35
- if label == 'LOWER_O':
36
  return token
37
- if label == 'LOWER_PERIOD':
38
- return token + '.'
39
- if label == 'LOWER_COMMA':
40
- return token + ','
41
- if label == 'LOWER_QUESTION':
42
- return token + '?'
43
- if label == 'UPPER_O':
44
  return token.capitalize()
45
- if label == 'UPPER_PERIOD':
46
- return token.capitalize() + '.'
47
- if label == 'UPPER_COMMA':
48
- return token.capitalize() + ','
49
- if label == 'UPPER_QUESTION':
50
- return token.capitalize() + '?'
51
- if label == 'UPPER_TOTAL_O':
52
  return token.upper()
53
- if label == 'UPPER_TOTAL_PERIOD':
54
- return token.upper() + '.'
55
- if label == 'UPPER_TOTAL_COMMA':
56
- return token.upper() + ','
57
- if label == 'UPPER_TOTAL_QUESTION':
58
- return token.upper() + '?'
59
- if label == 'O':
60
  return token
61
 
62
 
63
- def decode_label(label, classes='all'):
64
- if classes == 'punc':
65
  return LABEL_TO_PUNC_LABEL[INVERSE_LABELS[label]]
66
- if classes == 'case':
67
  return LABEL_TO_CASE_LABEL[INVERSE_LABELS[label]]
68
  else:
69
  return INVERSE_LABELS[label]
@@ -76,14 +80,12 @@ class SbertPuncCase(nn.Module):
76
  def __init__(self):
77
  super().__init__()
78
 
79
- self.tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO,
80
- strip_accents=False)
81
  self.model = AutoModelForTokenClassification.from_pretrained(MODEL_REPO)
82
  self.model.eval()
83
 
84
  def forward(self, input_ids, attention_mask):
85
- return self.model(input_ids=input_ids,
86
- attention_mask=attention_mask)
87
 
88
  def punctuate(self, text):
89
  text = text.strip().lower()
@@ -94,10 +96,23 @@ class SbertPuncCase(nn.Module):
94
  tokenizer_output = self.tokenizer(words, is_split_into_words=True)
95
 
96
  if len(tokenizer_output.input_ids) > 512:
97
- return ' '.join([self.punctuate(' '.join(text_part)) for text_part in np.array_split(words, 2)])
98
-
99
- predictions = self(torch.tensor([tokenizer_output.input_ids], device=self.model.device),
100
- torch.tensor([tokenizer_output.attention_mask], device=self.model.device)).logits.cpu().data.numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  predictions = np.argmax(predictions, axis=2)
102
 
103
  # decode punctuation and casing
@@ -108,16 +123,31 @@ class SbertPuncCase(nn.Module):
108
  label_id = predictions[0][label_pos]
109
  label = decode_label(label_id)
110
  splitted_text.append(token_to_label(word, label))
111
- capitalized_text = ' '.join(splitted_text)
112
  return capitalized_text
113
 
114
 
115
- if __name__ == '__main__':
116
- parser = argparse.ArgumentParser("Punctuation and case restoration model sbert_punc_case_ru")
117
- parser.add_argument("-i", "--input", type=str, help="text to restore", default='sbert punc case расставляет точки запятые и знаки вопроса вам нравится')
118
- parser.add_argument("-d", "--device", type=str, help="run model on cpu or gpu", choices=['cpu', 'cuda'], default='cpu')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  args = parser.parse_args()
120
  print(f"Source text: {args.input}\n")
121
  sbertpunc = SbertPuncCase().to(args.device)
122
  punctuated_text = sbertpunc.punctuate(args.input)
123
- print(f"Restored text: {punctuated_text}")
 
8
  from transformers import AutoTokenizer, AutoModelForTokenClassification
9
 
10
  # Прогнозируемые знаки препинания
11
+ PUNK_MAPPING = {".": "PERIOD", ",": "COMMA", "?": "QUESTION"}
12
 
13
  # Прогнозируемый регистр LOWER - нижний регистр, UPPER - верхний регистр для первого символа,
14
  # UPPER_TOTAL - верхний регистр для всех символов
15
+ LABELS_CASE = ["LOWER", "UPPER", "UPPER_TOTAL"]
16
  # Добавим в пунктуацию метку O означающий отсутсвие пунктуации
17
+ LABELS_PUNC = ["O"] + list(PUNK_MAPPING.values())
18
 
19
  # Сформируем метки на основе комбинаций регистра и пунктуации
20
  LABELS_list = []
21
  for case in LABELS_CASE:
22
  for punc in LABELS_PUNC:
23
+ LABELS_list.append(f"{case}_{punc}")
24
+ LABELS = {label: i + 1 for i, label in enumerate(LABELS_list)}
25
+ LABELS["O"] = -100
26
  INVERSE_LABELS = {i: label for label, i in LABELS.items()}
27
 
28
+ LABEL_TO_PUNC_LABEL = {
29
+ label: label.split("_")[-1] for label in LABELS.keys() if label != "O"
30
+ }
31
+ LABEL_TO_CASE_LABEL = {
32
+ label: "_".join(label.split("_")[:-1]) for label in LABELS.keys() if label != "O"
33
+ }
34
 
35
 
36
  def token_to_label(token, label):
37
  if type(label) == int:
38
  label = INVERSE_LABELS[label]
39
+ if label == "LOWER_O":
40
  return token
41
+ if label == "LOWER_PERIOD":
42
+ return token + "."
43
+ if label == "LOWER_COMMA":
44
+ return token + ","
45
+ if label == "LOWER_QUESTION":
46
+ return token + "?"
47
+ if label == "UPPER_O":
48
  return token.capitalize()
49
+ if label == "UPPER_PERIOD":
50
+ return token.capitalize() + "."
51
+ if label == "UPPER_COMMA":
52
+ return token.capitalize() + ","
53
+ if label == "UPPER_QUESTION":
54
+ return token.capitalize() + "?"
55
+ if label == "UPPER_TOTAL_O":
56
  return token.upper()
57
+ if label == "UPPER_TOTAL_PERIOD":
58
+ return token.upper() + "."
59
+ if label == "UPPER_TOTAL_COMMA":
60
+ return token.upper() + ","
61
+ if label == "UPPER_TOTAL_QUESTION":
62
+ return token.upper() + "?"
63
+ if label == "O":
64
  return token
65
 
66
 
67
+ def decode_label(label, classes="all"):
68
+ if classes == "punc":
69
  return LABEL_TO_PUNC_LABEL[INVERSE_LABELS[label]]
70
+ if classes == "case":
71
  return LABEL_TO_CASE_LABEL[INVERSE_LABELS[label]]
72
  else:
73
  return INVERSE_LABELS[label]
 
80
  def __init__(self):
81
  super().__init__()
82
 
83
+ self.tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO, strip_accents=False)
 
84
  self.model = AutoModelForTokenClassification.from_pretrained(MODEL_REPO)
85
  self.model.eval()
86
 
87
  def forward(self, input_ids, attention_mask):
88
+ return self.model(input_ids=input_ids, attention_mask=attention_mask)
 
89
 
90
  def punctuate(self, text):
91
  text = text.strip().lower()
 
96
  tokenizer_output = self.tokenizer(words, is_split_into_words=True)
97
 
98
  if len(tokenizer_output.input_ids) > 512:
99
+ return " ".join(
100
+ [
101
+ self.punctuate(" ".join(text_part))
102
+ for text_part in np.array_split(words, 2)
103
+ ]
104
+ )
105
+
106
+ predictions = (
107
+ self(
108
+ torch.tensor([tokenizer_output.input_ids], device=self.model.device),
109
+ torch.tensor(
110
+ [tokenizer_output.attention_mask], device=self.model.device
111
+ ),
112
+ )
113
+ .logits.cpu()
114
+ .data.numpy()
115
+ )
116
  predictions = np.argmax(predictions, axis=2)
117
 
118
  # decode punctuation and casing
 
123
  label_id = predictions[0][label_pos]
124
  label = decode_label(label_id)
125
  splitted_text.append(token_to_label(word, label))
126
+ capitalized_text = " ".join(splitted_text)
127
  return capitalized_text
128
 
129
 
130
+ if __name__ == "__main__":
131
+ parser = argparse.ArgumentParser(
132
+ "Punctuation and case restoration model sbert_punc_case_ru"
133
+ )
134
+ parser.add_argument(
135
+ "-i",
136
+ "--input",
137
+ type=str,
138
+ help="text to restore",
139
+ default="sbert punc case расставляет точки запятые и знаки вопроса вам нравится",
140
+ )
141
+ parser.add_argument(
142
+ "-d",
143
+ "--device",
144
+ type=str,
145
+ help="run model on cpu or gpu",
146
+ choices=["cpu", "cuda"],
147
+ default="cpu",
148
+ )
149
  args = parser.parse_args()
150
  print(f"Source text: {args.input}\n")
151
  sbertpunc = SbertPuncCase().to(args.device)
152
  punctuated_text = sbertpunc.punctuate(args.input)
153
+ print(f"Restored text: {punctuated_text}")
setup.py CHANGED
@@ -1,19 +1,24 @@
1
  from distutils.core import setup
2
 
3
- setup(name='sbert_punc_case_ru',
4
- version='0.1',
5
- description='Punctuation and Case Restoration model based on https://huggingface.co/sberbank-ai/sbert_large_nlu_ru',
6
- author='Almira Murtazina',
7
- author_email='[email protected]',
8
- packages=['sbert_punc_case_ru'],
9
- install_requires=['transformers>=4.18.3'],
10
- classifiers=[
11
- "Operating System :: OS Independent",
12
- "Programming Language :: Python :: 3",
13
- "Programming Language :: Python :: 3.6",
14
- "Programming Language :: Python :: 3.7",
15
- "Programming Language :: Python :: 3.8",
16
- "Programming Language :: Python :: 3.9",
17
- "Topic :: Scientific/Engineering :: Artificial Intelligence",
18
- ]
19
- )
 
 
 
 
 
 
1
  from distutils.core import setup
2
 
3
+ setup(
4
+ name="sbert_punc_case_ru",
5
+ version="0.2",
6
+ description="Punctuation and Case Restoration model based on https://huggingface.co/sberbank-ai/sbert_large_nlu_ru",
7
+ author="Almira Murtazina",
8
+ author_email="[email protected]",
9
+ packages=["sbert_punc_case_ru"],
10
+ install_requires=[
11
+ "transformers>=4.36.2",
12
+ "torch",
13
+ "numpy"
14
+ ],
15
+ classifiers=[
16
+ "Operating System :: OS Independent",
17
+ "Programming Language :: Python :: 3",
18
+ "Programming Language :: Python :: 3.6",
19
+ "Programming Language :: Python :: 3.7",
20
+ "Programming Language :: Python :: 3.8",
21
+ "Programming Language :: Python :: 3.9",
22
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
23
+ ],
24
+ )
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json CHANGED
@@ -1,13 +1,55 @@
1
  {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  "cls_token": "[CLS]",
3
  "do_basic_tokenize": true,
4
  "do_lower_case": true,
5
  "mask_token": "[MASK]",
6
- "name_or_path": "sberbank-ai/sbert_large_nlu_ru",
7
  "never_split": null,
8
  "pad_token": "[PAD]",
9
  "sep_token": "[SEP]",
10
- "special_tokens_map_file": null,
11
  "strip_accents": false,
12
  "tokenize_chinese_chars": true,
13
  "tokenizer_class": "BertTokenizer",
 
1
  {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "100": {
12
+ "content": "[UNK]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "101": {
20
+ "content": "[CLS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "102": {
28
+ "content": "[SEP]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "103": {
36
+ "content": "[MASK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "clean_up_tokenization_spaces": true,
45
  "cls_token": "[CLS]",
46
  "do_basic_tokenize": true,
47
  "do_lower_case": true,
48
  "mask_token": "[MASK]",
49
+ "model_max_length": 1000000000000000019884624838656,
50
  "never_split": null,
51
  "pad_token": "[PAD]",
52
  "sep_token": "[SEP]",
 
53
  "strip_accents": false,
54
  "tokenize_chinese_chars": true,
55
  "tokenizer_class": "BertTokenizer",