beyoru commited on
Commit
47cf318
·
verified ·
1 Parent(s): 170291e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -28
app.py CHANGED
@@ -7,29 +7,32 @@ from huggingface_hub import InferenceClient
7
  import os
8
 
9
  # Initialize Qwen client
10
- qwen_client = InferenceClient(api_key=os.environ.get("HF_TOKEN"))
11
 
12
  # Model and ONNX setup
13
  HG_MODEL = "livekit/turn-detector"
14
  ONNX_FILENAME = "model_quantized.onnx"
15
  PUNCS = string.punctuation.replace("'", "")
16
- MAX_HISTORY = 4
17
  MAX_HISTORY_TOKENS = 512
18
- EOU_THRESHOLD = 0.5
19
 
20
  # Initialize ONNX model
21
  tokenizer = AutoTokenizer.from_pretrained(HG_MODEL)
22
  onnx_session = ort.InferenceSession(ONNX_FILENAME, providers=["CPUExecutionProvider"])
23
 
 
24
  def softmax(logits):
25
  exp_logits = np.exp(logits - np.max(logits))
26
  return exp_logits / np.sum(exp_logits)
27
 
 
28
  def normalize_text(text):
29
  def strip_puncs(text):
30
  return text.translate(str.maketrans("", "", PUNCS))
31
  return " ".join(strip_puncs(text).lower().split())
32
 
 
33
  def format_chat_ctx(chat_ctx):
34
  new_chat_ctx = []
35
  for msg in chat_ctx:
@@ -38,14 +41,19 @@ def format_chat_ctx(chat_ctx):
38
  if content:
39
  msg["content"] = content
40
  new_chat_ctx.append(msg)
 
 
41
  convo_text = tokenizer.apply_chat_template(
42
  new_chat_ctx, add_generation_prompt=False, add_special_tokens=False, tokenize=False
43
  )
 
 
44
  ix = convo_text.rfind("<|im_end|>")
45
  return convo_text[:ix]
46
 
 
47
  def calculate_eou(chat_ctx, session):
48
- formatted_text = format_chat_ctx(chat_ctx[-MAX_HISTORY:])
49
  inputs = tokenizer(
50
  formatted_text,
51
  return_tensors="np",
@@ -59,50 +67,66 @@ def calculate_eou(chat_ctx, session):
59
  eou_token_id = tokenizer.encode("<|im_end|>")[-1]
60
  return probs[eou_token_id]
61
 
 
 
62
  def respond(
63
  message,
64
  history: list[tuple[str, str]],
65
- max_tokens=2048,
66
- temperature=0.6,
67
- top_p=0.95,
68
  ):
69
- messages = [{"role": "system", "content": os.environ.get("CHARACTER_DESC", "You are a helpful assistant.")}]
70
-
71
- for val in history[-MAX_HISTORY:]:
 
72
  if val[0]:
73
  messages.append({"role": "user", "content": val[0]})
74
  if val[1]:
75
  messages.append({"role": "assistant", "content": val[1]})
76
-
 
77
  messages.append({"role": "user", "content": message})
78
 
 
79
  eou_prob = calculate_eou(messages, onnx_session)
 
 
 
80
  if eou_prob < EOU_THRESHOLD:
81
- yield "[Wait... Keep typing...]"
82
  return
83
 
84
-
85
- stream = qwen_client.chat.completions.create(
86
- model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
87
- messages=messages,
88
  max_tokens=max_tokens,
 
89
  temperature=temperature,
90
  top_p=top_p,
91
- )
 
 
 
 
 
 
92
 
93
- for chunk in stream:
94
- new_token = chunk.choices[0].delta.content
95
- yield "".join(new_token)
96
-
97
- # Create Gradio interface
98
  demo = gr.ChatInterface(
99
  respond,
100
- additional_inputs=[
101
- gr.Slider(1, 4096, value=512, label="Max Tokens"),
102
- gr.Slider(0.1, 4.0, value=0.66, label="Temperature"),
103
- gr.Slider(0.1, 1.0, value=0.95, label="Top-p"),
104
- ]
 
 
 
 
 
105
  )
106
 
107
  if __name__ == "__main__":
108
- demo.launch()
 
7
  import os
8
 
9
  # Initialize Qwen client
10
+ client = InferenceClient("meta-llama/Llama-3.2-3B-Instruct")
11
 
12
  # Model and ONNX setup
13
  HG_MODEL = "livekit/turn-detector"
14
  ONNX_FILENAME = "model_quantized.onnx"
15
  PUNCS = string.punctuation.replace("'", "")
16
+ MAX_HISTORY = 4 # Adjusted to use the last 4 messages
17
  MAX_HISTORY_TOKENS = 512
18
+ EOU_THRESHOLD = 0.5 # Updated threshold to match original
19
 
20
  # Initialize ONNX model
21
  tokenizer = AutoTokenizer.from_pretrained(HG_MODEL)
22
  onnx_session = ort.InferenceSession(ONNX_FILENAME, providers=["CPUExecutionProvider"])
23
 
24
+ # Softmax function
25
  def softmax(logits):
26
  exp_logits = np.exp(logits - np.max(logits))
27
  return exp_logits / np.sum(exp_logits)
28
 
29
+ # Normalize text
30
  def normalize_text(text):
31
  def strip_puncs(text):
32
  return text.translate(str.maketrans("", "", PUNCS))
33
  return " ".join(strip_puncs(text).lower().split())
34
 
35
+ # Format chat context
36
  def format_chat_ctx(chat_ctx):
37
  new_chat_ctx = []
38
  for msg in chat_ctx:
 
41
  if content:
42
  msg["content"] = content
43
  new_chat_ctx.append(msg)
44
+
45
+ # Tokenize with chat template
46
  convo_text = tokenizer.apply_chat_template(
47
  new_chat_ctx, add_generation_prompt=False, add_special_tokens=False, tokenize=False
48
  )
49
+
50
+ # Remove EOU token from the current utterance
51
  ix = convo_text.rfind("<|im_end|>")
52
  return convo_text[:ix]
53
 
54
+ # Calculate EOU probability
55
  def calculate_eou(chat_ctx, session):
56
+ formatted_text = format_chat_ctx(chat_ctx[-MAX_HISTORY:]) # Use the last 4 messages
57
  inputs = tokenizer(
58
  formatted_text,
59
  return_tensors="np",
 
67
  eou_token_id = tokenizer.encode("<|im_end|>")[-1]
68
  return probs[eou_token_id]
69
 
70
+
71
+ # Respond function
72
  def respond(
73
  message,
74
  history: list[tuple[str, str]],
75
+ max_tokens,
76
+ temperature,
77
+ top_p,
78
  ):
79
+ # Keep the last 4 conversation pairs (user-assistant)
80
+ messages = [{"role": "system", "content": os.environ.get("CHARACTER_DESC")}]
81
+
82
+ for val in history[-10:]: # Only use the last 4 pairs
83
  if val[0]:
84
  messages.append({"role": "user", "content": val[0]})
85
  if val[1]:
86
  messages.append({"role": "assistant", "content": val[1]})
87
+
88
+ # Add the new user message to the context
89
  messages.append({"role": "user", "content": message})
90
 
91
+ # Calculate EOU probability
92
  eou_prob = calculate_eou(messages, onnx_session)
93
+ print(f"EOU Probability: {eou_prob}") # Debug output
94
+
95
+ # If EOU is below the threshold, ask for more input
96
  if eou_prob < EOU_THRESHOLD:
97
+ yield "[Waiting for user to continue input...]"
98
  return
99
 
100
+ # Generate response with Qwen
101
+ response = ""
102
+ for message in client.chat_completion(
103
+ messages,
104
  max_tokens=max_tokens,
105
+ stream=True,
106
  temperature=temperature,
107
  top_p=top_p,
108
+ ):
109
+ token = message.choices[0].delta.content
110
+ response += token
111
+ yield response
112
+
113
+ print(f"Generated response: {response}")
114
+
115
 
116
+ # Gradio interface
 
 
 
 
117
  demo = gr.ChatInterface(
118
  respond,
119
+ # additional_inputs=[
120
+ # # Commented out to disable user modification of the system message
121
+ # # gr.Textbox(value="You are an assistant.", label="System message"),
122
+ # gr.Slider(minimum=1, maximum=4096, value=256, step=1, label="Max new tokens"),
123
+ # gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
124
+ # gr.Slider(
125
+ # minimum=0.1,
126
+ # maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"
127
+ # ),
128
+ # ],
129
  )
130
 
131
  if __name__ == "__main__":
132
+ demo.launch()