File size: 4,468 Bytes
af4b5bf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import gradio as gr
from transformers import AutoTokenizer
import onnxruntime as ort
import numpy as np
import string
from huggingface_hub import InferenceClient
# Initialize Qwen client
qwen_client = InferenceClient("EVA-UNIT-01/EVA-Qwen2.5-1.5B-v0.0")
# Model and ONNX setup
HG_MODEL = "livekit/turn-detector"
ONNX_FILENAME = "model_quantized.onnx"
PUNCS = string.punctuation.replace("'", "")
MAX_HISTORY = 4 # Adjusted to use the last 4 messages
MAX_HISTORY_TOKENS = 512
EOU_THRESHOLD = 0.5 # Updated threshold to match original
# Initialize ONNX model
tokenizer = AutoTokenizer.from_pretrained(HG_MODEL)
onnx_session = ort.InferenceSession(ONNX_FILENAME, providers=["CPUExecutionProvider"])
# Softmax function
def softmax(logits):
exp_logits = np.exp(logits - np.max(logits))
return exp_logits / np.sum(exp_logits)
# Normalize text
def normalize_text(text):
def strip_puncs(text):
return text.translate(str.maketrans("", "", PUNCS))
return " ".join(strip_puncs(text).lower().split())
# Format chat context
def format_chat_ctx(chat_ctx):
new_chat_ctx = []
for msg in chat_ctx:
if msg["role"] in ("user", "assistant"):
content = normalize_text(msg["content"])
if content:
msg["content"] = content
new_chat_ctx.append(msg)
# Tokenize with chat template
convo_text = tokenizer.apply_chat_template(
new_chat_ctx, add_generation_prompt=False, add_special_tokens=False, tokenize=False
)
# Remove EOU token from the current utterance
ix = convo_text.rfind("<|im_end|>")
return convo_text[:ix]
# Calculate EOU probability
def calculate_eou(chat_ctx, session):
formatted_text = format_chat_ctx(chat_ctx[-MAX_HISTORY:]) # Use the last 4 messages
inputs = tokenizer(
formatted_text,
return_tensors="np",
truncation=True,
max_length=MAX_HISTORY_TOKENS,
)
input_ids = np.array(inputs["input_ids"], dtype=np.int64)
outputs = session.run(["logits"], {"input_ids": input_ids})
logits = outputs[0][0, -1, :]
probs = softmax(logits)
eou_token_id = tokenizer.encode("<|im_end|>")[-1]
return probs[eou_token_id]
# Read system message from file
with open("character/herta.txt", "r") as f:
system_message = f.read()
# Respond function
def respond(
message,
history: list[tuple[str, str]],
max_tokens,
temperature,
top_p,
):
# Keep the last 4 conversation pairs (user-assistant)
messages = [{"role": "system", "content": system_message}]
for val in history[-10:]: # Only use the last 4 pairs
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
# Add the new user message to the context
messages.append({"role": "user", "content": message})
# Calculate EOU probability
eou_prob = calculate_eou(messages, onnx_session)
print(f"EOU Probability: {eou_prob}") # Debug output
# If EOU is below the threshold, ask for more input
if eou_prob < EOU_THRESHOLD:
yield "[Waiting for user to continue input...]"
return
# Generate response with Qwen
response = ""
for message in qwen_client.chat_completion(
messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
token = message.choices[0].delta.content
response += token
yield response
print(f"Generated response: {response}")
# Gradio interface
demo = gr.ChatInterface(
respond,
# additional_inputs=[
# # Commented out to disable user modification of the system message
# # gr.Textbox(value="You are an assistant.", label="System message"),
# gr.Slider(minimum=1, maximum=4096, value=256, step=1, label="Max new tokens"),
# gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
# gr.Slider(
# minimum=0.1,
# maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"
# ),
# ],
theme = gr.themes.Default().set(
button_primary_background_fill="#FF0000",
button_primary_background_fill_dark="#AAAAAA",
button_primary_border="*button_primary_background_fill",
button_primary_border_dark="*button_primary_background_fill_dark",
)
)
if __name__ == "__main__":
demo.launch()
|