sickcell commited on
Commit
c7bf51f
1 Parent(s): 390a2cf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -10
app.py CHANGED
@@ -8,39 +8,35 @@ import gradio as gr
8
  import openai
9
  import re
10
 
11
- # 設置OpenAI API密鑰
12
  openai.api_key = 'sk-zK6OrDxP5DvDdAQqnR_nEuUL3UrZf_4W7qvYj1uphjT3BlbkFJdmZAxlxUCFv92NnnMwSB15FhpmiDZSfG2QPueobSQA'
13
 
14
  def load_or_create_model_and_embeddings(model_name, data_file, output_dir):
15
  model_path = os.path.join(output_dir, 'saved_model')
16
  embeddings_path = os.path.join(output_dir, 'corpus_embeddings.pt')
17
  if os.path.exists(model_path) and os.path.exists(embeddings_path):
18
- print("載入已保存的模型和嵌入...")
19
  model = SentenceTransformer(model_path)
20
  embeddings = torch.load(embeddings_path)
21
  with open(data_file, 'r', encoding='utf-8') as f:
22
  data = json.load(f)
23
  else:
24
- print("創建新的模型和嵌入...")
25
  model = SentenceTransformer(model_name)
26
  with open(data_file, 'r', encoding='utf-8') as f:
27
  data = json.load(f)
28
  texts = [item['text'] for item in data]
29
  embeddings = model.encode(texts, convert_to_tensor=True)
30
- print("保存模型和嵌入...")
31
  model.save(model_path)
32
  torch.save(embeddings, embeddings_path)
33
  return model, embeddings, data
34
 
35
- # 設置參數
36
  model_name = 'sentence-transformers/all-MiniLM-L6-v2'
37
  data_file = 'labeled_cti_data.json'
38
  output_dir = '.'
39
 
40
- # 載入或創建模型和嵌入
41
  model, embeddings, data = load_or_create_model_and_embeddings(model_name, data_file, output_dir)
42
 
43
- # 創建 Faiss 索引
44
  dimension = embeddings.shape[1]
45
  index = faiss.IndexFlatL2(dimension)
46
  index.add(embeddings.cpu().numpy().astype('float32'))
@@ -73,7 +69,7 @@ def semantic_search(query, top_k=5):
73
  results = []
74
  for distance, idx in zip(distances[0], indices[0]):
75
  similarity_score = 1 - distance / 2 # 將距離轉換為相似度分數
76
- if similarity_score >= 0.45: # 只添加相似度大於等於0.3的結果
77
  results.append({
78
  'text': data[idx]['text'],
79
  'entities': data[idx]['entities'],
@@ -177,7 +173,7 @@ def audio_to_search(audio):
177
  combined_output = f""
178
  return combined_output, transcription
179
 
180
- # 示例問題
181
  example_queries = [
182
  "Tell me about recent cyber attacks from Russia",
183
  "What APT groups are targeting Ukraine?",
@@ -205,7 +201,7 @@ footer {display:none !important}
205
  # 創建Gradio界面
206
  with gr.Blocks(css=custom_css) as iface:
207
  gr.Markdown("# AskCTI", elem_classes=["text-3xl"])
208
- gr.Markdown("使用文字或使用語音輸入問題或關鍵字查詢相關情資威脅情報,結果將顯示前5個最相關的結果。", elem_classes=["text-xl"])
209
 
210
  with gr.Row(equal_height=True):
211
  with gr.Column(scale=1, min_width=300):
 
8
  import openai
9
  import re
10
 
 
11
  openai.api_key = 'sk-zK6OrDxP5DvDdAQqnR_nEuUL3UrZf_4W7qvYj1uphjT3BlbkFJdmZAxlxUCFv92NnnMwSB15FhpmiDZSfG2QPueobSQA'
12
 
13
  def load_or_create_model_and_embeddings(model_name, data_file, output_dir):
14
  model_path = os.path.join(output_dir, 'saved_model')
15
  embeddings_path = os.path.join(output_dir, 'corpus_embeddings.pt')
16
  if os.path.exists(model_path) and os.path.exists(embeddings_path):
 
17
  model = SentenceTransformer(model_path)
18
  embeddings = torch.load(embeddings_path)
19
  with open(data_file, 'r', encoding='utf-8') as f:
20
  data = json.load(f)
21
  else:
 
22
  model = SentenceTransformer(model_name)
23
  with open(data_file, 'r', encoding='utf-8') as f:
24
  data = json.load(f)
25
  texts = [item['text'] for item in data]
26
  embeddings = model.encode(texts, convert_to_tensor=True)
 
27
  model.save(model_path)
28
  torch.save(embeddings, embeddings_path)
29
  return model, embeddings, data
30
 
31
+
32
  model_name = 'sentence-transformers/all-MiniLM-L6-v2'
33
  data_file = 'labeled_cti_data.json'
34
  output_dir = '.'
35
 
36
+
37
  model, embeddings, data = load_or_create_model_and_embeddings(model_name, data_file, output_dir)
38
 
39
+
40
  dimension = embeddings.shape[1]
41
  index = faiss.IndexFlatL2(dimension)
42
  index.add(embeddings.cpu().numpy().astype('float32'))
 
69
  results = []
70
  for distance, idx in zip(distances[0], indices[0]):
71
  similarity_score = 1 - distance / 2 # 將距離轉換為相似度分數
72
+ if similarity_score >= 0.45: # 只添加相似度大於等於0.45的結果
73
  results.append({
74
  'text': data[idx]['text'],
75
  'entities': data[idx]['entities'],
 
173
  combined_output = f""
174
  return combined_output, transcription
175
 
176
+ # 範例問題
177
  example_queries = [
178
  "Tell me about recent cyber attacks from Russia",
179
  "What APT groups are targeting Ukraine?",
 
201
  # 創建Gradio界面
202
  with gr.Blocks(css=custom_css) as iface:
203
  gr.Markdown("# AskCTI", elem_classes=["text-3xl"])
204
+ gr.Markdown("使用文字或使用語音輸入問題或關鍵字查詢相關情資威脅情報,結果將顯示前 5 個最相關的結果。", elem_classes=["text-xl"])
205
 
206
  with gr.Row(equal_height=True):
207
  with gr.Column(scale=1, min_width=300):