AskCTI / app.py
sickcell69
Update app.py
2678b8b verified
raw
history blame
1.53 kB
#!/usr/bin/env python
# coding: utf-8
# In[7]:
import gradio as gr
import pandas as pd
from sentence_transformers import SentenceTransformer, util
import torch
# 載入語義搜索模型
model_checkpoint = "sickcell69/cti-semantic-search-minilm"
#model_checkpoint = "sickcell69/bert-finetuned-ner"
model = SentenceTransformer(model_checkpoint)
# 載入數據
data_path = 'labeled_cti_data.json'
data = pd.read_json(data_path)
# 載入嵌入文件
embeddings_path = 'corpus_embeddings.pt'
corpus_embeddings = torch.load(embeddings_path, map_location=torch.device('cpu'))
def semantic_search(query):
print("Data columns:", data.columns)
print("First few rows:", data.head())
query_embedding = model.encode(query, convert_to_tensor=True)
search_hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=3)
results = []
for hit in search_hits[0]:
# 使用更安全的方法來訪問數據
row = data.iloc[hit['corpus_id']]
if 'tokens' in row:
text = " ".join(row['tokens'])
else:
text = str(row) # 如果沒有 'tokens',就轉換整行為字符串
results.append((hit['score'], text))
return results
iface = gr.Interface(
fn=semantic_search,
inputs="text",
outputs="text",
title="語義搜索應用",
description="輸入一個查詢,然後模型將返回最相似的結果。"
)
if __name__ == "__main__":
#iface.launch()
iface.launch(share=True) #網頁跑不出來
# In[ ]: