LHC88 commited on
Commit
3f2dddb
·
1 Parent(s): 5b38d1b

tokenizsation_function

Browse files
Files changed (1) hide show
  1. tokenization_functionary.py +537 -0
tokenization_functionary.py ADDED
@@ -0,0 +1,537 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, MeetKai Inc. All rights reserved.
2
+
3
+ from copy import deepcopy
4
+ import json
5
+ from typing import Any, Dict, List, Literal, Optional, Union
6
+
7
+ import jsonref
8
+ from pydantic import BaseModel, Field, model_validator
9
+ from typing_extensions import Self
10
+
11
+ from transformers.tokenization_utils_base import BatchEncoding
12
+ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
13
+ from transformers.utils import TensorType, logging
14
+
15
+
16
+ logger = logging.get_logger(__name__)
17
+ SYSTEM_PROMPT = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary"""
18
+ CODE_INTERPRETER_SYSTEM_PROMPT = """When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files."""
19
+
20
+
21
+ class Function(BaseModel):
22
+ name: str
23
+ description: Optional[str] = Field(default="")
24
+ parameters: Optional[dict] = None
25
+
26
+
27
+ class Tool(BaseModel):
28
+ type: Literal["function", "code_interpreter"]
29
+ function: Optional[Function] = None
30
+
31
+ @model_validator(mode="after")
32
+ def check_type_function_matches(self) -> Self:
33
+ if self.type == "function":
34
+ assert (
35
+ self.function is not None
36
+ ), '"function" must contain function description when `"type": "function"`'
37
+ else:
38
+ assert (
39
+ self.function is None
40
+ ), '"function" must not be provided when `"type": "code_interpreter"`'
41
+ return self
42
+
43
+
44
+ def convert_data_type(param_type: str) -> str:
45
+ """convert data_type to typescript data type
46
+ Args:
47
+ param_type (str): param_type
48
+ Returns:
49
+ str: param type in typescript
50
+ """
51
+ if param_type == "integer" or param_type == "float":
52
+ return "number"
53
+ return param_type
54
+
55
+
56
+ def get_param_type(param: Dict) -> str:
57
+ """get param_type of parameter
58
+ Args:
59
+ param (Dict): param dict in properties
60
+ Returns:
61
+ str: _description_
62
+ """
63
+ param_type = "any"
64
+ if "type" in param:
65
+ raw_param_type = param["type"]
66
+ if type(raw_param_type) is list:
67
+ param_type = " | ".join(raw_param_type)
68
+ else:
69
+ param_type = raw_param_type
70
+
71
+ else: # in many cases, the json schema contains: oneOf instead of "type"
72
+ if "oneOf" in param:
73
+ one_of_types = []
74
+ for item in param["oneOf"]:
75
+ if "type" in item:
76
+ one_of_types.append(convert_data_type(item["type"]))
77
+ one_of_types = list(set(one_of_types))
78
+ param_type = " | ".join(one_of_types)
79
+ return convert_data_type(param_type)
80
+
81
+
82
+ def get_format_param(param: Dict) -> Optional[str]:
83
+ """Get "format" from param. There are cases where format is not directly in param but in oneOf
84
+ Args:
85
+ param (Dict): _description_
86
+ Returns:
87
+ Optional[str]: _description_
88
+ """
89
+ if "format" in param:
90
+ return param["format"]
91
+ if "oneOf" in param:
92
+ formats = []
93
+ for item in param["oneOf"]:
94
+ if "format" in item:
95
+ formats.append(item["format"])
96
+ if len(formats) > 0:
97
+ return " or ".join(formats)
98
+ return None
99
+
100
+
101
+ def get_param_info(param: Dict) -> Optional[str]:
102
+ """get additional information about parameter such as: format, default value, min, max, ...
103
+ Args:
104
+ param (Dict): _description_
105
+ Returns:
106
+ Optional[str]: _description_
107
+ """
108
+ param_type = param.get("type", "any")
109
+ info_list = []
110
+ if "description" in param:
111
+ desc = param["description"]
112
+ if not desc.endswith("."):
113
+ desc += "."
114
+ info_list.append(desc)
115
+
116
+ if "default" in param:
117
+ default_value = param["default"]
118
+ if param_type == "string":
119
+ default_value = f'"{default_value}"' # if string --> add ""
120
+ info_list.append(f"Default={default_value}.")
121
+
122
+ format_param = get_format_param(param)
123
+ if format_param is not None:
124
+ info_list.append("Format=" + format_param)
125
+
126
+ for field, field_name in [
127
+ ("maximum", "Maximum"),
128
+ ("minimum", "Minimum"),
129
+ ("maxLength", "Maximum length"),
130
+ ("minLength", "Minimum length"),
131
+ ]:
132
+ if field in param:
133
+ info_list.append(f"{field_name}=" + str(param[field]))
134
+
135
+ if len(info_list) > 0:
136
+ result = "// " + " ".join(info_list)
137
+ result = result.replace("\n", " ")
138
+ return result
139
+ return None
140
+
141
+
142
+ def append_new_param_info(
143
+ info_list: List[str],
144
+ param_declaration: str,
145
+ comment_info: Optional[str],
146
+ examples_info: List,
147
+ depth: int,
148
+ ):
149
+ """Append a new parameter with comment to the info_list
150
+ Args:
151
+ info_lines (List[str]): current info_list
152
+ param_declaration (str): param: type
153
+ comment_info (Optional[str]): information of comment
154
+ examples_info (List): information of examples given
155
+ depth (int): level of nested param
156
+ """
157
+ offset = ""
158
+ if depth >= 1:
159
+ offset = "".join([" " for _ in range(depth)])
160
+ if comment_info is not None:
161
+ # if depth == 0: # format: //comment\nparam: type
162
+ info_list.append(f"{offset}{comment_info}")
163
+ if len(examples_info) > 0:
164
+ for example in examples_info:
165
+ info_list.append(f"{offset}{example}")
166
+ info_list.append(f"{offset}{param_declaration}")
167
+ # else: # format: param: type // comment
168
+ # info_list.append(f"{offset}{param_declaration} {comment_info}")
169
+ else:
170
+ info_list.append(f"{offset}{param_declaration}")
171
+
172
+
173
+ def get_examples_info(param_name: str, examples: List) -> List:
174
+ """get information about examples provided
175
+ Args:
176
+ param_name (str): _description_
177
+ examples (List): _description_
178
+ Returns:
179
+ List: _description_
180
+ """
181
+ examples_list = [f"// Example {param_name}:"]
182
+ for example in examples:
183
+ if isinstance(example, dict) or isinstance(example, list):
184
+ example_str = json.dumps(example, ensure_ascii=False).replace("\n", "\\n")
185
+ else:
186
+ example_str = str(example).replace("\n", "\\n")
187
+ examples_list.append(f"// {example_str}")
188
+
189
+ return examples_list
190
+
191
+
192
+ def get_enum_option_str(enum_options: List) -> str:
193
+ """get enum option separated by: "|"
194
+ Args:
195
+ enum_options (List): list of options
196
+ Returns:
197
+ _type_: concatenation of options separated by "|"
198
+ """
199
+ # if each option is string --> add quote
200
+ return " | ".join([f'"{v}"' if type(v) is str else str(v) for v in enum_options])
201
+
202
+
203
+ def get_array_typescript(
204
+ param_name: Optional[str], param_dic: dict, depth: int = 0
205
+ ) -> str:
206
+ """recursive implementation for generating type script of array
207
+ Args:
208
+ param_name (Optional[str]): name of param, optional
209
+ param_dic (dict): param_dic
210
+ depth (int, optional): nested level. Defaults to 0.
211
+ Returns:
212
+ _type_: typescript of array
213
+ """
214
+ offset = ""
215
+ if depth >= 1:
216
+ offset = "".join([" " for _ in range(depth)])
217
+ items_info = param_dic.get("items", {})
218
+
219
+ if len(items_info) == 0:
220
+ if param_name is not None:
221
+ return f"{offset}{param_name}: []"
222
+ else:
223
+ return "[]"
224
+ array_type = get_param_type(items_info)
225
+ if array_type == "object":
226
+ info_lines = []
227
+ child_lines = get_parameter_typescript(
228
+ items_info.get("properties", {}), items_info.get("required", []), depth + 1
229
+ )
230
+ # if comment_info is not None:
231
+ # info_lines.append(f"{offset}{comment_info}")
232
+ if param_name is not None:
233
+ info_lines.append(f"{offset}{param_name}" + ": {")
234
+ else:
235
+ info_lines.append(f"{offset}" + "{")
236
+ info_lines.extend(child_lines)
237
+ info_lines.append(f"{offset}" + "}[]")
238
+ return "\n".join(info_lines)
239
+
240
+ elif array_type == "array":
241
+ item_info = get_array_typescript(None, items_info, depth + 1)
242
+ if param_name is None:
243
+ return f"{item_info}[]"
244
+ return f"{offset}{param_name}: {item_info.strip()}[]"
245
+
246
+ else:
247
+ if "enum" in items_info:
248
+ item_type = get_enum_option_str(items_info["enum"])
249
+ if param_name is None:
250
+ return f"({item_type})[]"
251
+ else:
252
+ return f"{offset}{param_name}: ({item_type})[]"
253
+ else:
254
+ if param_name is None:
255
+ return f"{array_type}[]"
256
+ else:
257
+ return f"{offset}{param_name}: {array_type}[],"
258
+
259
+
260
+ def get_parameter_typescript(properties, required_params, depth=0) -> List[str]:
261
+ """Recursion, returning the information about parameters including data type, description and other information
262
+ These kinds of information will be put into the prompt
263
+ Args:
264
+ properties (_type_): properties in parameters
265
+ required_params (_type_): List of required parameters
266
+ depth (int, optional): the depth of params (nested level). Defaults to 0.
267
+ Returns:
268
+ _type_: list of lines containing information about all parameters
269
+ """
270
+ tp_lines = []
271
+ for param_name, param in properties.items():
272
+ # Sometimes properties have "required" field as a list of string.
273
+ # Even though its supposed to be not under properties. So we skip it
274
+ if not isinstance(param, dict):
275
+ continue
276
+ # Param Description
277
+ comment_info = get_param_info(param)
278
+ # Param Examples
279
+ examples_info = []
280
+ if "examples" in param:
281
+ examples_info = get_examples_info(param_name, param["examples"])
282
+ # Param Name declaration
283
+ param_declaration = f"{param_name}"
284
+ if isinstance(required_params, list):
285
+ if param_name not in required_params:
286
+ param_declaration += "?"
287
+ param_type = get_param_type(param)
288
+
289
+ offset = ""
290
+ if depth >= 1:
291
+ offset = "".join([" " for _ in range(depth)])
292
+
293
+ if param_type == "object": # param_type is object
294
+ child_lines = get_parameter_typescript(
295
+ param.get("properties", {}), param.get("required", []), depth + 1
296
+ )
297
+ if comment_info is not None:
298
+ tp_lines.append(f"{offset}{comment_info}")
299
+ if len(examples_info) > 0:
300
+ for example in examples_info:
301
+ tp_lines.append(f"{offset}{example}")
302
+
303
+ param_declaration += ": {"
304
+ tp_lines.append(f"{offset}{param_declaration}")
305
+ tp_lines.extend(child_lines)
306
+ tp_lines.append(f"{offset}" + "},")
307
+
308
+ elif param_type == "array": # param_type is an array
309
+ item_info = param.get("items", {})
310
+ if "type" not in item_info: # don't know type of array
311
+ param_declaration += ": [],"
312
+ append_new_param_info(
313
+ tp_lines, param_declaration, comment_info, examples_info, depth
314
+ )
315
+ else:
316
+ array_declaration = get_array_typescript(
317
+ param_declaration, param, depth
318
+ )
319
+ if not array_declaration.endswith(","):
320
+ array_declaration += ","
321
+ if comment_info is not None:
322
+ tp_lines.append(f"{offset}{comment_info}")
323
+ if len(examples_info) > 0:
324
+ for example in examples_info:
325
+ tp_lines.append(f"{offset}{example}")
326
+ tp_lines.append(array_declaration)
327
+ else:
328
+ if "enum" in param:
329
+ param_type = get_enum_option_str(param["enum"])
330
+ # param_type = " | ".join([f'"{v}"' for v in param["enum"]])
331
+ if "nullable" in param and param["nullable"] is True:
332
+ param_type += " | null"
333
+ param_declaration += f": {param_type},"
334
+ append_new_param_info(
335
+ tp_lines, param_declaration, comment_info, examples_info, depth
336
+ )
337
+
338
+ return tp_lines
339
+
340
+
341
+ def generate_schema_from_functions(
342
+ functions: List[Function], namespace="functions"
343
+ ) -> str:
344
+ """
345
+ Convert functions schema to a schema that language models can understand.
346
+ """
347
+
348
+ schema = "// Supported function definitions that should be called when necessary.\n"
349
+ schema += f"namespace {namespace} {{\n\n"
350
+
351
+ for function in functions:
352
+ # Convert a Function object to dict, if necessary
353
+ if not isinstance(function, dict):
354
+ function = function.model_dump()
355
+ function_name = function.get("name", None)
356
+ if function_name is None:
357
+ continue
358
+
359
+ description = function.get("description", "")
360
+ schema += f"// {description}\n"
361
+ schema += f"type {function_name}"
362
+
363
+ parameters = function.get("parameters", None)
364
+ if parameters is not None and parameters.get("properties") is not None:
365
+ parameters = deepcopy(jsonref.JsonRef.replace_refs(parameters))
366
+ schema += " = (_: {\n"
367
+ required_params = parameters.get("required", [])
368
+ tp_lines = get_parameter_typescript(
369
+ parameters.get("properties"),
370
+ required_params,
371
+ 0,
372
+ )
373
+ schema += "\n".join(tp_lines)
374
+ schema += "\n}) => any;\n\n"
375
+ else:
376
+ # Doesn't have any parameters
377
+ schema += " = () => any;\n\n"
378
+
379
+ schema += f"}} // namespace {namespace}"
380
+
381
+ return schema
382
+
383
+
384
+ class FunctionaryTokenizer(PreTrainedTokenizerFast):
385
+ def apply_chat_template(
386
+ self,
387
+ conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]], str],
388
+ tools: Optional[List[Dict[str, Any]]],
389
+ chat_template: Optional[str] = None,
390
+ add_generation_prompt: bool = False,
391
+ tokenize: bool = True,
392
+ padding: bool = False,
393
+ truncation: bool = False,
394
+ max_length: Optional[int] = None,
395
+ return_tensors: Optional[Union[str, TensorType]] = None,
396
+ return_dict: bool = False,
397
+ tokenizer_kwargs: Optional[Dict[str, Any]] = None,
398
+ **kwargs,
399
+ ) -> Union[str, List[int], List[str], List[List[int]], BatchEncoding]:
400
+ if return_dict and not tokenize:
401
+ raise ValueError(
402
+ "`return_dict=True` is incompatible with `tokenize=False`, because there is no dict "
403
+ "of tokenizer outputs to return."
404
+ )
405
+
406
+ if tokenizer_kwargs is None:
407
+ tokenizer_kwargs = {}
408
+
409
+ using_default_template = False
410
+
411
+ # First, handle the cases when the model has a dict of multiple templates
412
+ if isinstance(self.chat_template, dict) or (
413
+ self.chat_template is None and isinstance(self.default_chat_template, dict)
414
+ ):
415
+ if self.chat_template is not None:
416
+ template_dict = self.chat_template
417
+ using_default_dict = False
418
+ else:
419
+ template_dict = self.default_chat_template
420
+ using_default_dict = True
421
+ if chat_template is not None and chat_template in template_dict:
422
+ # The user can pass the name of a template to the chat template argument instead of an entire template
423
+ chat_template = template_dict[chat_template]
424
+ if using_default_dict:
425
+ using_default_template = True
426
+ elif chat_template is None and "default" in template_dict:
427
+ chat_template = template_dict["default"]
428
+ if using_default_dict:
429
+ using_default_template = True
430
+ elif chat_template is None:
431
+ raise ValueError(
432
+ "This model has multiple chat templates with no default specified! Please either pass a chat "
433
+ "template or the name of the template you wish to use to the `chat_template` argument. Available "
434
+ f"template names are {sorted(template_dict.keys())}."
435
+ )
436
+ elif chat_template is None:
437
+ # These are the cases when the model has a single template
438
+ # priority: `chat_template` argument > `tokenizer.chat_template` > `tokenizer.default_chat_template
439
+ if self.chat_template is not None:
440
+ chat_template = self.chat_template
441
+ else:
442
+ chat_template = self.default_chat_template
443
+ using_default_template = True
444
+
445
+ if using_default_template:
446
+ logger.warning_once(
447
+ "No chat template is set for this tokenizer, falling back to a default class-level template. This is "
448
+ "very error-prone, because models are often trained with templates different from the class default! "
449
+ "Default chat templates are a legacy feature and will be removed in Transformers v4.43, at which "
450
+ "point any code depending on them will stop working. We recommend setting a valid chat template before "
451
+ "then to ensure that this model continues working without issues."
452
+ )
453
+
454
+ PYTHON_RUN_SYS_MSG = "When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files."
455
+ SYSTEM_CONTENT = """You are capable of executing available function(s) if required.
456
+ Only execute function(s) when absolutely necessary.
457
+ Ask for the required input to:recipient==all
458
+ Use JSON for function arguments.
459
+ Respond in this format:
460
+ >>>${recipient}
461
+ ${content}
462
+ Available functions:
463
+ """
464
+
465
+ # Prepare tools/functions into schema
466
+ functions_pydantic_to_render = []
467
+ has_code_interpreter = False
468
+ if tools is not None:
469
+ for item in tools:
470
+ if (
471
+ "function" in item and item["function"] is not None
472
+ ): # new data format: tools: [{"type": xx, "function": xxx}]
473
+ functions_pydantic_to_render.append(item["function"])
474
+ elif "type" in item and item["type"] == "code_interpreter":
475
+ has_code_interpreter = True
476
+ else:
477
+ functions_pydantic_to_render.append(item) # old format
478
+
479
+ conversation.insert(
480
+ 0,
481
+ {
482
+ "role": "system",
483
+ "content": SYSTEM_CONTENT
484
+ + generate_schema_from_functions(functions_pydantic_to_render),
485
+ },
486
+ )
487
+ if has_code_interpreter:
488
+ conversation.insert(1, {"role": "system", "content": PYTHON_RUN_SYS_MSG})
489
+
490
+ # Compilation function uses a cache to avoid recompiling the same template
491
+ compiled_template = self._compile_jinja_template(chat_template)
492
+
493
+ if isinstance(conversation, (list, tuple)) and (
494
+ isinstance(conversation[0], (list, tuple))
495
+ or hasattr(conversation[0], "messages")
496
+ ):
497
+ conversations = conversation
498
+ is_batched = True
499
+ else:
500
+ conversations = [conversation]
501
+ is_batched = False
502
+
503
+ rendered = []
504
+ template_kwargs = {
505
+ **self.special_tokens_map,
506
+ **kwargs,
507
+ } # kwargs overwrite special tokens if both are present
508
+ for chat in conversations:
509
+ if hasattr(chat, "messages"):
510
+ # Indicates it's a Conversation object
511
+ chat = chat.messages
512
+ rendered_chat = compiled_template.render(
513
+ messages=chat,
514
+ add_generation_prompt=add_generation_prompt,
515
+ **template_kwargs,
516
+ )
517
+ rendered.append(rendered_chat)
518
+
519
+ if not is_batched:
520
+ rendered = rendered[0]
521
+
522
+ if tokenize:
523
+ out = self(
524
+ rendered,
525
+ padding=padding,
526
+ truncation=truncation,
527
+ max_length=max_length,
528
+ add_special_tokens=False,
529
+ return_tensors=return_tensors,
530
+ **tokenizer_kwargs,
531
+ )
532
+ if return_dict:
533
+ return out
534
+ else:
535
+ return out["input_ids"]
536
+ else:
537
+ return rendered