Spaces:
Sleeping
Sleeping
import json | |
import os | |
import torch | |
from sentence_transformers import SentenceTransformer | |
import faiss | |
import numpy as np | |
import gradio as gr | |
import openai | |
# 設置OpenAI API密鑰 | |
openai.api_key = 'sk-zK6OrDxP5DvDdAQqnR_nEuUL3UrZf_4W7qvYj1uphjT3BlbkFJdmZAxlxUCFv92NnnMwSB15FhpmiDZSfG2QPueobSQA' | |
def load_or_create_model_and_embeddings(model_name, data_file, output_dir): | |
model_path = os.path.join(output_dir, 'saved_model') | |
embeddings_path = os.path.join(output_dir, 'corpus_embeddings.pt') | |
if os.path.exists(model_path) and os.path.exists(embeddings_path): | |
print("載入已保存的模型和嵌入...") | |
model = SentenceTransformer(model_path) | |
embeddings = torch.load(embeddings_path) | |
with open(data_file, 'r', encoding='utf-8') as f: | |
data = json.load(f) | |
else: | |
print("創建新的模型和嵌入...") | |
model = SentenceTransformer(model_name) | |
with open(data_file, 'r', encoding='utf-8') as f: | |
data = json.load(f) | |
texts = [item['text'] for item in data] | |
embeddings = model.encode(texts, convert_to_tensor=True) | |
print("保存模型和嵌入...") | |
model.save(model_path) | |
torch.save(embeddings, embeddings_path) | |
return model, embeddings, data | |
# 設置參數 | |
model_name = 'sentence-transformers/all-MiniLM-L6-v2' | |
data_file = 'labeled_cti_data.json' | |
output_dir = '.' | |
# 載入或創建模型和嵌入 | |
model, embeddings, data = load_or_create_model_and_embeddings(model_name, data_file, output_dir) | |
# 創建 Faiss 索引 | |
dimension = embeddings.shape[1] | |
index = faiss.IndexFlatL2(dimension) | |
index.add(embeddings.cpu().numpy().astype('float32')) | |
def get_entity_groups(entities): | |
return list(set(entity['entity_group'] for entity in entities)) | |
def semantic_search(query, top_k=3): | |
query_vector = model.encode([query], convert_to_tensor=True) | |
distances, indices = index.search(query_vector.cpu().numpy().astype('float32'), top_k) | |
results = [] | |
for i, idx in enumerate(indices[0]): | |
results.append({ | |
'text': data[idx]['text'], | |
'similarity_score': 1 - distances[0][i] / 2, | |
'entity_groups': get_entity_groups(data[idx]['entities']) | |
}) | |
return results | |
def search_and_format(query): | |
results = semantic_search(query) | |
formatted_results = "" | |
for i, result in enumerate(results, 1): | |
formatted_results += f"{i}. 相似度分數: {result['similarity_score']:.4f}\n" | |
formatted_results += f" 情資: {result['text']}\n" | |
formatted_results += f" 命名實體: {', '.join(result['entity_groups'])}\n\n" | |
return formatted_results | |
def transcribe_audio(audio): | |
try: | |
# 將音頻文件上傳到Whisper API | |
with open(audio, "rb") as audio_file: | |
transcript = openai.Audio.transcribe("whisper-1", audio_file) | |
return transcript.text | |
except Exception as e: | |
return f"轉錄時發生錯誤: {str(e)}" | |
def audio_to_search(audio): | |
transcription = transcribe_audio(audio) | |
search_results = search_and_format(transcription) | |
return search_results, transcription, transcription | |
# 示例問題 | |
example_queries = [ | |
"Tell me about recent cyber attacks from Russia", | |
"What APT groups are targeting Ukraine?", | |
"Explain the Log4j vulnerability", | |
"Chinese state-sponsored hacking activities", | |
"How does Ransomware-as-a-Service work?", | |
"Latest North Korean cryptocurrency thefts", | |
"Describe the SolarWinds supply chain attack", | |
"What is the Lazarus Group known for?", | |
"Common attack vectors used against critical infrastructure", | |
"Pegasus spyware capabilities and targets" | |
] | |
# 自定義 CSS | |
custom_css = """ | |
.container {display: flex; flex-direction: row;} | |
.input-column {flex: 1; padding-right: 20px;} | |
.output-column {flex: 2;} | |
.examples-list {display: flex; flex-wrap: wrap; gap: 10px;} | |
.examples-list > * {flex-basis: calc(50% - 5px);} | |
""" | |
# 創建Gradio界面 | |
with gr.Blocks(css=custom_css) as iface: | |
gr.Markdown("# AskCTI") | |
gr.Markdown("輸入查詢或使用語音輸入以查詢相關情資威脅情報,將顯示前3個最相關的結果。") | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=1, min_width=300): | |
query_input = gr.Textbox(lines=3, label="文字查詢") | |
with gr.Row(): | |
submit_btn = gr.Button("查詢") | |
audio_input = gr.Audio(type="filepath", label="語音輸入") | |
gr.Markdown("### 範例查詢") | |
for i in range(0, len(example_queries), 2): | |
with gr.Row(): | |
for j in range(2): | |
if i + j < len(example_queries): | |
gr.Button(example_queries[i+j]).click( | |
lambda x: x, inputs=[gr.Textbox(value=example_queries[i+j], visible=False)], outputs=[query_input] | |
) | |
with gr.Column(scale=2): | |
output = gr.Textbox(lines=20, label="查詢結果") | |
transcription_output = gr.Textbox(lines=3, label="語音轉錄結果") | |
submit_btn.click(search_and_format, inputs=[query_input], outputs=[output]) | |
audio_input.change(audio_to_search, inputs=[audio_input], outputs=[output, transcription_output, query_input]) | |
# 啟動Gradio界面 | |
iface.launch() |