jtmuller commited on
Commit
00407e6
·
1 Parent(s): b1aea93

Update Space

Browse files
Files changed (1) hide show
  1. app.py +19 -7
app.py CHANGED
@@ -19,16 +19,20 @@ PUNCS = string.punctuation.replace("'", "")
19
  # ------------------------------------------------
20
  # Utility functions
21
  # ------------------------------------------------
 
 
22
  def softmax(logits: np.ndarray) -> np.ndarray:
23
  exp_logits = np.exp(logits - np.max(logits))
24
  return exp_logits / np.sum(exp_logits)
25
 
 
26
  def normalize_text(text: str) -> str:
27
  """Lowercase, strip punctuation (except single quotes), and collapse whitespace."""
28
  def strip_puncs(text_in):
29
  return text_in.translate(str.maketrans("", "", PUNCS))
30
  return " ".join(strip_puncs(text).lower().split())
31
 
 
32
  def calculate_eou(chat_ctx, session, tokenizer) -> float:
33
  """
34
  Given a conversation context (list of dicts with 'role' and 'content'),
@@ -62,11 +66,13 @@ def calculate_eou(chat_ctx, session, tokenizer) -> float:
62
  eou_token_id = tokenizer.encode("<|im_end|>")[-1]
63
  return probs[eou_token_id]
64
 
 
65
  # ------------------------------------------------
66
  # Load ONNX session & tokenizer once
67
  # ------------------------------------------------
68
  print("Loading ONNX model session...")
69
- onnx_session = ort.InferenceSession(ONNX_FILENAME, providers=["CPUExecutionProvider"])
 
70
 
71
  print("Loading tokenizer...")
72
  turn_detector_tokenizer = AutoTokenizer.from_pretrained(HG_MODEL)
@@ -80,6 +86,8 @@ client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
80
  # ------------------------------------------------
81
  # Gradio Chat Handler
82
  # ------------------------------------------------
 
 
83
  def respond(message, history, system_message, max_tokens, temperature, top_p):
84
  """
85
  This function is called on each new user message in the ChatInterface.
@@ -93,19 +101,22 @@ def respond(message, history, system_message, max_tokens, temperature, top_p):
93
  # [{'role': 'system', 'content': ...},
94
  # {'role': 'user', 'content': ...}, ...]
95
 
96
- messages = []
 
 
 
97
  if system_message.strip():
98
- messages.append({"role": "system", "content": system_message})
99
 
100
  # history is a list of tuples: [(user1, assistant1), (user2, assistant2), ...]
101
- for user_text, assistant_text in history:
102
  if user_text:
103
  messages.append({"role": "user", "content": user_text})
104
  if assistant_text:
105
  messages.append({"role": "assistant", "content": assistant_text})
106
 
107
  # Append the new user message
108
- messages.append({"role": "user", "content": message})
109
 
110
  # 2) Calculate EOU probability on the entire conversation
111
  eou_prob = calculate_eou(messages, onnx_session, turn_detector_tokenizer)
@@ -113,9 +124,10 @@ def respond(message, history, system_message, max_tokens, temperature, top_p):
113
  # 3) Generate the assistant response from your HF model.
114
  # (This code streams token-by-token.)
115
  response = ""
116
-
117
  yield f"[EOU Probability: {eou_prob:.4f}]"
118
 
 
119
  # ------------------------------------------------
120
  # Gradio ChatInterface
121
  # ------------------------------------------------
@@ -158,4 +170,4 @@ demo = gr.ChatInterface(
158
  )
159
 
160
  if __name__ == "__main__":
161
- demo.launch()
 
19
  # ------------------------------------------------
20
  # Utility functions
21
  # ------------------------------------------------
22
+
23
+
24
  def softmax(logits: np.ndarray) -> np.ndarray:
25
  exp_logits = np.exp(logits - np.max(logits))
26
  return exp_logits / np.sum(exp_logits)
27
 
28
+
29
  def normalize_text(text: str) -> str:
30
  """Lowercase, strip punctuation (except single quotes), and collapse whitespace."""
31
  def strip_puncs(text_in):
32
  return text_in.translate(str.maketrans("", "", PUNCS))
33
  return " ".join(strip_puncs(text).lower().split())
34
 
35
+
36
  def calculate_eou(chat_ctx, session, tokenizer) -> float:
37
  """
38
  Given a conversation context (list of dicts with 'role' and 'content'),
 
66
  eou_token_id = tokenizer.encode("<|im_end|>")[-1]
67
  return probs[eou_token_id]
68
 
69
+
70
  # ------------------------------------------------
71
  # Load ONNX session & tokenizer once
72
  # ------------------------------------------------
73
  print("Loading ONNX model session...")
74
+ onnx_session = ort.InferenceSession(
75
+ ONNX_FILENAME, providers=["CPUExecutionProvider"])
76
 
77
  print("Loading tokenizer...")
78
  turn_detector_tokenizer = AutoTokenizer.from_pretrained(HG_MODEL)
 
86
  # ------------------------------------------------
87
  # Gradio Chat Handler
88
  # ------------------------------------------------
89
+
90
+
91
  def respond(message, history, system_message, max_tokens, temperature, top_p):
92
  """
93
  This function is called on each new user message in the ChatInterface.
 
101
  # [{'role': 'system', 'content': ...},
102
  # {'role': 'user', 'content': ...}, ...]
103
 
104
+ messages = [
105
+ {"role": "user",
106
+ "content": message}
107
+ ]
108
  if system_message.strip():
109
+ messages.insert(0, {"role": "system", "content": system_message})
110
 
111
  # history is a list of tuples: [(user1, assistant1), (user2, assistant2), ...]
112
+ """ for user_text, assistant_text in history:
113
  if user_text:
114
  messages.append({"role": "user", "content": user_text})
115
  if assistant_text:
116
  messages.append({"role": "assistant", "content": assistant_text})
117
 
118
  # Append the new user message
119
+ messages.append({"role": "user", "content": message}) """
120
 
121
  # 2) Calculate EOU probability on the entire conversation
122
  eou_prob = calculate_eou(messages, onnx_session, turn_detector_tokenizer)
 
124
  # 3) Generate the assistant response from your HF model.
125
  # (This code streams token-by-token.)
126
  response = ""
127
+
128
  yield f"[EOU Probability: {eou_prob:.4f}]"
129
 
130
+
131
  # ------------------------------------------------
132
  # Gradio ChatInterface
133
  # ------------------------------------------------
 
170
  )
171
 
172
  if __name__ == "__main__":
173
+ demo.launch()