KaleiNeely commited on
Commit
9b9dc85
·
1 Parent(s): 8c48767

Update tokenization_rwkv_world.py

Browse files
Files changed (1) hide show
  1. tokenization_rwkv_world.py +19 -3
tokenization_rwkv_world.py CHANGED
@@ -202,13 +202,18 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
202
  return tokens
203
 
204
  def decodeBytes(self, tokens):
205
- byte_sequence = [self.encoder[i] for i in tokens if i != 0]
206
- return b''.join(byte_sequence)
207
 
208
  def _tokenize(self, text, **kwargs):
209
  """Tokenize a string."""
210
  return self.encodeBytes(text.encode("utf-8"))
211
 
 
 
 
 
 
 
212
  def _decode(self,
213
  token_ids: Union[int, List[int], "np.ndarray", "torch.Tensor", "tf.Tensor"],
214
  skip_special_tokens: bool = False,
@@ -222,7 +227,18 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
222
  return ""
223
  return self.encoder.get(token_ids, self.unk_token)
224
  elif isinstance(token_ids, list):
225
- return self.decodeBytes(token_ids).decode('utf-8')
 
 
 
 
 
 
 
 
 
 
 
226
  else:
227
  return token_ids
228
 
 
202
  return tokens
203
 
204
  def decodeBytes(self, tokens):
205
+ return b''.join(map(lambda i: self.encoder[i], tokens))
 
206
 
207
  def _tokenize(self, text, **kwargs):
208
  """Tokenize a string."""
209
  return self.encodeBytes(text.encode("utf-8"))
210
 
211
+ def _decode_tokens(self, tokens):
212
+ try:
213
+ return self.decodeBytes(tokens).decode('utf-8')
214
+ except:
215
+ return '\ufffd' # bad utf-8
216
+
217
  def _decode(self,
218
  token_ids: Union[int, List[int], "np.ndarray", "torch.Tensor", "tf.Tensor"],
219
  skip_special_tokens: bool = False,
 
227
  return ""
228
  return self.encoder.get(token_ids, self.unk_token)
229
  elif isinstance(token_ids, list):
230
+ out_str = ""
231
+ out_last = 0
232
+ out_tokens = []
233
+ for i, token in enumerate(token_ids):
234
+ if token == 0:
235
+ break
236
+ out_tokens += [token]
237
+ tmp = self._decode_tokens(out_tokens[out_last:])
238
+ if '\ufffd' not in tmp:
239
+ out_str += tmp
240
+ out_last = i + 1
241
+ return out_str
242
  else:
243
  return token_ids
244