Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -6,6 +6,7 @@ import faiss
|
|
6 |
import numpy as np
|
7 |
import gradio as gr
|
8 |
import openai
|
|
|
9 |
|
10 |
# 設置OpenAI API密鑰
|
11 |
openai.api_key = 'sk-zK6OrDxP5DvDdAQqnR_nEuUL3UrZf_4W7qvYj1uphjT3BlbkFJdmZAxlxUCFv92NnnMwSB15FhpmiDZSfG2QPueobSQA'
|
@@ -47,30 +48,123 @@ index.add(embeddings.cpu().numpy().astype('float32'))
|
|
47 |
def get_entity_groups(entities):
|
48 |
return list(set(entity['entity_group'] for entity in entities))
|
49 |
|
50 |
-
def
|
51 |
-
|
52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
results = []
|
54 |
-
for
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
|
|
|
|
|
|
|
|
60 |
return results
|
61 |
|
62 |
def search_and_format(query):
|
63 |
results = semantic_search(query)
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
for i, result in enumerate(results, 1):
|
66 |
-
|
67 |
-
|
68 |
-
formatted_results += f"
|
|
|
|
|
|
|
69 |
return formatted_results
|
70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
def transcribe_audio(audio):
|
72 |
try:
|
73 |
-
# 將音頻文件上傳到Whisper API
|
74 |
with open(audio, "rb") as audio_file:
|
75 |
transcript = openai.Audio.transcribe("whisper-1", audio_file)
|
76 |
return transcript.text
|
@@ -80,7 +174,8 @@ def transcribe_audio(audio):
|
|
80 |
def audio_to_search(audio):
|
81 |
transcription = transcribe_audio(audio)
|
82 |
search_results = search_and_format(transcription)
|
83 |
-
|
|
|
84 |
|
85 |
# 示例問題
|
86 |
example_queries = [
|
@@ -103,35 +198,36 @@ custom_css = """
|
|
103 |
.output-column {flex: 2;}
|
104 |
.examples-list {display: flex; flex-wrap: wrap; gap: 10px;}
|
105 |
.examples-list > * {flex-basis: calc(50% - 5px);}
|
|
|
|
|
106 |
"""
|
107 |
|
108 |
# 創建Gradio界面
|
109 |
with gr.Blocks(css=custom_css) as iface:
|
110 |
-
gr.Markdown("# AskCTI")
|
111 |
-
gr.Markdown("
|
112 |
|
113 |
with gr.Row(equal_height=True):
|
114 |
with gr.Column(scale=1, min_width=300):
|
115 |
-
query_input = gr.Textbox(lines=3, label="
|
116 |
with gr.Row():
|
117 |
-
submit_btn = gr.Button("查詢")
|
118 |
audio_input = gr.Audio(type="filepath", label="語音輸入")
|
119 |
|
120 |
-
gr.Markdown("### 範例查詢")
|
121 |
for i in range(0, len(example_queries), 2):
|
122 |
with gr.Row():
|
123 |
for j in range(2):
|
124 |
if i + j < len(example_queries):
|
125 |
-
gr.Button(example_queries[i+j]).click(
|
126 |
lambda x: x, inputs=[gr.Textbox(value=example_queries[i+j], visible=False)], outputs=[query_input]
|
127 |
)
|
128 |
|
129 |
with gr.Column(scale=2):
|
130 |
-
output = gr.
|
131 |
-
transcription_output = gr.Textbox(lines=3, label="語音轉錄結果")
|
132 |
|
133 |
submit_btn.click(search_and_format, inputs=[query_input], outputs=[output])
|
134 |
-
audio_input.change(audio_to_search, inputs=[audio_input], outputs=[output,
|
135 |
|
136 |
# 啟動Gradio界面
|
137 |
iface.launch()
|
|
|
6 |
import numpy as np
|
7 |
import gradio as gr
|
8 |
import openai
|
9 |
+
import re
|
10 |
|
11 |
# 設置OpenAI API密鑰
|
12 |
openai.api_key = 'sk-zK6OrDxP5DvDdAQqnR_nEuUL3UrZf_4W7qvYj1uphjT3BlbkFJdmZAxlxUCFv92NnnMwSB15FhpmiDZSfG2QPueobSQA'
|
|
|
48 |
def get_entity_groups(entities):
|
49 |
return list(set(entity['entity_group'] for entity in entities))
|
50 |
|
51 |
+
def get_color_for_entity(entity_group):
|
52 |
+
colors = {
|
53 |
+
'SamFile': '#EE8434', # Orange (wheel)
|
54 |
+
'Way': '#C95D63', # Indian red
|
55 |
+
'Idus': '#AE8799', # Mountbatten pink
|
56 |
+
'Tool': '#9083AE', # African Violet
|
57 |
+
'Features': '#8181B9', # Tropical indigo
|
58 |
+
'HackOrg': '#496DDB', # Royal Blue (web color)
|
59 |
+
'Purp': '#BCD8C1', # Celadon
|
60 |
+
'OffAct': '#D6DBB2', # Vanilla
|
61 |
+
'Org': '#E3D985', # Flax
|
62 |
+
'SecTeam': '#E57A44', # Orange (Crayola)
|
63 |
+
'Time': '#E3D985', # Dark purple
|
64 |
+
'Exp': '#5D76CF', # Glaucous
|
65 |
+
'Area': '#757FC1', # Another shade of blue
|
66 |
+
}
|
67 |
+
return colors.get(entity_group, '#000000') # Default to black if entity group not found
|
68 |
+
|
69 |
+
def semantic_search(query, top_k=5):
|
70 |
+
query_embedding = model.encode([query], convert_to_tensor=True)
|
71 |
+
distances, indices = index.search(query_embedding.cpu().numpy().astype('float32'), top_k)
|
72 |
+
|
73 |
results = []
|
74 |
+
for distance, idx in zip(distances[0], indices[0]):
|
75 |
+
similarity_score = 1 - distance / 2 # 將距離轉換為相似度分數
|
76 |
+
if similarity_score >= 0.45: # 只添加相似度大於等於0.3的結果
|
77 |
+
results.append({
|
78 |
+
'text': data[idx]['text'],
|
79 |
+
'entities': data[idx]['entities'],
|
80 |
+
'similarity_score': similarity_score,
|
81 |
+
'entity_groups': get_entity_groups(data[idx]['entities'])
|
82 |
+
})
|
83 |
+
|
84 |
return results
|
85 |
|
86 |
def search_and_format(query):
|
87 |
results = semantic_search(query)
|
88 |
+
|
89 |
+
if not results:
|
90 |
+
return "<div class='search-result'><p>查無相關資訊。</p></div>"
|
91 |
+
|
92 |
+
formatted_results = """
|
93 |
+
<style>
|
94 |
+
.search-result {
|
95 |
+
font-size: 24px;
|
96 |
+
line-height: 1.6;
|
97 |
+
}
|
98 |
+
.search-result h2 {
|
99 |
+
font-size: 24px;
|
100 |
+
color: #333;
|
101 |
+
}
|
102 |
+
.search-result h3 {
|
103 |
+
font-size: 24px;
|
104 |
+
color: #444;
|
105 |
+
}
|
106 |
+
.search-result p {
|
107 |
+
margin-bottom: 24px;
|
108 |
+
}
|
109 |
+
.result-separator {
|
110 |
+
border-top: 2px solid #ccc;
|
111 |
+
margin: 20px 0;
|
112 |
+
}
|
113 |
+
</style>
|
114 |
+
<div class="search-result">
|
115 |
+
"""
|
116 |
for i, result in enumerate(results, 1):
|
117 |
+
if i > 1:
|
118 |
+
formatted_results += '<div class="result-separator"></div>'
|
119 |
+
formatted_results += f"<p><strong>相似度分數:</strong> {result['similarity_score']:.4f}</p>"
|
120 |
+
formatted_results += f"<p><strong>情資:</strong> {format_text_with_entities_markdown(result['text'], result['entities'])}</p>"
|
121 |
+
formatted_results += f"<p><strong>命名實體:</strong> {'、'.join(result['entity_groups'])}</p>"
|
122 |
+
formatted_results += "</div>"
|
123 |
return formatted_results
|
124 |
|
125 |
+
def format_text_with_entities_markdown(text, entities):
|
126 |
+
# 將實體按照起始位置排序
|
127 |
+
entity_spans = sorted(entities, key=lambda x: x['start'])
|
128 |
+
|
129 |
+
# 創建一個字典來存儲每個單詞的實體
|
130 |
+
word_entities = {}
|
131 |
+
|
132 |
+
# 使用正則表達式分割文本為單詞
|
133 |
+
words = re.findall(r'\S+|\s+', text)
|
134 |
+
current_pos = 0
|
135 |
+
|
136 |
+
for word in words:
|
137 |
+
word_start = current_pos
|
138 |
+
word_end = current_pos + len(word)
|
139 |
+
word_entities[word] = []
|
140 |
+
|
141 |
+
# 檢查每個實體是否與當前單詞重疊
|
142 |
+
for entity in entity_spans:
|
143 |
+
if entity['start'] < word_end and entity['end'] > word_start:
|
144 |
+
word_entities[word].append(entity['entity_group'])
|
145 |
+
|
146 |
+
current_pos = word_end
|
147 |
+
|
148 |
+
# 處理每個單詞
|
149 |
+
formatted_text = []
|
150 |
+
for word in words:
|
151 |
+
if word_entities[word]:
|
152 |
+
unique_entity_groups = list(dict.fromkeys(word_entities[word])) # 去除重複的實體
|
153 |
+
entity_tags = []
|
154 |
+
for i, group in enumerate(unique_entity_groups):
|
155 |
+
entity_tag = f'<sup style="color: {get_color_for_entity(group)}; font-size: 14px;">{group}</sup>'
|
156 |
+
if i > 0: # 如果不是第一個標籤,添加逗號分隔符
|
157 |
+
entity_tags.append('<sup style="font-size: 14px;">、</sup>')
|
158 |
+
entity_tags.append(entity_tag)
|
159 |
+
formatted_word = f'<strong>{word}</strong>{"".join(entity_tags)}'
|
160 |
+
else:
|
161 |
+
formatted_word = word
|
162 |
+
formatted_text.append(formatted_word)
|
163 |
+
|
164 |
+
return ''.join(formatted_text)
|
165 |
+
|
166 |
def transcribe_audio(audio):
|
167 |
try:
|
|
|
168 |
with open(audio, "rb") as audio_file:
|
169 |
transcript = openai.Audio.transcribe("whisper-1", audio_file)
|
170 |
return transcript.text
|
|
|
174 |
def audio_to_search(audio):
|
175 |
transcription = transcribe_audio(audio)
|
176 |
search_results = search_and_format(transcription)
|
177 |
+
combined_output = f""
|
178 |
+
return combined_output, transcription
|
179 |
|
180 |
# 示例問題
|
181 |
example_queries = [
|
|
|
198 |
.output-column {flex: 2;}
|
199 |
.examples-list {display: flex; flex-wrap: wrap; gap: 10px;}
|
200 |
.examples-list > * {flex-basis: calc(50% - 5px);}
|
201 |
+
footer {display:none !important}
|
202 |
+
.gradio-container {font-size: 16px;}
|
203 |
"""
|
204 |
|
205 |
# 創建Gradio界面
|
206 |
with gr.Blocks(css=custom_css) as iface:
|
207 |
+
gr.Markdown("# AskCTI", elem_classes=["text-3xl"])
|
208 |
+
gr.Markdown("輸入查詢或使用語音輸入問題或關鍵字查詢相關情資威脅情報,將顯示前5個最相關的結果。", elem_classes=["text-xl"])
|
209 |
|
210 |
with gr.Row(equal_height=True):
|
211 |
with gr.Column(scale=1, min_width=300):
|
212 |
+
query_input = gr.Textbox(lines=3, label="", elem_classes=["text-lg"])
|
213 |
with gr.Row():
|
214 |
+
submit_btn = gr.Button("查詢", elem_classes=["text-lg"])
|
215 |
audio_input = gr.Audio(type="filepath", label="語音輸入")
|
216 |
|
217 |
+
gr.Markdown("### 範例查詢", elem_classes=["text-xl"])
|
218 |
for i in range(0, len(example_queries), 2):
|
219 |
with gr.Row():
|
220 |
for j in range(2):
|
221 |
if i + j < len(example_queries):
|
222 |
+
gr.Button(example_queries[i+j], elem_classes=["text-lg"]).click(
|
223 |
lambda x: x, inputs=[gr.Textbox(value=example_queries[i+j], visible=False)], outputs=[query_input]
|
224 |
)
|
225 |
|
226 |
with gr.Column(scale=2):
|
227 |
+
output = gr.HTML(elem_classes=["text-lg"])
|
|
|
228 |
|
229 |
submit_btn.click(search_and_format, inputs=[query_input], outputs=[output])
|
230 |
+
audio_input.change(audio_to_search, inputs=[audio_input], outputs=[output, query_input])
|
231 |
|
232 |
# 啟動Gradio界面
|
233 |
iface.launch()
|