kertser commited on
Commit
05d6778
1 Parent(s): 3291ddb

Upload WarBot.py

Browse files
Files changed (1) hide show
  1. WarBot.py +51 -38
WarBot.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  from transformers import AutoTokenizer ,AutoModelForCausalLM
2
  import re
3
  # Speller and punctuation:
@@ -6,7 +8,7 @@ import yaml
6
  import torch
7
  from torch import package
8
  # not very necessary
9
- import textwrap
10
  from textwrap3 import wrap
11
 
12
  # util function to get expected len after tokenizing
@@ -65,7 +67,6 @@ def prepare_punct():
65
 
66
  def initialize():
67
  """ Loading the model """
68
- torch.backends.quantized.engine = 'qnnpack' # Just for the specific machine architecture
69
  fit_checkpoint = "WarBot"
70
  tokenizer = AutoTokenizer.from_pretrained(fit_checkpoint)
71
  model = AutoModelForCausalLM.from_pretrained(fit_checkpoint)
@@ -75,10 +76,13 @@ def initialize():
75
  def split_string(string,n=256):
76
  return [string[i:i+n] for i in range(0, len(string), n)]
77
 
78
- def get_response(quote:str,model,tokenizer,model_punct):
79
  # encode the input, add the eos_token and return a tensor in Pytorch
80
- user_inpit_ids = tokenizer.encode(f"|0|{get_length_param(quote, tokenizer)}|" \
81
- + quote + tokenizer.eos_token, return_tensors="pt")
 
 
 
82
 
83
  chat_history_ids = user_inpit_ids # To be changed
84
 
@@ -88,22 +92,22 @@ def get_response(quote:str,model,tokenizer,model_punct):
88
  else:
89
  no_repeat_ngram_size = 1
90
 
91
- output_id = model.generate(
92
- chat_history_ids,
93
- num_return_sequences=1, # use for more variants, but have to print [i]
94
- max_length=200, #512
95
- no_repeat_ngram_size=no_repeat_ngram_size, #3
96
- do_sample=True, #True
97
- top_k=50,#50
98
- top_p=0.9, #0.9
99
- temperature = 0.4, # was 0.6, 0 for greedy
100
- #mask_token_id=tokenizer.mask_token_id,
101
- eos_token_id=tokenizer.eos_token_id,
102
- #unk_token_id=tokenizer.unk_token_id,
103
- pad_token_id=tokenizer.pad_token_id,
104
- #pad_token_id=tokenizer.eos_token_id,
105
- #device='cpu'
106
- )
107
 
108
  response = tokenizer.decode(output_id[0], skip_special_tokens=True)
109
  response = removeSigns(response)
@@ -113,26 +117,35 @@ def get_response(quote:str,model,tokenizer,model_punct):
113
  response = remove_duplicates(re.sub(r"\d{4,}", "", response)) # Remove the consequent numbers with 4 or more digits
114
  response = re.sub(r'\.\.+', '', response) # Remove the "....." thing
115
 
116
- maxLen = 170
117
-
118
- try:
119
- if len(response)>maxLen: # We shall play with it
120
- resps = wrap(response,maxLen)
121
- for i in range(len(resps)):
122
  resps[i] = model_punct.enhance_text(resps[i], lan='ru')
123
  response = ''.join(resps)
124
- else:
125
- response = model_punct.enhance_text(response, lan='ru')
126
- except:
127
- pass # sometimes the string is getting too long
128
 
 
129
  response = re.sub(r'[UNK]', '', response) # Remove the [UNK] thing
 
 
 
 
 
 
 
 
130
  return response
131
 
132
- #if __name__ == '__main__':
133
- #model,tokenizer,model_punct = initialize()
134
- #quote = "Это хорошо, но глядя на ролик, когда ефиопские толпы в Израиле громят машины и нападают на улице на израильтян - задумаешься, куда все движется"
135
- #print('please wait...')
136
- #response = wrap(get_response(quote,model,tokenizer,model_punct),60)
137
- #for phrase in response:
138
- # print(phrase)
 
 
 
1
+ # Main library for WarBot
2
+
3
  from transformers import AutoTokenizer ,AutoModelForCausalLM
4
  import re
5
  # Speller and punctuation:
 
8
  import torch
9
  from torch import package
10
  # not very necessary
11
+ #import textwrap
12
  from textwrap3 import wrap
13
 
14
  # util function to get expected len after tokenizing
 
67
 
68
  def initialize():
69
  """ Loading the model """
 
70
  fit_checkpoint = "WarBot"
71
  tokenizer = AutoTokenizer.from_pretrained(fit_checkpoint)
72
  model = AutoModelForCausalLM.from_pretrained(fit_checkpoint)
 
76
  def split_string(string,n=256):
77
  return [string[i:i+n] for i in range(0, len(string), n)]
78
 
79
+ def get_response(quote:str,model,tokenizer,model_punct,temperature=0.2):
80
  # encode the input, add the eos_token and return a tensor in Pytorch
81
+ try:
82
+ user_inpit_ids = tokenizer.encode(f"|0|{get_length_param(quote, tokenizer)}|" \
83
+ + quote + tokenizer.eos_token, return_tensors="pt")
84
+ except:
85
+ return "" # Exception in tokenization
86
 
87
  chat_history_ids = user_inpit_ids # To be changed
88
 
 
92
  else:
93
  no_repeat_ngram_size = 1
94
 
95
+ try:
96
+ output_id = model.generate(
97
+ chat_history_ids,
98
+ num_return_sequences=1, # use for more variants, but have to print [i]
99
+ max_length=200, #512
100
+ no_repeat_ngram_size=no_repeat_ngram_size, #3
101
+ do_sample=True, #True
102
+ top_k=50,#50
103
+ top_p=0.9, #0.9
104
+ temperature = temperature, # was 0.6, 0 for greedy
105
+ eos_token_id=tokenizer.eos_token_id,
106
+ pad_token_id=tokenizer.pad_token_id,
107
+ #device='cpu'
108
+ )
109
+ except:
110
+ return "" # Exception in generation
111
 
112
  response = tokenizer.decode(output_id[0], skip_special_tokens=True)
113
  response = removeSigns(response)
 
117
  response = remove_duplicates(re.sub(r"\d{4,}", "", response)) # Remove the consequent numbers with 4 or more digits
118
  response = re.sub(r'\.\.+', '', response) # Remove the "....." thing
119
 
120
+ if len(response)>200:
121
+ resps = wrap(response,200)
122
+ for i in range(len(resps)):
123
+ try:
 
 
124
  resps[i] = model_punct.enhance_text(resps[i], lan='ru')
125
  response = ''.join(resps)
126
+ except:
127
+ return "" # Excepion in punctuation
128
+ else:
129
+ response = model_punct.enhance_text(response, lan='ru')
130
 
131
+ # Immanent postprocessing of the response
132
  response = re.sub(r'[UNK]', '', response) # Remove the [UNK] thing
133
+ response = re.sub(r',+', ',', response) # Replace multi-commas with single one
134
+ response = re.sub(r'-+', ',', response) # Replace multi-dashes with single one
135
+ response = re.sub(r'\.\?', '?', response) # Fix the .? issue
136
+ response = re.sub(r'\.\!', '!', response) # Fix the .! issue
137
+ response = re.sub(r'\.\,', ',', response) # Fix the ,. issue
138
+ response = re.sub(r'\.\)', '.', response) # Fix the .) issue
139
+ response = response.replace('[]', '') # Fix the [] issue
140
+
141
  return response
142
 
143
+ if __name__ == '__main__':
144
+ """
145
+ quote = "Здравствуй, Жопа, Новый Год, выходи на ёлку!"
146
+ model, tokenizer, model_punct = initialize()
147
+ response = ""
148
+ while not response:
149
+ response = get_response(quote, model, tokenizer, model_punct,temperature=0.2)
150
+ print(response)
151
+ """