tianxie-sf
commited on
Commit
·
9936980
1
Parent(s):
3ea8253
update _convert_id_to_token (#7)
Browse files- update _convert_id_to_token (4bb73d6b6d162f8d32b8fe3d898b6ecc315776ad)
- tokenization_xgen.py +13 -7
tokenization_xgen.py
CHANGED
@@ -149,20 +149,22 @@ class XgenTokenizer(PreTrainedTokenizer):
|
|
149 |
def _convert_token_to_id(self, token):
|
150 |
"""Converts a token (str) in an id using the vocab."""
|
151 |
if isinstance(token, str):
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
|
156 |
def _convert_id_to_token(self, index):
|
157 |
"""Converts an index (integer) in a token (str) using the vocab."""
|
158 |
-
return self.encoder.decode_single_token_bytes(index)
|
159 |
|
160 |
def _decode(self, token_ids: List[int], skip_special_tokens: bool = False, **kwargs):
|
|
|
|
|
161 |
return self.encoder.decode(token_ids)
|
162 |
|
163 |
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:
|
164 |
"""Build model inputs from a sequence by appending eos_token_id."""
|
165 |
-
eos_token_id = [
|
166 |
|
167 |
output = token_ids_0 + eos_token_id
|
168 |
|
@@ -218,11 +220,15 @@ class XgenTokenizer(PreTrainedTokenizer):
|
|
218 |
Returns:
|
219 |
`List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
|
220 |
"""
|
221 |
-
eos_token_id = [
|
222 |
|
223 |
output = [0] * len(token_ids_0 + eos_token_id)
|
224 |
|
225 |
if token_ids_1 is not None:
|
226 |
output += [1] * len(token_ids_1 + eos_token_id)
|
227 |
|
228 |
-
return output
|
|
|
|
|
|
|
|
|
|
149 |
def _convert_token_to_id(self, token):
|
150 |
"""Converts a token (str) in an id using the vocab."""
|
151 |
if isinstance(token, str):
|
152 |
+
return self.encoder.encode_single_token(token)
|
153 |
+
else:
|
154 |
+
return token
|
155 |
|
156 |
def _convert_id_to_token(self, index):
|
157 |
"""Converts an index (integer) in a token (str) using the vocab."""
|
158 |
+
return self.encoder.decode_single_token_bytes(index).decode("utf-8")
|
159 |
|
160 |
def _decode(self, token_ids: List[int], skip_special_tokens: bool = False, **kwargs):
|
161 |
+
if skip_special_tokens:
|
162 |
+
token_ids = [t for t in token_ids if t not in self.all_special_ids]
|
163 |
return self.encoder.decode(token_ids)
|
164 |
|
165 |
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:
|
166 |
"""Build model inputs from a sequence by appending eos_token_id."""
|
167 |
+
eos_token_id = [self.eos_token_id] if self.add_eos_token else []
|
168 |
|
169 |
output = token_ids_0 + eos_token_id
|
170 |
|
|
|
220 |
Returns:
|
221 |
`List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
|
222 |
"""
|
223 |
+
eos_token_id = [self.eos_token_id] if self.add_eos_token else []
|
224 |
|
225 |
output = [0] * len(token_ids_0 + eos_token_id)
|
226 |
|
227 |
if token_ids_1 is not None:
|
228 |
output += [1] * len(token_ids_1 + eos_token_id)
|
229 |
|
230 |
+
return output
|
231 |
+
|
232 |
+
# has no vocab file
|
233 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None):
|
234 |
+
return ()
|