File size: 4,048 Bytes
74cfe0f fc44b6a af4b5bf fc44b6a 6270b8e a62c597 fc44b6a 74cfe0f fc44b6a 77f7b3a af4b5bf fc44b6a af4b5bf fc44b6a 77f7b3a fc44b6a 7bb2dd6 af4b5bf 47cf318 af4b5bf 47cf318 af4b5bf 47cf318 af4b5bf 77f7b3a 47cf318 fc44b6a 77f7b3a af4b5bf 47cf318 77f7b3a 47cf318 af4b5bf fc44b6a af4b5bf fc44b6a af4b5bf 47cf318 7bb2dd6 fc44b6a 7bb2dd6 47cf318 da959de 7bb2dd6 47cf318 fc44b6a af4b5bf 77f7b3a 47cf318 533370e 47cf318 533370e fc44b6a 5af9f68 c5662e3 1fb128b fc44b6a 47cf318 7bb2dd6 47cf318 fc44b6a af4b5bf 47cf318 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import gradio as gr
from transformers import AutoTokenizer
import onnxruntime as ort
import numpy as np
import string
from huggingface_hub import InferenceClient
import os
client = InferenceClient(api_key=os.environ.get('HF_TOKEN'))
# Model and ONNX setup
HG_MODEL = "livekit/turn-detector"
ONNX_FILENAME = "model_quantized.onnx"
PUNCS = string.punctuation.replace("'", "")
MAX_HISTORY = 4 # Adjusted to use the last 4 messages
MAX_HISTORY_TOKENS = 512
EOU_THRESHOLD = 0.5 # Updated threshold to match original
# Initialize ONNX model
tokenizer = AutoTokenizer.from_pretrained(HG_MODEL)
onnx_session = ort.InferenceSession(ONNX_FILENAME, providers=["CPUExecutionProvider"])
# Softmax function
def softmax(logits):
exp_logits = np.exp(logits - np.max(logits))
return exp_logits / np.sum(exp_logits)
# Normalize text
def normalize_text(text):
def strip_puncs(text):
return text.translate(str.maketrans("", "", PUNCS))
return " ".join(strip_puncs(text).lower().split())
# Format chat context
def format_chat_ctx(chat_ctx):
new_chat_ctx = []
for msg in chat_ctx:
if msg["role"] in ("user", "assistant"):
content = normalize_text(msg["content"])
if content:
msg["content"] = content
new_chat_ctx.append(msg)
# Tokenize with chat template
convo_text = tokenizer.apply_chat_template(
new_chat_ctx, add_generation_prompt=False, add_special_tokens=False, tokenize=False
)
# Remove EOU token from the current utterance
ix = convo_text.rfind("<|im_end|>")
return convo_text[:ix]
# Calculate EOU probability
def calculate_eou(chat_ctx, session):
formatted_text = format_chat_ctx(chat_ctx[-MAX_HISTORY:]) # Use the last 4 messages
inputs = tokenizer(
formatted_text,
return_tensors="np",
truncation=True,
max_length=MAX_HISTORY_TOKENS,
)
input_ids = np.array(inputs["input_ids"], dtype=np.int64)
outputs = session.run(["logits"], {"input_ids": input_ids})
logits = outputs[0][0, -1, :]
probs = softmax(logits)
eou_token_id = tokenizer.encode("<|im_end|>")[-1]
return probs[eou_token_id]
# Respond function
def respond(
message,
history: list[tuple[str, str]],
max_tokens,
temperature,
top_p,
):
# Keep the last 4 conversation pairs (user-assistant)
messages = [{"role": "system", "content": os.environ.get("CHARACTER_DESC")}]
for val in history[-20:]:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
# Add the new user message to the context
messages.append({"role": "user", "content": message})
# Calculate EOU probability
# eou_prob = calculate_eou(messages, onnx_session)
# print(f"EOU Probability: {eou_prob}") # Debug output
# # If EOU is below the threshold, ask for more input
# if eou_prob < EOU_THRESHOLD:
# yield "[Waiting for user to continue input...]"
# return
stream = client.chat.completions.create(
model=os.environ.get('MODEL_ID'),
messages=messages,
temperature = 0.6,
max_tokens= 2048,
top_p = 0.9,
stream=True
)
bot_response = ""
for chunk in stream:
bot_response += chunk.choices[0].delta.content
yield bot_response
# Gradio interface
demo = gr.ChatInterface(
respond,
# additional_inputs=[
# # Commented out to disable user modification of the system message
# # gr.Textbox(value="You are an assistant.", label="System message"),
# gr.Slider(minimum=1, maximum=4096, value=256, step=1, label="Max new tokens"),
# gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
# gr.Slider(
# minimum=0.1,
# maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"
# ),
# ],
)
if __name__ == "__main__":
demo.launch()
|