Qifan Zhang
commited on
Commit
·
cf575f8
1
Parent(s):
f32101e
add feature 1
Browse files- .gitignore +3 -0
- app.py +53 -0
- output.csv +13 -0
- utils/models.py +16 -0
- utils/similarity.py +25 -0
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
data
|
2 |
+
.idea
|
3 |
+
*.csv
|
app.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
import pandas as pd
|
5 |
+
|
6 |
+
from utils.similarity import batch_cos_sim
|
7 |
+
|
8 |
+
|
9 |
+
def read_data(filepath: str) -> Optional[pd.DataFrame]:
|
10 |
+
if not filepath:
|
11 |
+
return None
|
12 |
+
if filepath.endswith('.xlsx'):
|
13 |
+
df = pd.read_csv(filepath)
|
14 |
+
elif filepath.endswith('.csv'):
|
15 |
+
df = pd.read_csv(filepath)
|
16 |
+
else:
|
17 |
+
raise Exception('File type not supported')
|
18 |
+
return df
|
19 |
+
|
20 |
+
|
21 |
+
def process(model_name: str,
|
22 |
+
prompt: str,
|
23 |
+
file=None,
|
24 |
+
):
|
25 |
+
df = read_data(file.name)
|
26 |
+
df = batch_cos_sim(df, model_name)
|
27 |
+
path = 'output.csv'
|
28 |
+
df.to_csv(path, index=False, encoding='utf-8-sig')
|
29 |
+
return df.to_markdown(), path
|
30 |
+
|
31 |
+
|
32 |
+
model_name_input = gr.components.Textbox(
|
33 |
+
value='paraphrase-multilingual-MiniLM-L12-v2',
|
34 |
+
lines=1,
|
35 |
+
type="text"
|
36 |
+
)
|
37 |
+
|
38 |
+
prompt_input = gr.components.Textbox(
|
39 |
+
value='prompt,response',
|
40 |
+
lines=10,
|
41 |
+
type="text"
|
42 |
+
)
|
43 |
+
|
44 |
+
file_output = gr.components.File(label="Output File",
|
45 |
+
file_count="single",
|
46 |
+
file_types=["", ".", ".csv", ".xls", ".xlsx"])
|
47 |
+
|
48 |
+
app = gr.Interface(
|
49 |
+
fn=process,
|
50 |
+
inputs=[model_name_input, prompt_input, "file" ],
|
51 |
+
outputs=["text", file_output]
|
52 |
+
)
|
53 |
+
app.launch()
|
output.csv
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
prompt,response,originality
|
2 |
+
床单,当空调被,0.6427325010299683
|
3 |
+
床单,保暖,0.5928247570991516
|
4 |
+
床单,绑在树上做成吊床,0.5714011490345001
|
5 |
+
床单,当燃料烧,0.7625655382871628
|
6 |
+
床单,包裹东西,0.41448450088500977
|
7 |
+
床单,裁剪成衣服,0.5791812241077423
|
8 |
+
牙刷,用来刷首饰,0.5138461589813232
|
9 |
+
牙刷,刷鞋,0.5954866111278534
|
10 |
+
牙刷,洗水果,0.6339634656906128
|
11 |
+
牙刷,捅人,0.5337955951690674
|
12 |
+
牙刷,洗马桶,0.5022678673267365
|
13 |
+
牙刷,刷桃子的毛,0.6439318358898163
|
utils/models.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import lru_cache
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from sentence_transformers import SentenceTransformer
|
5 |
+
import numpy as np
|
6 |
+
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
7 |
+
|
8 |
+
|
9 |
+
class SBert:
|
10 |
+
def __init__(self, path):
|
11 |
+
self.model = SentenceTransformer(path, device=DEVICE)
|
12 |
+
|
13 |
+
@lru_cache(maxsize=10000)
|
14 |
+
def __call__(self, x) -> np.ndarray:
|
15 |
+
y = self.model.encode(x)
|
16 |
+
return y
|
utils/similarity.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
from sentence_transformers.util import cos_sim
|
3 |
+
|
4 |
+
from utils.models import SBert
|
5 |
+
|
6 |
+
|
7 |
+
def get_cos_sim(model, prompt: str, response: str) -> float:
|
8 |
+
prompt_vec = model(prompt)
|
9 |
+
response_vec = model(response)
|
10 |
+
score = cos_sim(prompt_vec, response_vec).item()
|
11 |
+
return score
|
12 |
+
|
13 |
+
|
14 |
+
def batch_cos_sim(df: pd.DataFrame, model_name) -> pd.DataFrame:
|
15 |
+
# df.columns = ['prompt', 'response']
|
16 |
+
assert 'prompt' in df.columns
|
17 |
+
assert 'response' in df.columns
|
18 |
+
model = SBert(model_name)
|
19 |
+
df['originality'] = df.apply(lambda x: 1 - get_cos_sim(model, x['prompt'], x['response']), axis=1)
|
20 |
+
return df
|
21 |
+
|
22 |
+
|
23 |
+
if __name__ == '__main__':
|
24 |
+
_df = pd.read_csv('data/example_1.csv')
|
25 |
+
_df_o = batch_cos_sim(_df, 'paraphrase-multilingual-MiniLM-L12-v2')
|