Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -6,20 +6,26 @@ import faiss
|
|
6 |
import numpy as np
|
7 |
import gradio as gr
|
8 |
|
9 |
-
def load_or_create_model_and_embeddings(model_name, data_file):
|
10 |
model_path = os.path.join(output_dir, 'saved_model')
|
11 |
embeddings_path = os.path.join(output_dir, 'corpus_embeddings.pt')
|
12 |
if os.path.exists(model_path) and os.path.exists(embeddings_path):
|
13 |
print("載入已保存的模型和嵌入...")
|
14 |
model = SentenceTransformer(model_path)
|
15 |
embeddings = torch.load(embeddings_path)
|
|
|
|
|
16 |
else:
|
|
|
17 |
model = SentenceTransformer(model_name)
|
18 |
with open(data_file, 'r', encoding='utf-8') as f:
|
19 |
data = json.load(f)
|
20 |
texts = [item['text'] for item in data]
|
21 |
embeddings = model.encode(texts, convert_to_tensor=True)
|
22 |
-
|
|
|
|
|
|
|
23 |
|
24 |
# 設置參數
|
25 |
model_name = 'sentence-transformers/all-MiniLM-L6-v2'
|
@@ -27,21 +33,25 @@ data_file = 'labeled_cti_data.json'
|
|
27 |
output_dir = '.'
|
28 |
|
29 |
# 載入或創建模型和嵌入
|
30 |
-
model, embeddings= load_or_create_model_and_embeddings(model_name, data_file)
|
31 |
|
32 |
# 創建 Faiss 索引
|
33 |
dimension = embeddings.shape[1]
|
34 |
index = faiss.IndexFlatL2(dimension)
|
35 |
index.add(embeddings.cpu().numpy().astype('float32'))
|
36 |
|
|
|
|
|
|
|
37 |
def semantic_search(query, top_k=3):
|
38 |
query_vector = model.encode([query], convert_to_tensor=True)
|
39 |
distances, indices = index.search(query_vector.cpu().numpy().astype('float32'), top_k)
|
40 |
results = []
|
41 |
for i, idx in enumerate(indices[0]):
|
42 |
results.append({
|
43 |
-
'text':
|
44 |
-
'similarity_score': 1 - distances[0][i] / 2
|
|
|
45 |
})
|
46 |
return results
|
47 |
|
@@ -50,17 +60,57 @@ def search_and_format(query):
|
|
50 |
formatted_results = ""
|
51 |
for i, result in enumerate(results, 1):
|
52 |
formatted_results += f"{i}. 相似度分數: {result['similarity_score']:.4f}\n"
|
53 |
-
formatted_results += f"
|
|
|
54 |
return formatted_results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
# 創建Gradio界面
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
# 啟動Gradio界面
|
66 |
iface.launch()
|
|
|
6 |
import numpy as np
|
7 |
import gradio as gr
|
8 |
|
9 |
+
def load_or_create_model_and_embeddings(model_name, data_file, output_dir):
|
10 |
model_path = os.path.join(output_dir, 'saved_model')
|
11 |
embeddings_path = os.path.join(output_dir, 'corpus_embeddings.pt')
|
12 |
if os.path.exists(model_path) and os.path.exists(embeddings_path):
|
13 |
print("載入已保存的模型和嵌入...")
|
14 |
model = SentenceTransformer(model_path)
|
15 |
embeddings = torch.load(embeddings_path)
|
16 |
+
with open(data_file, 'r', encoding='utf-8') as f:
|
17 |
+
data = json.load(f)
|
18 |
else:
|
19 |
+
print("創建新的模型和嵌入...")
|
20 |
model = SentenceTransformer(model_name)
|
21 |
with open(data_file, 'r', encoding='utf-8') as f:
|
22 |
data = json.load(f)
|
23 |
texts = [item['text'] for item in data]
|
24 |
embeddings = model.encode(texts, convert_to_tensor=True)
|
25 |
+
print("保存模型和嵌入...")
|
26 |
+
model.save(model_path)
|
27 |
+
torch.save(embeddings, embeddings_path)
|
28 |
+
return model, embeddings, data
|
29 |
|
30 |
# 設置參數
|
31 |
model_name = 'sentence-transformers/all-MiniLM-L6-v2'
|
|
|
33 |
output_dir = '.'
|
34 |
|
35 |
# 載入或創建模型和嵌入
|
36 |
+
model, embeddings, data = load_or_create_model_and_embeddings(model_name, data_file, output_dir)
|
37 |
|
38 |
# 創建 Faiss 索引
|
39 |
dimension = embeddings.shape[1]
|
40 |
index = faiss.IndexFlatL2(dimension)
|
41 |
index.add(embeddings.cpu().numpy().astype('float32'))
|
42 |
|
43 |
+
def get_entity_groups(entities):
|
44 |
+
return list(set(entity['entity_group'] for entity in entities))
|
45 |
+
|
46 |
def semantic_search(query, top_k=3):
|
47 |
query_vector = model.encode([query], convert_to_tensor=True)
|
48 |
distances, indices = index.search(query_vector.cpu().numpy().astype('float32'), top_k)
|
49 |
results = []
|
50 |
for i, idx in enumerate(indices[0]):
|
51 |
results.append({
|
52 |
+
'text': data[idx]['text'],
|
53 |
+
'similarity_score': 1 - distances[0][i] / 2,
|
54 |
+
'entity_groups': get_entity_groups(data[idx]['entities'])
|
55 |
})
|
56 |
return results
|
57 |
|
|
|
60 |
formatted_results = ""
|
61 |
for i, result in enumerate(results, 1):
|
62 |
formatted_results += f"{i}. 相似度分數: {result['similarity_score']:.4f}\n"
|
63 |
+
formatted_results += f" 情資: {result['text']}\n"
|
64 |
+
formatted_results += f" 實體組: {', '.join(result['entity_groups'])}\n\n"
|
65 |
return formatted_results
|
66 |
+
# 示例問題
|
67 |
+
example_queries = [
|
68 |
+
"Tell me about recent cyber attacks from Russia",
|
69 |
+
"What APT groups are targeting Ukraine?",
|
70 |
+
"Explain the Log4j vulnerability",
|
71 |
+
"Chinese state-sponsored hacking activities",
|
72 |
+
"How does Ransomware-as-a-Service work?",
|
73 |
+
"Latest North Korean cryptocurrency thefts",
|
74 |
+
"Describe the SolarWinds supply chain attack",
|
75 |
+
"What is the Lazarus Group known for?",
|
76 |
+
"Common attack vectors used against critical infrastructure",
|
77 |
+
"Pegasus spyware capabilities and targets"
|
78 |
+
]
|
79 |
+
|
80 |
+
# 自定義 CSS
|
81 |
+
custom_css = """
|
82 |
+
.container {display: flex; flex-direction: row;}
|
83 |
+
.input-column {flex: 1; padding-right: 20px;}
|
84 |
+
.output-column {flex: 2;}
|
85 |
+
.examples-list {display: flex; flex-wrap: wrap; gap: 10px;}
|
86 |
+
.examples-list > * {flex-basis: calc(50% - 5px);}
|
87 |
+
"""
|
88 |
|
89 |
# 創建Gradio界面
|
90 |
+
with gr.Blocks(css=custom_css) as iface:
|
91 |
+
gr.Markdown("# AskCTI")
|
92 |
+
gr.Markdown("輸入查詢以搜索相關威脅情報,將顯示前3個最相關的結果,包括實體組。")
|
93 |
+
|
94 |
+
with gr.Row(equal_height=True):
|
95 |
+
with gr.Column(scale=1, min_width=300):
|
96 |
+
query_input = gr.Textbox(lines=3, label="")
|
97 |
+
submit_btn = gr.Button("查詢")
|
98 |
+
#clear_btn = gr.Button("清除")
|
99 |
+
|
100 |
+
gr.Markdown("### 範例查詢")
|
101 |
+
for i in range(0, len(example_queries), 2):
|
102 |
+
with gr.Row():
|
103 |
+
for j in range(2):
|
104 |
+
if i + j < len(example_queries):
|
105 |
+
gr.Button(example_queries[i+j]).click(
|
106 |
+
lambda x: x, inputs=[gr.Textbox(value=example_queries[i+j], visible=False)], outputs=[query_input]
|
107 |
+
)
|
108 |
+
|
109 |
+
with gr.Column(scale=2):
|
110 |
+
output = gr.Textbox(lines=20, label="")
|
111 |
+
|
112 |
+
submit_btn.click(search_and_format, inputs=[query_input], outputs=[output])
|
113 |
+
#clear_btn.click(lambda: "", outputs=[query_input])
|
114 |
|
115 |
# 啟動Gradio界面
|
116 |
iface.launch()
|