Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,332 Bytes
8824f88 e8f48ed 8824f88 2213339 8824f88 bdbcb4a 245a9f1 8824f88 7680f74 8824f88 a5c0568 a174343 8824f88 31bf44d 0737a9d 34353a1 0737a9d 8824f88 c5ac75a bdbcb4a c5ac75a 615fd05 9604647 8824f88 ecbb198 fe36abc 8824f88 ae2bb64 8824f88 bdbcb4a ae2bb64 fe36abc ae2bb64 8824f88 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
#!/usr/bin/env python
import os
from threading import Thread
from typing import Iterator
import spaces
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "8192"))
model_id = "utter-project/EuroLLM-1.7B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
@spaces.GPU
def generate(
message: str,
chat_history: list[tuple[str, str]],
max_new_tokens: int = 1024,
temperature: float = 0.06,
top_p: float = 0.95,
top_k: int = 40,
repetition_penalty: float = 1.2,
) -> Iterator[str]:
conversation = []
for user, assistant in chat_history:
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
chat_interface = gr.ChatInterface(
fn=generate,
chatbot=gr.Chatbot(height=450,
label="utter-project/EuroLLM-1.7B-Instruct",
show_share_button=True,
),
cache_examples=False,
additional_inputs=[
gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
),
gr.Slider(
label="Temperature",
minimum=0.05,
maximum=1.2,
step=0.05,
value=0.2,
),
gr.Slider(
label="Top-p (nucleus sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.9,
),
gr.Slider(
label="Top-k",
minimum=1,
maximum=1000,
step=1,
value=50,
),
gr.Slider(
label="Repetition penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.2,
),
],
examples=[
["Describe the significance of the Eiffel Tower in French culture and history."],
["Что такое 'загадочная русская душа' и как это понятие отражается в русской литературе?"], # Russian: What is the "mysterious Russian soul" and how is this concept reflected in Russian literature?
["Jakie są najbardziej znane polskie tradycje bożonarodzeniowe?"], # Polish: What are the most well-known Polish Christmas traditions?
["Welche Rolle spielte die Hanse im mittelalterlichen Europa?"], # German: What role did the Hanseatic League play in medieval Europe?
["日本の茶道の精神と作法について説明してください。"] # Japanese: Please explain the spirit and etiquette of Japanese tea ceremony.
],
title="utter-project/EuroLLM-1.7B-Instruct",
description="""utter-project/EuroLLM-1.7B-Instruct quick demo""",
submit_btn="Generate",
stop_btn="Stop",
retry_btn="🔄 Retry",
undo_btn="↩️ Undo",
clear_btn="🗑️ Clear",
)
with gr.Blocks(css="style.css") as demo:
chat_interface.render()
if __name__ == "__main__":
demo.queue(max_size=20).launch()
|