Spaces:
Sleeping
Sleeping
File size: 4,483 Bytes
cf22cb1 f99b8f7 cf22cb1 f99b8f7 6767ad2 cf22cb1 6767ad2 cf22cb1 6767ad2 cf22cb1 6767ad2 f99b8f7 cf22cb1 f99b8f7 cf22cb1 6767ad2 f99b8f7 cf22cb1 6767ad2 cf22cb1 f99b8f7 cf22cb1 6767ad2 04c1d54 cf22cb1 04c1d54 cf22cb1 6767ad2 04c1d54 6767ad2 fc7d349 cf22cb1 6767ad2 04c1d54 721cbd2 04c1d54 6767ad2 04c1d54 f99b8f7 cf22cb1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import json
import os
import torch
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import gradio as gr
def load_or_create_model_and_embeddings(model_name, data_file, output_dir):
model_path = os.path.join(output_dir, 'saved_model')
embeddings_path = os.path.join(output_dir, 'corpus_embeddings.pt')
if os.path.exists(model_path) and os.path.exists(embeddings_path):
print("載入已保存的模型和嵌入...")
model = SentenceTransformer(model_path)
embeddings = torch.load(embeddings_path)
with open(data_file, 'r', encoding='utf-8') as f:
data = json.load(f)
else:
print("創建新的模型和嵌入...")
model = SentenceTransformer(model_name)
with open(data_file, 'r', encoding='utf-8') as f:
data = json.load(f)
texts = [item['text'] for item in data]
embeddings = model.encode(texts, convert_to_tensor=True)
print("保存模型和嵌入...")
model.save(model_path)
torch.save(embeddings, embeddings_path)
return model, embeddings, data
# 設置參數
model_name = 'sentence-transformers/all-MiniLM-L6-v2'
data_file = 'labeled_cti_data.json'
output_dir = '.'
# 載入或創建模型和嵌入
model, embeddings, data = load_or_create_model_and_embeddings(model_name, data_file, output_dir)
# 創建 Faiss 索引
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(embeddings.cpu().numpy().astype('float32'))
def get_entity_groups(entities):
return list(set(entity['entity_group'] for entity in entities))
def semantic_search(query, top_k=3):
query_vector = model.encode([query], convert_to_tensor=True)
distances, indices = index.search(query_vector.cpu().numpy().astype('float32'), top_k)
results = []
for i, idx in enumerate(indices[0]):
results.append({
'text': data[idx]['text'],
'similarity_score': 1 - distances[0][i] / 2,
'entity_groups': get_entity_groups(data[idx]['entities'])
})
return results
def search_and_format(query):
results = semantic_search(query)
formatted_results = ""
for i, result in enumerate(results, 1):
formatted_results += f"{i}. 相似度分數: {result['similarity_score']:.4f}\n"
formatted_results += f" 情資: {result['text']}\n"
formatted_results += f" 命名實體: {', '.join(result['entity_groups'])}\n\n"
return formatted_results
# 示例問題
example_queries = [
"Tell me about recent cyber attacks from Russia",
"What APT groups are targeting Ukraine?",
"Explain the Log4j vulnerability",
"Chinese state-sponsored hacking activities",
"How does Ransomware-as-a-Service work?",
"Latest North Korean cryptocurrency thefts",
"Describe the SolarWinds supply chain attack",
"What is the Lazarus Group known for?",
"Common attack vectors used against critical infrastructure",
"Pegasus spyware capabilities and targets"
]
# 自定義 CSS
custom_css = """
.container {display: flex; flex-direction: row;}
.input-column {flex: 1; padding-right: 20px;}
.output-column {flex: 2;}
.examples-list {display: flex; flex-wrap: wrap; gap: 10px;}
.examples-list > * {flex-basis: calc(50% - 5px);}
"""
# 創建Gradio界面
with gr.Blocks(css=custom_css) as iface:
gr.Markdown("# AskCTI")
gr.Markdown("輸入查詢以搜索相關威脅情報,將顯示前3個最相關的結果,包括實體組。")
with gr.Row(equal_height=True):
with gr.Column(scale=1, min_width=300):
query_input = gr.Textbox(lines=3, label="")
submit_btn = gr.Button("查詢")
#clear_btn = gr.Button("清除")
gr.Markdown("### 範例查詢")
for i in range(0, len(example_queries), 2):
with gr.Row():
for j in range(2):
if i + j < len(example_queries):
gr.Button(example_queries[i+j]).click(
lambda x: x, inputs=[gr.Textbox(value=example_queries[i+j], visible=False)], outputs=[query_input]
)
with gr.Column(scale=2):
output = gr.Textbox(lines=20, label="")
submit_btn.click(search_and_format, inputs=[query_input], outputs=[output])
#clear_btn.click(lambda: "", outputs=[query_input])
# 啟動Gradio界面
iface.launch() |