File size: 1,863 Bytes
edd4815 d654474 dd2409d d654474 dd2409d 8d09ff7 d654474 dd2409d d654474 0e97d35 dd2409d d654474 0e97d35 d654474 dd2409d 8d09ff7 0e97d35 d654474 dd2409d 0e97d35 8d09ff7 dd2409d 0e97d35 613e689 0e97d35 d654474 8d09ff7 0e97d35 d654474 edd4815 d654474 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
from typing import List
import pandas as pd
from sentence_transformers.util import cos_sim
from utils.models import ModelWithPooling
def p0_originality(df: pd.DataFrame, model_name: str, pooling: str) -> pd.DataFrame:
"""
row-wise
:param df:
:param model_name:
:return:
"""
assert 'prompt' in df.columns
assert 'response' in df.columns
model = ModelWithPooling(model_name)
def get_cos_sim(prompt: str, response: str) -> float:
prompt_vec = model(text=prompt, pooling=pooling)
response_vec = model(text=response, pooling=pooling)
score = cos_sim(prompt_vec, response_vec).item()
return score
df['originality'] = df.apply(lambda x: 1 - get_cos_sim(x['prompt'], x['response']), axis=1)
return df
def p1_flexibility(df: pd.DataFrame, model_name: str, pooling: str) -> pd.DataFrame:
"""
group-wise
:param df:
:param model_name:
:return:
"""
assert 'prompt' in df.columns
assert 'response' in df.columns
assert 'id' in df.columns
model = ModelWithPooling(model_name)
def get_flexibility(responses: List[str]) -> float:
responses_vec = [model(text=_, pooling=pooling) for _ in responses]
score = 0
for i in range(len(responses_vec) - 1):
score += 1 - cos_sim(responses_vec[i], responses_vec[i + 1]).item()
return score
df_out = df.groupby(by=['id', 'prompt']) \
.agg({'id': 'first', 'prompt': 'first', 'response': get_flexibility}) \
.rename(columns={'response': 'flexibility'}) \
.reset_index(drop=True)
return df_out
if __name__ == '__main__':
_df_input = pd.read_csv('data/tmp/example_3.csv')
_df_0 = p0_originality(_df_input, 'paraphrase-multilingual-MiniLM-L12-v2')
_df_1 = p1_flexibility(_df_input, 'paraphrase-multilingual-MiniLM-L12-v2')
|