sickcell commited on
Commit
0ded0f0
1 Parent(s): 8670050

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -24
app.py CHANGED
@@ -6,6 +6,7 @@ import faiss
6
  import numpy as np
7
  import gradio as gr
8
  import openai
 
9
 
10
  # 設置OpenAI API密鑰
11
  openai.api_key = 'sk-zK6OrDxP5DvDdAQqnR_nEuUL3UrZf_4W7qvYj1uphjT3BlbkFJdmZAxlxUCFv92NnnMwSB15FhpmiDZSfG2QPueobSQA'
@@ -47,30 +48,123 @@ index.add(embeddings.cpu().numpy().astype('float32'))
47
  def get_entity_groups(entities):
48
  return list(set(entity['entity_group'] for entity in entities))
49
 
50
- def semantic_search(query, top_k=3):
51
- query_vector = model.encode([query], convert_to_tensor=True)
52
- distances, indices = index.search(query_vector.cpu().numpy().astype('float32'), top_k)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  results = []
54
- for i, idx in enumerate(indices[0]):
55
- results.append({
56
- 'text': data[idx]['text'],
57
- 'similarity_score': 1 - distances[0][i] / 2,
58
- 'entity_groups': get_entity_groups(data[idx]['entities'])
59
- })
 
 
 
 
60
  return results
61
 
62
  def search_and_format(query):
63
  results = semantic_search(query)
64
- formatted_results = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  for i, result in enumerate(results, 1):
66
- formatted_results += f"{i}. 相似度分數: {result['similarity_score']:.4f}\n"
67
- formatted_results += f" 情資: {result['text']}\n"
68
- formatted_results += f" 命名實體: {', '.join(result['entity_groups'])}\n\n"
 
 
 
69
  return formatted_results
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  def transcribe_audio(audio):
72
  try:
73
- # 將音頻文件上傳到Whisper API
74
  with open(audio, "rb") as audio_file:
75
  transcript = openai.Audio.transcribe("whisper-1", audio_file)
76
  return transcript.text
@@ -80,7 +174,8 @@ def transcribe_audio(audio):
80
  def audio_to_search(audio):
81
  transcription = transcribe_audio(audio)
82
  search_results = search_and_format(transcription)
83
- return search_results, transcription, transcription
 
84
 
85
  # 示例問題
86
  example_queries = [
@@ -103,35 +198,36 @@ custom_css = """
103
  .output-column {flex: 2;}
104
  .examples-list {display: flex; flex-wrap: wrap; gap: 10px;}
105
  .examples-list > * {flex-basis: calc(50% - 5px);}
 
 
106
  """
107
 
108
  # 創建Gradio界面
109
  with gr.Blocks(css=custom_css) as iface:
110
- gr.Markdown("# AskCTI")
111
- gr.Markdown("輸入查詢或使用語音輸入以查詢相關情資威脅情報,將顯示前3個最相關的結果。")
112
 
113
  with gr.Row(equal_height=True):
114
  with gr.Column(scale=1, min_width=300):
115
- query_input = gr.Textbox(lines=3, label="文字查詢")
116
  with gr.Row():
117
- submit_btn = gr.Button("查詢")
118
  audio_input = gr.Audio(type="filepath", label="語音輸入")
119
 
120
- gr.Markdown("### 範例查詢")
121
  for i in range(0, len(example_queries), 2):
122
  with gr.Row():
123
  for j in range(2):
124
  if i + j < len(example_queries):
125
- gr.Button(example_queries[i+j]).click(
126
  lambda x: x, inputs=[gr.Textbox(value=example_queries[i+j], visible=False)], outputs=[query_input]
127
  )
128
 
129
  with gr.Column(scale=2):
130
- output = gr.Textbox(lines=20, label="查詢結果")
131
- transcription_output = gr.Textbox(lines=3, label="語音轉錄結果")
132
 
133
  submit_btn.click(search_and_format, inputs=[query_input], outputs=[output])
134
- audio_input.change(audio_to_search, inputs=[audio_input], outputs=[output, transcription_output, query_input])
135
 
136
  # 啟動Gradio界面
137
  iface.launch()
 
6
  import numpy as np
7
  import gradio as gr
8
  import openai
9
+ import re
10
 
11
  # 設置OpenAI API密鑰
12
  openai.api_key = 'sk-zK6OrDxP5DvDdAQqnR_nEuUL3UrZf_4W7qvYj1uphjT3BlbkFJdmZAxlxUCFv92NnnMwSB15FhpmiDZSfG2QPueobSQA'
 
48
  def get_entity_groups(entities):
49
  return list(set(entity['entity_group'] for entity in entities))
50
 
51
+ def get_color_for_entity(entity_group):
52
+ colors = {
53
+ 'SamFile': '#EE8434', # Orange (wheel)
54
+ 'Way': '#C95D63', # Indian red
55
+ 'Idus': '#AE8799', # Mountbatten pink
56
+ 'Tool': '#9083AE', # African Violet
57
+ 'Features': '#8181B9', # Tropical indigo
58
+ 'HackOrg': '#496DDB', # Royal Blue (web color)
59
+ 'Purp': '#BCD8C1', # Celadon
60
+ 'OffAct': '#D6DBB2', # Vanilla
61
+ 'Org': '#E3D985', # Flax
62
+ 'SecTeam': '#E57A44', # Orange (Crayola)
63
+ 'Time': '#E3D985', # Dark purple
64
+ 'Exp': '#5D76CF', # Glaucous
65
+ 'Area': '#757FC1', # Another shade of blue
66
+ }
67
+ return colors.get(entity_group, '#000000') # Default to black if entity group not found
68
+
69
+ def semantic_search(query, top_k=5):
70
+ query_embedding = model.encode([query], convert_to_tensor=True)
71
+ distances, indices = index.search(query_embedding.cpu().numpy().astype('float32'), top_k)
72
+
73
  results = []
74
+ for distance, idx in zip(distances[0], indices[0]):
75
+ similarity_score = 1 - distance / 2 # 將距離轉換為相似度分數
76
+ if similarity_score >= 0.45: # 只添加相似度大於等於0.3的結果
77
+ results.append({
78
+ 'text': data[idx]['text'],
79
+ 'entities': data[idx]['entities'],
80
+ 'similarity_score': similarity_score,
81
+ 'entity_groups': get_entity_groups(data[idx]['entities'])
82
+ })
83
+
84
  return results
85
 
86
  def search_and_format(query):
87
  results = semantic_search(query)
88
+
89
+ if not results:
90
+ return "<div class='search-result'><p>查無相關資訊。</p></div>"
91
+
92
+ formatted_results = """
93
+ <style>
94
+ .search-result {
95
+ font-size: 24px;
96
+ line-height: 1.6;
97
+ }
98
+ .search-result h2 {
99
+ font-size: 24px;
100
+ color: #333;
101
+ }
102
+ .search-result h3 {
103
+ font-size: 24px;
104
+ color: #444;
105
+ }
106
+ .search-result p {
107
+ margin-bottom: 24px;
108
+ }
109
+ .result-separator {
110
+ border-top: 2px solid #ccc;
111
+ margin: 20px 0;
112
+ }
113
+ </style>
114
+ <div class="search-result">
115
+ """
116
  for i, result in enumerate(results, 1):
117
+ if i > 1:
118
+ formatted_results += '<div class="result-separator"></div>'
119
+ formatted_results += f"<p><strong>相似度分數:</strong> {result['similarity_score']:.4f}</p>"
120
+ formatted_results += f"<p><strong>情資:</strong> {format_text_with_entities_markdown(result['text'], result['entities'])}</p>"
121
+ formatted_results += f"<p><strong>命名實體:</strong> {'、'.join(result['entity_groups'])}</p>"
122
+ formatted_results += "</div>"
123
  return formatted_results
124
 
125
+ def format_text_with_entities_markdown(text, entities):
126
+ # 將實體按照起始位置排序
127
+ entity_spans = sorted(entities, key=lambda x: x['start'])
128
+
129
+ # 創建一個字典來存儲每個單詞的實體
130
+ word_entities = {}
131
+
132
+ # 使用正則表達式分割文本為單詞
133
+ words = re.findall(r'\S+|\s+', text)
134
+ current_pos = 0
135
+
136
+ for word in words:
137
+ word_start = current_pos
138
+ word_end = current_pos + len(word)
139
+ word_entities[word] = []
140
+
141
+ # 檢查每個實體是否與當前單詞重疊
142
+ for entity in entity_spans:
143
+ if entity['start'] < word_end and entity['end'] > word_start:
144
+ word_entities[word].append(entity['entity_group'])
145
+
146
+ current_pos = word_end
147
+
148
+ # 處理每個單詞
149
+ formatted_text = []
150
+ for word in words:
151
+ if word_entities[word]:
152
+ unique_entity_groups = list(dict.fromkeys(word_entities[word])) # 去除重複的實體
153
+ entity_tags = []
154
+ for i, group in enumerate(unique_entity_groups):
155
+ entity_tag = f'<sup style="color: {get_color_for_entity(group)}; font-size: 14px;">{group}</sup>'
156
+ if i > 0: # 如果不是第一個標籤,添加逗號分隔符
157
+ entity_tags.append('<sup style="font-size: 14px;">、</sup>')
158
+ entity_tags.append(entity_tag)
159
+ formatted_word = f'<strong>{word}</strong>{"".join(entity_tags)}'
160
+ else:
161
+ formatted_word = word
162
+ formatted_text.append(formatted_word)
163
+
164
+ return ''.join(formatted_text)
165
+
166
  def transcribe_audio(audio):
167
  try:
 
168
  with open(audio, "rb") as audio_file:
169
  transcript = openai.Audio.transcribe("whisper-1", audio_file)
170
  return transcript.text
 
174
  def audio_to_search(audio):
175
  transcription = transcribe_audio(audio)
176
  search_results = search_and_format(transcription)
177
+ combined_output = f""
178
+ return combined_output, transcription
179
 
180
  # 示例問題
181
  example_queries = [
 
198
  .output-column {flex: 2;}
199
  .examples-list {display: flex; flex-wrap: wrap; gap: 10px;}
200
  .examples-list > * {flex-basis: calc(50% - 5px);}
201
+ footer {display:none !important}
202
+ .gradio-container {font-size: 16px;}
203
  """
204
 
205
  # 創建Gradio界面
206
  with gr.Blocks(css=custom_css) as iface:
207
+ gr.Markdown("# AskCTI", elem_classes=["text-3xl"])
208
+ gr.Markdown("輸入查詢或使用語音輸入問題或關鍵字查詢相關情資威脅情報,將顯示前5個最相關的結果。", elem_classes=["text-xl"])
209
 
210
  with gr.Row(equal_height=True):
211
  with gr.Column(scale=1, min_width=300):
212
+ query_input = gr.Textbox(lines=3, label="", elem_classes=["text-lg"])
213
  with gr.Row():
214
+ submit_btn = gr.Button("查詢", elem_classes=["text-lg"])
215
  audio_input = gr.Audio(type="filepath", label="語音輸入")
216
 
217
+ gr.Markdown("### 範例查詢", elem_classes=["text-xl"])
218
  for i in range(0, len(example_queries), 2):
219
  with gr.Row():
220
  for j in range(2):
221
  if i + j < len(example_queries):
222
+ gr.Button(example_queries[i+j], elem_classes=["text-lg"]).click(
223
  lambda x: x, inputs=[gr.Textbox(value=example_queries[i+j], visible=False)], outputs=[query_input]
224
  )
225
 
226
  with gr.Column(scale=2):
227
+ output = gr.HTML(elem_classes=["text-lg"])
 
228
 
229
  submit_btn.click(search_and_format, inputs=[query_input], outputs=[output])
230
+ audio_input.change(audio_to_search, inputs=[audio_input], outputs=[output, query_input])
231
 
232
  # 啟動Gradio界面
233
  iface.launch()