Spaces:
Runtime error
Runtime error
import re | |
import os | |
import sys | |
import json | |
import numpy as np | |
import editdistance | |
from tqdm import tqdm | |
from collections import Counter | |
sys.path.insert(0, '.') | |
from utils.text_utils import detokenize_sent | |
from critic.critic import run_gpt2, gpt2_critic | |
def load_data(): | |
data_path = 'eval_critic/eval_data.jsonl' | |
good_sents, bad_sents = [], [] | |
for line in open(data_path): | |
obj = json.loads(line) | |
good_sents.append(obj['good']) | |
bad_sents.append(obj['bad']) | |
return good_sents, bad_sents | |
good_sents, bad_sents = load_data() | |
def get_logps(sents): | |
final = [] | |
for start in tqdm(range(0, len(sents), 100)): | |
sents_sub = sents[start: start+100] | |
sents_sub_detok = [detokenize_sent(sent) for sent in sents_sub] | |
logps = run_gpt2(sents_sub_detok) | |
assert logps is not None | |
for i in range(len(sents_sub)): | |
final.append({'sent': sents_sub[i], 'sent_detok': sents_sub_detok[i], 'logp': float(logps[i])}) | |
return final | |
def evaluate_logp(): | |
""" | |
Check whether log p(bad_sent) < log p(good_sent) | |
""" | |
good_logps = get_logps(good_sents) | |
bad_logps = get_logps(bad_sents) | |
accs = [] | |
for good, bad in zip(good_logps, bad_logps): | |
accs.append(int(bad['logp'] < good['logp'])) | |
avg_acc = float(sum(accs))/len(accs) | |
print (f'log p(bad) < log p(good)? {sum(accs)} / {len(accs)} = {avg_acc:.3f}') | |
return good_logps, bad_logps | |
good_logps, bad_logps = evaluate_logp() | |
# log p(bad) < log p(good)? 555 / 586 = 0.947 | |
def compute_metrics(good_accs, bad_accs): | |
goodP = float(sum(good_accs))/(len(bad_accs)-sum(bad_accs)+sum(good_accs)) | |
goodR = float(sum(good_accs))/len(good_accs) | |
goodF05 = (1+0.5**2) * float(goodP * goodR)/((0.5**2 * goodP) + goodR) | |
badP = float(sum(bad_accs))/(len(good_accs)-sum(good_accs)+sum(bad_accs)) | |
badR = float(sum(bad_accs))/len(bad_accs) | |
badF05 = (1+0.5**2) * float(badP * badR)/((0.5**2 * badP) + badR) | |
print (f' Good precision = {sum(good_accs)} / {(len(bad_accs)-sum(bad_accs)+sum(good_accs))} = {goodP:.3f}') | |
print (f' Good recall = {sum(good_accs)} / {len(good_accs)} = {goodR:.3f}') | |
print (f' Good F0.5 = {goodF05:.3f}') | |
print (f' Bad precision = {sum(bad_accs)} / {(len(good_accs)-sum(good_accs)+sum(bad_accs))} = {badP:.3f}') | |
print (f' Bad recall = {sum(bad_accs)} / {len(bad_accs)} = {badR:.3f}') | |
print (f' Bad F0.5 = {badF05:.3f}') | |
return {'goodP': goodP, 'goodR': goodR, 'goodF05': goodF05, 'badP': badP, 'badR': badR, 'badF05': badF05} | |
def evaluate_baseline_critic(): | |
threshold = np.mean([elm['logp'] for elm in good_logps + bad_logps]) | |
good_accs, bad_accs = [], [] | |
for obj in good_logps: | |
pred = int(obj['logp'] > threshold) | |
good_accs.append(pred==1) | |
for obj in bad_logps: | |
pred = int(obj['logp'] > threshold) | |
bad_accs.append(pred==0) | |
print ('\nBaseline critic:') | |
stats = compute_metrics(good_accs, bad_accs) | |
json.dump(stats, open('baseline_critic.stats.json', 'w'), indent=2) | |
evaluate_baseline_critic() | |
# Baseline critic: | |
# Good precision = 365 / 668 = 0.546 | |
# Good recall = 365 / 586 = 0.623 | |
# Good F0.5 = 0.560 | |
# Bad precision = 283 / 504 = 0.562 | |
# Bad recall = 283 / 586 = 0.483 | |
# Bad F0.5 = 0.544 | |
def evaluate_LM_Critic(): | |
good_accs, bad_accs = [], [] | |
for obj in tqdm(good_logps): | |
res = gpt2_critic(obj['sent_detok'], verbose=0, seed=1, n_samples=100, word_level_mode='refine') | |
pred = int(res[0]) | |
good_accs.append(pred==1) | |
for obj in tqdm(bad_logps): | |
res = gpt2_critic(obj['sent_detok'], verbose=0, seed=1, n_samples=100, word_level_mode='refine') | |
pred = int(res[0]) | |
bad_accs.append(pred==0) | |
print ('\nLM-Critic:') | |
stats = compute_metrics(good_accs, bad_accs) | |
json.dump(stats, open('lm_critic.stats.json', 'w'), indent=2) | |
evaluate_LM_Critic() | |
# LM-Critic: (there is variance due to the randomness of sampling, some variation in GPT2 return score) | |
# Good precision = 446 / 654 = 0.682 | |
# Good recall = 446 / 586 = 0.761 | |
# Good F0.5 = 0.696 | |
# Bad precision = 378 / 518 = 0.730 | |
# Bad recall = 378 / 586 = 0.645 | |
# Bad F0.5 = 0.711 | |