sickcell69 commited on
Commit
cf22cb1
1 Parent(s): a6a1e12

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -29
app.py CHANGED
@@ -1,39 +1,66 @@
1
- import gradio as gr
2
- import pandas as pd
3
- from sentence_transformers import SentenceTransformer, util
4
  import torch
 
 
 
 
5
 
6
- # 載入語義搜索模型
7
- model_checkpoint = "sickcell69/cti-semantic-search-minilm"
8
- model = SentenceTransformer(model_checkpoint)
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- # 載入數據
11
- data_path = 'labeled_cti_data.json'
12
- data = pd.read_json(data_path)
 
13
 
14
- # 載入嵌入文件
15
- embeddings_path = 'corpus_embeddings.pt'
16
- corpus_embeddings = torch.load(embeddings_path, map_location=torch.device('cpu'))
17
 
18
- def semantic_search(query):
19
- query_embedding = model.encode(query, convert_to_tensor=True)
20
- search_hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=5)
21
-
 
 
 
 
22
  results = []
23
- for hit in search_hits[0]:
24
- text = " ".join(data.iloc[hit['corpus_id']]['tokens'])
25
- results.append(f"Score: {hit['score']:.4f} - Text: {text}")
26
-
27
- return "\n".join(results)
 
 
 
 
 
 
 
 
 
28
 
 
29
  iface = gr.Interface(
30
- fn=semantic_search,
31
- inputs="text",
32
- outputs="text",
33
- title="語義搜索應用",
34
- description="輸入一個查詢,然後模型將返回最相似的結果。"
35
  )
36
 
37
- if __name__ == "__main__":
38
- #iface.launch()
39
- iface.launch(share=True) #網頁跑不出來
 
1
+ import json
2
+ import os
 
3
  import torch
4
+ from sentence_transformers import SentenceTransformer
5
+ import faiss
6
+ import numpy as np
7
+ import gradio as gr
8
 
9
+ def load_or_create_model_and_embeddings(model_name, data_file):
10
+ model_path = os.path.join(output_dir, 'saved_model')
11
+ embeddings_path = os.path.join(output_dir, 'corpus_embeddings.pt')
12
+ if os.path.exists(model_path) and os.path.exists(embeddings_path):
13
+ print("載入已保存的模型和嵌入...")
14
+ model = SentenceTransformer(model_path)
15
+ embeddings = torch.load(embeddings_path)
16
+ else:
17
+ model = SentenceTransformer(model_name)
18
+ with open(data_file, 'r', encoding='utf-8') as f:
19
+ data = json.load(f)
20
+ texts = [item['text'] for item in data]
21
+ embeddings = model.encode(texts, convert_to_tensor=True)
22
+ return model, embeddings
23
 
24
+ # 設置參數
25
+ model_name = 'sentence-transformers/all-MiniLM-L6-v2'
26
+ data_file = 'labeled_cti_data.json'
27
+ output_dir = '.'
28
 
29
+ # 載入或創建模型和嵌入
30
+ model, embeddings= load_or_create_model_and_embeddings(model_name, data_file)
 
31
 
32
+ # 創建 Faiss 索引
33
+ dimension = embeddings.shape[1]
34
+ index = faiss.IndexFlatL2(dimension)
35
+ index.add(embeddings.cpu().numpy().astype('float32'))
36
+
37
+ def semantic_search(query, top_k=3):
38
+ query_vector = model.encode([query], convert_to_tensor=True)
39
+ distances, indices = index.search(query_vector.cpu().numpy().astype('float32'), top_k)
40
  results = []
41
+ for i, idx in enumerate(indices[0]):
42
+ results.append({
43
+ 'text': texts[idx],
44
+ 'similarity_score': 1 - distances[0][i] / 2
45
+ })
46
+ return results
47
+
48
+ def search_and_format(query):
49
+ results = semantic_search(query)
50
+ formatted_results = ""
51
+ for i, result in enumerate(results, 1):
52
+ formatted_results += f"{i}. 相似度分數: {result['similarity_score']:.4f}\n"
53
+ formatted_results += f" 情一: {result['text']}\n\n"
54
+ return formatted_results
55
 
56
+ # 創建Gradio界面
57
  iface = gr.Interface(
58
+ fn=search_and_format,
59
+ inputs=gr.Textbox(lines=2, placeholder="輸入您的搜索查詢..."),
60
+ outputs=gr.Textbox(lines=10),
61
+ title="語義搜索",
62
+ description="輸入查詢以搜索相關文本。將顯示前3個最相關的結果。"
63
  )
64
 
65
+ # 啟動Gradio界面
66
+ iface.launch()