Spaces:
Sleeping
Sleeping
import json | |
import os | |
import torch | |
from sentence_transformers import SentenceTransformer | |
import faiss | |
import numpy as np | |
import gradio as gr | |
import openai | |
import re | |
openai.api_key = 'sk-zK6OrDxP5DvDdAQqnR_nEuUL3UrZf_4W7qvYj1uphjT3BlbkFJdmZAxlxUCFv92NnnMwSB15FhpmiDZSfG2QPueobSQA' | |
def load_or_create_model_and_embeddings(model_name, data_file, output_dir): | |
model_path = os.path.join(output_dir, 'saved_model') | |
embeddings_path = os.path.join(output_dir, 'corpus_embeddings.pt') | |
if os.path.exists(model_path) and os.path.exists(embeddings_path): | |
model = SentenceTransformer(model_path) | |
embeddings = torch.load(embeddings_path) | |
with open(data_file, 'r', encoding='utf-8') as f: | |
data = json.load(f) | |
else: | |
model = SentenceTransformer(model_name) | |
with open(data_file, 'r', encoding='utf-8') as f: | |
data = json.load(f) | |
texts = [item['text'] for item in data] | |
embeddings = model.encode(texts, convert_to_tensor=True) | |
model.save(model_path) | |
torch.save(embeddings, embeddings_path) | |
return model, embeddings, data | |
model_name = 'sentence-transformers/all-MiniLM-L6-v2' | |
data_file = 'labeled_cti_data.json' | |
output_dir = '.' | |
model, embeddings, data = load_or_create_model_and_embeddings(model_name, data_file, output_dir) | |
dimension = embeddings.shape[1] | |
index = faiss.IndexFlatL2(dimension) | |
index.add(embeddings.cpu().numpy().astype('float32')) | |
def get_entity_groups(entities): | |
return list(set(entity['entity_group'] for entity in entities)) | |
def get_color_for_entity(entity_group): | |
colors = { | |
'SamFile': '#EE8434', # Orange (wheel) | |
'Way': '#C95D63', # Indian red | |
'Idus': '#AE8799', # Mountbatten pink | |
'Tool': '#9083AE', # African Violet | |
'Features': '#8181B9', # Tropical indigo | |
'HackOrg': '#496DDB', # Royal Blue (web color) | |
'Purp': '#BCD8C1', # Celadon | |
'OffAct': '#D6DBB2', # Vanilla | |
'Org': '#E3D985', # Flax | |
'SecTeam': '#E57A44', # Orange (Crayola) | |
'Time': '#E3D985', # Dark purple | |
'Exp': '#5D76CF', # Glaucous | |
'Area': '#757FC1', # Another shade of blue | |
} | |
return colors.get(entity_group, '#000000') # Default to black if entity group not found | |
def semantic_search(query, top_k=5): | |
query_embedding = model.encode([query], convert_to_tensor=True) | |
distances, indices = index.search(query_embedding.cpu().numpy().astype('float32'), top_k) | |
results = [] | |
for distance, idx in zip(distances[0], indices[0]): | |
similarity_score = 1 - distance / 2 # 將距離轉換為相似度分數 | |
if similarity_score >= 0.45: # 只添加相似度大於等於0.45的結果 | |
results.append({ | |
'text': data[idx]['text'], | |
'entities': data[idx]['entities'], | |
'similarity_score': similarity_score, | |
'entity_groups': get_entity_groups(data[idx]['entities']) | |
}) | |
return results | |
def search_and_format(query): | |
results = semantic_search(query) | |
if not results: | |
return "<div class='search-result'><p>查無相關資訊。</p></div>" | |
formatted_results = """ | |
<style> | |
.search-result { | |
font-size: 24px; | |
line-height: 1.6; | |
} | |
.search-result h2 { | |
font-size: 24px; | |
color: #333; | |
} | |
.search-result h3 { | |
font-size: 24px; | |
color: #444; | |
} | |
.search-result p { | |
margin-bottom: 24px; | |
} | |
.result-separator { | |
border-top: 2px solid #ccc; | |
margin: 20px 0; | |
} | |
</style> | |
<div class="search-result"> | |
""" | |
for i, result in enumerate(results, 1): | |
if i > 1: | |
formatted_results += '<div class="result-separator"></div>' | |
formatted_results += f"<p><strong>相似度分數:</strong> {result['similarity_score']:.4f}</p>" | |
formatted_results += f"<p><strong>情資:</strong> {format_text_with_entities_markdown(result['text'], result['entities'])}</p>" | |
formatted_results += f"<p><strong>命名實體:</strong> {'、'.join(result['entity_groups'])}</p>" | |
formatted_results += "</div>" | |
return formatted_results | |
def format_text_with_entities_markdown(text, entities): | |
# 將實體按照起始位置排序 | |
entity_spans = sorted(entities, key=lambda x: x['start']) | |
# 創建一個字典來存儲每個單詞的實體 | |
word_entities = {} | |
# 使用正則表達式分割文本為單詞 | |
words = re.findall(r'\S+|\s+', text) | |
current_pos = 0 | |
for word in words: | |
word_start = current_pos | |
word_end = current_pos + len(word) | |
word_entities[word] = [] | |
# 檢查每個實體是否與當前單詞重疊 | |
for entity in entity_spans: | |
if entity['start'] < word_end and entity['end'] > word_start: | |
word_entities[word].append(entity['entity_group']) | |
current_pos = word_end | |
# 處理每個單詞 | |
formatted_text = [] | |
for word in words: | |
if word_entities[word]: | |
unique_entity_groups = list(dict.fromkeys(word_entities[word])) # 去除重複的實體 | |
entity_tags = [] | |
for i, group in enumerate(unique_entity_groups): | |
entity_tag = f'<sup style="color: {get_color_for_entity(group)}; font-size: 14px;">{group}</sup>' | |
if i > 0: # 如果不是第一個標籤,添加逗號分隔符 | |
entity_tags.append('<sup style="font-size: 14px;">、</sup>') | |
entity_tags.append(entity_tag) | |
formatted_word = f'<strong>{word}</strong>{"".join(entity_tags)}' | |
else: | |
formatted_word = word | |
formatted_text.append(formatted_word) | |
return ''.join(formatted_text) | |
def transcribe_audio(audio): | |
try: | |
with open(audio, "rb") as audio_file: | |
transcript = openai.Audio.transcribe("whisper-1", audio_file) | |
return transcript.text | |
except Exception as e: | |
return f"轉錄時發生錯誤: {str(e)}" | |
def audio_to_search(audio): | |
transcription = transcribe_audio(audio) | |
search_results = search_and_format(transcription) | |
combined_output = f"" | |
return combined_output, transcription | |
# 範例問題 | |
example_queries = [ | |
"Tell me about recent cyber attacks from Russia", | |
"What APT groups are targeting Ukraine?", | |
"Explain the Log4j vulnerability", | |
"Chinese state-sponsored hacking activities", | |
"Toilet?", | |
"Latest North Korean hacker", | |
"Describe the SolarWinds supply chain attack", | |
"What is the Lazarus Group known for?", | |
"Common attack vectors used against critical infrastructure", | |
"pls rick roll me" | |
] | |
# 自定義 CSS | |
custom_css = """ | |
.container {display: flex; flex-direction: row;} | |
.input-column {flex: 1; padding-right: 20px;} | |
.output-column {flex: 2;} | |
.examples-list {display: flex; flex-wrap: wrap; gap: 10px;} | |
.examples-list > * {flex-basis: calc(50% - 5px);} | |
footer {display:none !important} | |
.gradio-container {font-size: 16px;} | |
""" | |
# 創建Gradio界面 | |
with gr.Blocks(css=custom_css) as iface: | |
gr.Markdown("# AskCTI", elem_classes=["text-3xl"]) | |
gr.Markdown("使用文字或使用語音輸入問題或關鍵字查詢相關情資威脅情報,結果顯示前 5 個最相關的結果。", elem_classes=["text-xl"]) | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=1, min_width=300): | |
query_input = gr.Textbox(lines=3, label="", elem_classes=["text-lg"]) | |
with gr.Row(): | |
submit_btn = gr.Button("查詢", elem_classes=["text-lg"]) | |
audio_input = gr.Audio(type="filepath", label="語音輸入") | |
#audio_input = gr.Audio(sources="microphone", label="錄音", elem_classes="small-button") | |
gr.Markdown("### 範例查詢", elem_classes=["text-xl"]) | |
for i in range(0, len(example_queries), 2): | |
with gr.Row(): | |
for j in range(2): | |
if i + j < len(example_queries): | |
gr.Button(example_queries[i+j], elem_classes=["text-lg"]).click( | |
lambda x: x, inputs=[gr.Textbox(value=example_queries[i+j], visible=False)], outputs=[query_input] | |
) | |
with gr.Column(scale=2): | |
output = gr.HTML(elem_classes=["text-lg"]) | |
submit_btn.click(search_and_format, inputs=[query_input], outputs=[output]) | |
audio_input.change(audio_to_search, inputs=[audio_input], outputs=[output, query_input]) | |
# 啟動Gradio界面 | |
iface.launch() |