sickcell69 commited on
Commit
f99b8f7
1 Parent(s): bbe45db

Upload gradio.py

Browse files
Files changed (1) hide show
  1. gradio.py +52 -0
gradio.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ # In[7]:
5
+
6
+
7
+ import gradio as gr
8
+ import pandas as pd
9
+ from sentence_transformers import SentenceTransformer, util
10
+ import torch
11
+
12
+ # 載入語義搜索模型
13
+ model_checkpoint = "sickcell69/cti-semantic-search-minilm"
14
+ model = SentenceTransformer(model_checkpoint)
15
+
16
+ # 載入數據
17
+ data_path = 'labeled_cti_data.json'
18
+ data = pd.read_json(data_path)
19
+
20
+ # 載入嵌入文件
21
+ embeddings_path = 'corpus_embeddings.pt'
22
+ corpus_embeddings = torch.load(embeddings_path)
23
+
24
+ def semantic_search(query):
25
+ query_embedding = model.encode(query, convert_to_tensor=True)
26
+ search_hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=5)
27
+
28
+ results = []
29
+ for hit in search_hits[0]:
30
+ text = " ".join(data.iloc[hit['corpus_id']]['tokens'])
31
+ results.append(f"Score: {hit['score']:.4f} - Text: {text}")
32
+
33
+ return "\n".join(results)
34
+
35
+ iface = gr.Interface(
36
+ fn=semantic_search,
37
+ inputs="text",
38
+ outputs="text",
39
+ title="語義搜索應用",
40
+ description="輸入一個查詢,然後模型將返回最相似的結果。"
41
+ )
42
+
43
+ if __name__ == "__main__":
44
+ #iface.launch()
45
+ iface.launch(share=True) #網頁跑不出來
46
+
47
+
48
+ # In[ ]:
49
+
50
+
51
+
52
+