sickcell commited on
Commit
6767ad2
1 Parent(s): 10f1b9f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -13
app.py CHANGED
@@ -6,20 +6,26 @@ import faiss
6
  import numpy as np
7
  import gradio as gr
8
 
9
- def load_or_create_model_and_embeddings(model_name, data_file):
10
  model_path = os.path.join(output_dir, 'saved_model')
11
  embeddings_path = os.path.join(output_dir, 'corpus_embeddings.pt')
12
  if os.path.exists(model_path) and os.path.exists(embeddings_path):
13
  print("載入已保存的模型和嵌入...")
14
  model = SentenceTransformer(model_path)
15
  embeddings = torch.load(embeddings_path)
 
 
16
  else:
 
17
  model = SentenceTransformer(model_name)
18
  with open(data_file, 'r', encoding='utf-8') as f:
19
  data = json.load(f)
20
  texts = [item['text'] for item in data]
21
  embeddings = model.encode(texts, convert_to_tensor=True)
22
- return model, embeddings
 
 
 
23
 
24
  # 設置參數
25
  model_name = 'sentence-transformers/all-MiniLM-L6-v2'
@@ -27,21 +33,25 @@ data_file = 'labeled_cti_data.json'
27
  output_dir = '.'
28
 
29
  # 載入或創建模型和嵌入
30
- model, embeddings= load_or_create_model_and_embeddings(model_name, data_file)
31
 
32
  # 創建 Faiss 索引
33
  dimension = embeddings.shape[1]
34
  index = faiss.IndexFlatL2(dimension)
35
  index.add(embeddings.cpu().numpy().astype('float32'))
36
 
 
 
 
37
  def semantic_search(query, top_k=3):
38
  query_vector = model.encode([query], convert_to_tensor=True)
39
  distances, indices = index.search(query_vector.cpu().numpy().astype('float32'), top_k)
40
  results = []
41
  for i, idx in enumerate(indices[0]):
42
  results.append({
43
- 'text': texts[idx],
44
- 'similarity_score': 1 - distances[0][i] / 2
 
45
  })
46
  return results
47
 
@@ -50,17 +60,57 @@ def search_and_format(query):
50
  formatted_results = ""
51
  for i, result in enumerate(results, 1):
52
  formatted_results += f"{i}. 相似度分數: {result['similarity_score']:.4f}\n"
53
- formatted_results += f" 情一: {result['text']}\n\n"
 
54
  return formatted_results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  # 創建Gradio界面
57
- iface = gr.Interface(
58
- fn=search_and_format,
59
- inputs=gr.Textbox(lines=2, placeholder="輸入您的搜索查詢..."),
60
- outputs=gr.Textbox(lines=10),
61
- title="語義搜索",
62
- description="輸入查詢以搜索相關文本。將顯示前3個最相關的結果。"
63
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  # 啟動Gradio界面
66
  iface.launch()
 
6
  import numpy as np
7
  import gradio as gr
8
 
9
+ def load_or_create_model_and_embeddings(model_name, data_file, output_dir):
10
  model_path = os.path.join(output_dir, 'saved_model')
11
  embeddings_path = os.path.join(output_dir, 'corpus_embeddings.pt')
12
  if os.path.exists(model_path) and os.path.exists(embeddings_path):
13
  print("載入已保存的模型和嵌入...")
14
  model = SentenceTransformer(model_path)
15
  embeddings = torch.load(embeddings_path)
16
+ with open(data_file, 'r', encoding='utf-8') as f:
17
+ data = json.load(f)
18
  else:
19
+ print("創建新的模型和嵌入...")
20
  model = SentenceTransformer(model_name)
21
  with open(data_file, 'r', encoding='utf-8') as f:
22
  data = json.load(f)
23
  texts = [item['text'] for item in data]
24
  embeddings = model.encode(texts, convert_to_tensor=True)
25
+ print("保存模型和嵌入...")
26
+ model.save(model_path)
27
+ torch.save(embeddings, embeddings_path)
28
+ return model, embeddings, data
29
 
30
  # 設置參數
31
  model_name = 'sentence-transformers/all-MiniLM-L6-v2'
 
33
  output_dir = '.'
34
 
35
  # 載入或創建模型和嵌入
36
+ model, embeddings, data = load_or_create_model_and_embeddings(model_name, data_file, output_dir)
37
 
38
  # 創建 Faiss 索引
39
  dimension = embeddings.shape[1]
40
  index = faiss.IndexFlatL2(dimension)
41
  index.add(embeddings.cpu().numpy().astype('float32'))
42
 
43
+ def get_entity_groups(entities):
44
+ return list(set(entity['entity_group'] for entity in entities))
45
+
46
  def semantic_search(query, top_k=3):
47
  query_vector = model.encode([query], convert_to_tensor=True)
48
  distances, indices = index.search(query_vector.cpu().numpy().astype('float32'), top_k)
49
  results = []
50
  for i, idx in enumerate(indices[0]):
51
  results.append({
52
+ 'text': data[idx]['text'],
53
+ 'similarity_score': 1 - distances[0][i] / 2,
54
+ 'entity_groups': get_entity_groups(data[idx]['entities'])
55
  })
56
  return results
57
 
 
60
  formatted_results = ""
61
  for i, result in enumerate(results, 1):
62
  formatted_results += f"{i}. 相似度分數: {result['similarity_score']:.4f}\n"
63
+ formatted_results += f" 情資: {result['text']}\n"
64
+ formatted_results += f" 實體組: {', '.join(result['entity_groups'])}\n\n"
65
  return formatted_results
66
+ # 示例問題
67
+ example_queries = [
68
+ "Tell me about recent cyber attacks from Russia",
69
+ "What APT groups are targeting Ukraine?",
70
+ "Explain the Log4j vulnerability",
71
+ "Chinese state-sponsored hacking activities",
72
+ "How does Ransomware-as-a-Service work?",
73
+ "Latest North Korean cryptocurrency thefts",
74
+ "Describe the SolarWinds supply chain attack",
75
+ "What is the Lazarus Group known for?",
76
+ "Common attack vectors used against critical infrastructure",
77
+ "Pegasus spyware capabilities and targets"
78
+ ]
79
+
80
+ # 自定義 CSS
81
+ custom_css = """
82
+ .container {display: flex; flex-direction: row;}
83
+ .input-column {flex: 1; padding-right: 20px;}
84
+ .output-column {flex: 2;}
85
+ .examples-list {display: flex; flex-wrap: wrap; gap: 10px;}
86
+ .examples-list > * {flex-basis: calc(50% - 5px);}
87
+ """
88
 
89
  # 創建Gradio界面
90
+ with gr.Blocks(css=custom_css) as iface:
91
+ gr.Markdown("# AskCTI")
92
+ gr.Markdown("輸入查詢以搜索相關威脅情報,將顯示前3個最相關的結果,包括實體組。")
93
+
94
+ with gr.Row(equal_height=True):
95
+ with gr.Column(scale=1, min_width=300):
96
+ query_input = gr.Textbox(lines=3, label="")
97
+ submit_btn = gr.Button("查詢")
98
+ #clear_btn = gr.Button("清除")
99
+
100
+ gr.Markdown("### 範例查詢")
101
+ for i in range(0, len(example_queries), 2):
102
+ with gr.Row():
103
+ for j in range(2):
104
+ if i + j < len(example_queries):
105
+ gr.Button(example_queries[i+j]).click(
106
+ lambda x: x, inputs=[gr.Textbox(value=example_queries[i+j], visible=False)], outputs=[query_input]
107
+ )
108
+
109
+ with gr.Column(scale=2):
110
+ output = gr.Textbox(lines=20, label="")
111
+
112
+ submit_btn.click(search_and_format, inputs=[query_input], outputs=[output])
113
+ #clear_btn.click(lambda: "", outputs=[query_input])
114
 
115
  # 啟動Gradio界面
116
  iface.launch()