beyoru commited on
Commit
c7c557b
·
verified ·
1 Parent(s): beef6e2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -104
app.py CHANGED
@@ -2,115 +2,80 @@ import gradio as gr
2
  from huggingface_hub import InferenceClient
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import torch
 
 
 
 
 
 
 
 
 
5
  import string
6
 
7
- # Constants
8
- PUNCS = string.punctuation.replace("'", "")
9
- MAX_HISTORY = 4
10
- MAX_HISTORY_TOKENS = 512
11
-
12
- class EOUDetector:
13
- def __init__(self, model_name="livekit/turn-detector"):
14
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
15
- self.model = AutoModelForCausalLM.from_pretrained(model_name)
16
- self.eou_token_id = self.tokenizer.encode("<|im_end|>")[-1]
17
-
18
- def _normalize_text(self, text: str) -> str:
19
- """Normalize text by removing punctuation and extra spaces."""
20
- text = text.translate(str.maketrans("", "", PUNCS))
21
- return " ".join(text.lower().split())
22
-
23
- def _format_chat_context(self, messages: list[dict]) -> str:
24
- """Format chat context using the model's chat template."""
25
- normalized_messages = []
26
-
27
- for msg in messages[-MAX_HISTORY:]: # Only keep last MAX_HISTORY messages
28
- if msg["role"] not in ("user", "assistant"):
29
- continue
30
-
31
- content = self._normalize_text(msg["content"])
32
- if content:
33
- normalized_messages.append({
34
- "role": msg["role"],
35
- "content": content
36
- })
37
-
38
- # Apply chat template without generation prompt
39
- conversation = self.tokenizer.apply_chat_template(
40
- normalized_messages,
41
- add_generation_prompt=False,
42
- add_special_tokens=False,
43
- tokenize=False
44
- )
45
-
46
- # Remove the EOU token from current utterance if present
47
- ix = conversation.rfind("<|im_end|>")
48
- if ix >= 0:
49
- conversation = conversation[:ix]
50
-
51
- return conversation
52
-
53
- def compute_eou_probability(self, messages: list[dict]) -> float:
54
- """Compute the probability of end of utterance."""
55
- # Format the conversation
56
- conversation = self._format_chat_context(messages)
57
-
58
- # Tokenize with proper truncation
59
- inputs = self.tokenizer(
60
- conversation,
61
- add_special_tokens=False,
62
- return_tensors="pt",
63
- max_length=MAX_HISTORY_TOKENS,
64
- truncation=True,
65
- truncation_side="left"
66
- )
67
-
68
- # Get model predictions
69
- with torch.no_grad():
70
- outputs = self.model(**inputs)
71
-
72
- # Get logits for the last token
73
- logits = outputs.logits[0, -1, :]
74
-
75
- # Compute softmax properly
76
- probabilities = torch.nn.functional.softmax(logits, dim=-1)
77
-
78
- # Get probability for EOU token
79
- eou_probability = probabilities[self.eou_token_id].item()
80
-
81
- return eou_probability
82
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  def respond(
84
- message: str,
85
  history: list[tuple[str, str]],
86
- system_message: str,
87
- max_tokens: int,
88
- temperature: float,
89
- top_p: float,
90
- eou_threshold: float = 0.2,
91
- ) -> str:
92
- # Initialize clients
93
- eou_detector = EOUDetector()
94
- client = InferenceClient("Qwen/Qwen2.5-3B-Instruct")
95
-
96
- # Prepare messages
97
  messages = [{"role": "system", "content": system_message}]
98
- for user_msg, assistant_msg in history:
99
- if user_msg:
100
- messages.append({"role": "user", "content": user_msg})
101
- if assistant_msg:
102
- messages.append({"role": "assistant", "content": assistant_msg})
103
-
104
- # Add current message
105
- messages.append({"role": "user", "content": message})
106
 
107
- # Check EOU probability
108
- eou_probability = eou_detector.compute_eou_probability(messages)
109
- print(f"EOU Probability: {eou_probability}")
 
 
110
 
 
 
 
 
111
  if eou_probability >= eou_threshold:
112
- # Generate response
 
 
113
  response = ""
 
114
  for message in client.chat_completion(
115
  messages,
116
  max_tokens=max_tokens,
@@ -122,19 +87,29 @@ def respond(
122
  response += token
123
  yield response
124
  else:
 
125
  yield "Waiting for user to finish... Please continue."
 
126
 
127
- # Gradio Interface
128
  demo = gr.ChatInterface(
129
  respond,
130
  additional_inputs=[
131
- gr.Textbox(value="You are a helpful assistant", label="System message"),
132
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
133
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
134
- gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p"),
135
- gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.05, label="EOU Threshold"),
 
 
 
 
 
 
 
 
136
  ],
137
  )
138
 
139
  if __name__ == "__main__":
140
- demo.launch()
 
2
  from huggingface_hub import InferenceClient
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import torch
5
+
6
+ # Load Inference Client for the response model
7
+ client = InferenceClient("Qwen/Qwen2.5-3B-Instruct")
8
+
9
+ # Load tokenizer and model for the EOU detection
10
+ tokenizer = AutoTokenizer.from_pretrained("livekit/turn-detector")
11
+ model = AutoModelForCausalLM.from_pretrained("livekit/turn-detector")
12
+
13
+ import re
14
  import string
15
 
16
+ def normalize_text(text: str) -> str:
17
+ """Chuẩn hóa văn bản bằng cách loại bỏ dấu câu, khoảng trắng thừa và chuyển về chữ thường."""
18
+ text = text.strip().lower() # Chuyển về chữ thường và xóa khoảng trắng đầu/cuối
19
+ text = re.sub(f"[{re.escape(string.punctuation)}]", "", text) # Loại bỏ dấu câu
20
+ return re.sub(r"\s+", " ", text) # Loại bỏ khoảng trắng thừa
21
+
22
+ def compute_eou_probability(chat_ctx: list[dict[str, str]], max_tokens: int = 512) -> float:
23
+ """Compute the probability of End of Utterance (EOU) after normalizing text."""
24
+ conversation = ["Assistant ready to help."] # Add system message directly
25
+
26
+ for msg in chat_ctx:
27
+ content = msg.get("content", "")
28
+ if content:
29
+ normalized_content = normalize_text(content) # Chuẩn hóa văn bản
30
+ conversation.append(normalized_content)
31
+
32
+ # Tokenize the conversation
33
+ inputs = tokenizer(
34
+ conversation, padding=True, truncation=True, max_length=max_tokens, return_tensors="pt"
35
+ )
36
+
37
+ with torch.no_grad():
38
+ outputs = model(**inputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
+ logits = outputs.logits[0, -1, :]
41
+ probabilities = torch.nn.functional.softmax(logits, dim=-1)
42
+
43
+ # Get EOU token probability
44
+ eou_token_id = tokenizer.encode("<|im_end|>")[0]
45
+ if eou_token_id not in tokenizer.get_vocab().values():
46
+ raise ValueError("EOU token '<|im_end|>' not found in tokenizer vocabulary.")
47
+
48
+ return probabilities[eou_token_id].item()
49
+
50
+
51
+ # Respond function with EOU checking logic
52
  def respond(
53
+ message,
54
  history: list[tuple[str, str]],
55
+ system_message,
56
+ max_tokens,
57
+ temperature,
58
+ top_p,
59
+ eou_threshold: float = 0.2, # Default EOU threshold
60
+ ):
 
 
 
 
 
61
  messages = [{"role": "system", "content": system_message}]
 
 
 
 
 
 
 
 
62
 
63
+ for val in history:
64
+ if val[0]:
65
+ messages.append({"role": "user", "content": val[0]})
66
+ if val[1]:
67
+ messages.append({"role": "assistant", "content": val[1]})
68
 
69
+ # Compute EOU probability before responding
70
+ eou_probability = compute_eou_probability(messages, max_tokens=max_tokens)
71
+ console.log(eou_probability)
72
+ # Only respond if EOU probability exceeds threshold
73
  if eou_probability >= eou_threshold:
74
+ # Prepare message for assistant response
75
+ messages.append({"role": "user", "content": message})
76
+
77
  response = ""
78
+
79
  for message in client.chat_completion(
80
  messages,
81
  max_tokens=max_tokens,
 
87
  response += token
88
  yield response
89
  else:
90
+ # Let the user continue typing if the EOU probability is low
91
  yield "Waiting for user to finish... Please continue."
92
+ print("Waiting for user to finish... Please continue.")
93
 
94
+ # Gradio UI
95
  demo = gr.ChatInterface(
96
  respond,
97
  additional_inputs=[
98
+ gr.Textbox(value="You are helpful assistant", label="System message"),
99
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
100
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
101
+ gr.Slider(
102
+ minimum=0.1,
103
+ maximum=1.0,
104
+ value=0.95,
105
+ step=0.05,
106
+ label="Top-p (nucleus sampling)",
107
+ ),
108
+ gr.Slider(
109
+ minimum=0.0, maximum=1.0, value=0.7, step=0.05, label="EOU Threshold"
110
+ ), # Add EOU threshold slider
111
  ],
112
  )
113
 
114
  if __name__ == "__main__":
115
+ demo.launch()