Fix decode method for torch tensor
Browse files- tokenization_chatglm.py +6 -15
tokenization_chatglm.py
CHANGED
@@ -253,29 +253,20 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|
253 |
|
254 |
return seq
|
255 |
|
256 |
-
def
|
257 |
self,
|
258 |
-
token_ids: Union[
|
259 |
skip_special_tokens: bool = False,
|
260 |
clean_up_tokenization_spaces: bool = True,
|
261 |
-
spaces_between_special_tokens: bool = True,
|
262 |
**kwargs
|
263 |
) -> str:
|
264 |
-
if
|
265 |
token_ids = [token_ids]
|
266 |
if len(token_ids) == 0:
|
267 |
return ""
|
268 |
-
if
|
269 |
-
|
270 |
-
|
271 |
-
if self.pad_token_id in single_token_ids: # remove pad
|
272 |
-
single_token_ids = list(filter((self.pad_token_id).__ne__, single_token_ids))
|
273 |
-
tokens.append(self.sp_tokenizer.decode(single_token_ids))
|
274 |
-
return (tokens)
|
275 |
-
else:
|
276 |
-
if self.pad_token_id in token_ids: # remove pad
|
277 |
-
token_ids = list(filter((self.pad_token_id).__ne__, token_ids))
|
278 |
-
return self.sp_tokenizer.decode(token_ids)
|
279 |
|
280 |
def _convert_token_to_id(self, token):
|
281 |
""" Converts a token (str) in an id using the vocab. """
|
|
|
253 |
|
254 |
return seq
|
255 |
|
256 |
+
def _decode(
|
257 |
self,
|
258 |
+
token_ids: Union[int, List[int]],
|
259 |
skip_special_tokens: bool = False,
|
260 |
clean_up_tokenization_spaces: bool = True,
|
|
|
261 |
**kwargs
|
262 |
) -> str:
|
263 |
+
if isinstance(token_ids, int):
|
264 |
token_ids = [token_ids]
|
265 |
if len(token_ids) == 0:
|
266 |
return ""
|
267 |
+
if self.pad_token_id in token_ids: # remove pad
|
268 |
+
token_ids = list(filter((self.pad_token_id).__ne__, token_ids))
|
269 |
+
return self.sp_tokenizer.decode(token_ids)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
270 |
|
271 |
def _convert_token_to_id(self, token):
|
272 |
""" Converts a token (str) in an id using the vocab. """
|