import gradio as gr from transformers import AutoTokenizer import json from functools import partial tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta") demo_conversation = """[ {"role": "system", "content": "You are a helpful chatbot."}, {"role": "user", "content": "Hi there!"}, {"role": "assistant", "content": "Hello, human!"}, {"role": "user", "content": "Can I ask a question?"} ]""" chat_templates = { "chatml": """{% for message in messages %} {{ "<|im_start|>" + message["role"] + "\\n" + message["content"] + "<|im_end|>\\n" }} {% endfor %} {% if add_generation_prompt %} {{ "<|im_start|>assistant\\n" }} {% endif %}""", "zephyr": """{% for message in messages %} {% if message['role'] == 'user' %} {{ '<|user|>\n' + message['content'] + eos_token }} {% elif message['role'] == 'system' %} {{ '<|system|>\n' + message['content'] + eos_token }} {% elif message['role'] == 'assistant' %} {{ '<|assistant|>\n' + message['content'] + eos_token }} {% endif %} {% if loop.last and add_generation_prompt %} {{ '<|assistant|>' }} {% endif %} {% endfor %}""", "llama": """{% if messages[0]['role'] == 'system' %} {% set loop_messages = messages[1:] %} {% set system_message = messages[0]['content'] %} {% else %} {% set loop_messages = messages %} {% set system_message = false %} {% endif %} {% for message in loop_messages %} {% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %} {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }} {% endif %} {% if loop.index0 == 0 and system_message != false %} {% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %} {% else %} {% set content = message['content'] %} {% endif %} {% if message['role'] == 'user' %} {{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }} {% elif message['role'] == 'assistant' %} {{ ' ' + content.strip() + ' ' + eos_token }} {% endif %} {% endfor %}""", "alpaca": """{% for message in messages %} {% if message['role'] == 'system' %} {{ message['content'] + '\n\n' }} {% elif message['role'] == 'user' %} {{ '### Instruction:\n' + message['content'] + '\n\n' }} {% elif message['role'] == 'assistant' %} {{ '### Response:\n' + message['content'] + '\n\n' }} {% endif %} {% if loop.last and add_generation_prompt %} {{ '### Response:\n' }} {% endif %} {% endfor %}""", "vicuna": """{% for message in messages %} {% if message['role'] == 'system' %} {{ message['content'] + '\n' }} {% elif message['role'] == 'user' %} {{ 'USER:\n' + message['content'] + '\n' }} {% elif message['role'] == 'assistant' %} {{ 'ASSISTANT:\n' + message['content'] + '\n' }} {% endif %} {% if loop.last and add_generation_prompt %} {{ 'ASSISTANT:\n' }} {% endif %} {% endfor %}""", "falcon": """{% for message in messages %} {% if not loop.first %} {{ '\n' }} {% endif %} {% if message['role'] == 'system' %} {{ 'System: ' }} {% elif message['role'] == 'user' %} {{ 'User: ' }} {% elif message['role'] == 'assistant' %} {{ 'Falcon: ' }} {% endif %} {{ message['content'] }} {% endfor %} {% if add_generation_prompt %} {{ '\n' + 'Falcon:' }} {% endif %}""" } description_text = """# Chat Template Creator ### This space is a helper app for writing [Chat Templates](https://huggingface.co/docs/transformers/main/en/chat_templating). ### When you're happy with the outputs from your template, you can use the code block at the end to add it to a PR!""" def apply_chat_template(template, test_conversation, add_generation_prompt, cleanup_whitespace): if cleanup_whitespace: template = "".join([line.strip() for line in template.split('\n')]) tokenizer.chat_template = template outputs = [] conversation = json.loads(test_conversation) pr_snippet = ( "CHECKPOINT = \"big-ai-company/cool-new-model\"\n" "tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)", f"tokenizer.chat_template = \"{template}\"", "tokenizer.push_to_hub(CHECKPOINT, create_pr=True)" ) pr_snippet = "\n".join(pr_snippet) formatted = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=add_generation_prompt) return formatted, pr_snippet def load_template(template_name): template_in.value = chat_templates[template_name] with gr.Blocks() as demo: gr.Markdown(description_text) with gr.Row(): gr.Markdown("### Pick an existing template to start:") with gr.Row(): load_chatml = gr.Button("ChatML") load_zephyr = gr.Button("Zephyr") load_llama = gr.Button("LLaMA") with gr.Row(): load_alpaca = gr.Button("Alpaca") load_vicuna = gr.Button("Vicuna") load_falcon = gr.Button("Falcon") with gr.Row(): with gr.Column(): template_in = gr.TextArea(value=chat_templates["chatml"], lines=10, max_lines=30, label="Chat Template") conversation_in = gr.TextArea(value=demo_conversation, lines=6, label="Conversation") generation_prompt_check = gr.Checkbox(value=False, label="Add generation prompt") cleanup_whitespace_check = gr.Checkbox(value=True, label="Cleanup template whitespace") submit = gr.Button("Apply template", variant="primary") with gr.Column(): formatted_out = gr.TextArea(label="Formatted conversation") code_snippet_out = gr.TextArea(label="Code snippet to create PR", lines=3, show_label=True, show_copy_button=True) submit.click(fn=apply_chat_template, inputs=[template_in, conversation_in, generation_prompt_check, cleanup_whitespace_check], outputs=[formatted_out, code_snippet_out] ) load_chatml.click(fn=partial(load_template, "chatml")) load_zephyr.click(fn=partial(load_template, "zephyr")) load_llama.click(fn=partial(load_template, "llama")) load_alpaca.click(fn=partial(load_template, "alpaca")) load_vicuna.click(fn=partial(load_template, "vicuna")) load_falcon.click(fn=partial(load_template, "falcon")) demo.launch() #iface = gr.Interface( # description=description_text, # fn=apply_chat_template, # inputs=[ # gr.TextArea(value=default_template, lines=10, max_lines=30, label="Chat Template"), # gr.TextArea(value=demo_conversation, lines=6, label="Conversation"), # gr.Checkbox(value=False, label="Add generation prompt"), # gr.Checkbox(value=True, label="Cleanup template whitespace"), # ], # outputs=[ # gr.TextArea(label="Formatted conversation"), # gr.TextArea(label="Code snippet to create PR", lines=3, show_label=True, show_copy_button=True) # ] #) #iface.launch()