Spaces:
Runtime error
Runtime error
File size: 6,976 Bytes
b6e5241 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import sys
import torch
import random
import hashlib
import numpy as np
from tqdm import tqdm
from transformers import GPT2Tokenizer, GPT2Model, GPT2LMHeadModel
import nltk
nltk.download('punkt')
sys.path.insert(0, '.')
from critic.perturbations import get_local_neighbors_char_level, get_local_neighbors_word_level
from utils.spacy_tokenizer import spacy_tokenize_gec
model_name = 'gpt2'
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
model = GPT2LMHeadModel.from_pretrained(model_name)
model.eval()
#model.cuda()
model.cpu()
print (f'Loaded {model_name}')
def get_gpt2_loss(input_ids, attention_mask, labels):
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
lm_logits = outputs[1] #[bsize, seqlen, vocab]
if labels is not None:
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
shift_mask = attention_mask[..., 1:].contiguous()
loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
bsize, seqlen = input_ids.size()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)).view(bsize, seqlen-1)
loss = (loss * shift_mask).sum(dim=1) #[bsize, ]
return loss
MAX_LENGTH = 66
def run_gpt2(sents, cuda=False, model_name=None):
assert isinstance(sents, list)
_sents = [tokenizer.bos_token + s for s in sents]
inputs = tokenizer(_sents, return_tensors="pt", padding=True)
if inputs['input_ids'].size(1) > MAX_LENGTH:
return None
if cuda:
inputs = {k: v.cuda() for k, v in inputs.items()}
loss = get_gpt2_loss(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'], labels=inputs['input_ids'])
logps = - loss.detach().cpu()
return logps
def gpt2_critic_char_level_only(sent, verbose=1, cuda=False, fp16=True, seed='auto', n_samples=100):
return_string = []
if seed == 'auto':
seed = int(hashlib.md5(sent.encode()).hexdigest(), 16) % (2**32) #Seed must be between 0 and 2**32 - 1
if verbose > 1:
print ('seed', seed)
np.random.seed(seed); random.seed(seed)
is_good = True
for _ in range(1):
sent_perturbations = get_local_neighbors_char_level(sent, max_n_samples=n_samples)
if verbose > 1:
print ("#sent_perturbations (char-level)", len(sent_perturbations))
return_string.append(f"#sent_perturbations (char-level){len(sent_perturbations)}\n")
sents = [sent] + list(sent_perturbations)
if fp16:
with torch.cuda.amp.autocast():
logps = run_gpt2(sents, cuda)
else:
logps = run_gpt2(sents, cuda)
if logps is None:
if verbose:
print ('Invalid input. Maybe the sentence is too long.')
return_string.append('Invalid input. Maybe the sentence is too long.\n')
return None
best_idx = int(logps.argmax())
if best_idx != 0:
is_good = False
break
if verbose:
if is_good:
print ('Good! Your sentence log(p) = {:.3f}'.format(float(logps[0])))
return_string.append('Good! Your sentence log(p) = {:.3f}\n'.format(float(logps[0])))
else:
print ('Bad! Your sentence log(p) = {:.3f}'.format(float(logps[0])))
return_string.append('Bad! Your sentence log(p) = {:.3f}\n'.format(float(logps[0])))
print ('Neighbor sentence with highest log(p): {} (= {:.3f})'.format(sents[best_idx], float(logps[best_idx])))
return_string.append('Neighbor sentence with highest log(p): {} (= {:.3f})\n'.format(sents[best_idx], float(logps[best_idx])))
counter_example = None
if not is_good:
counter_example = [sents[best_idx], float(logps[best_idx])]
return is_good, float(logps[0]), counter_example
def gpt2_critic(sent, verbose=1, cuda=False, fp16=True, seed='auto', n_samples=100, word_level_mode='refine'):
return_string = []
if seed == 'auto':
seed = int(hashlib.md5(sent.encode()).hexdigest(), 16) % (2**32) #Seed must be between 0 and 2**32 - 1
if verbose > 1:
print ('seed', seed)
return_string.append(f'seed{seed}\n')
np.random.seed(seed); random.seed(seed)
sent_toked = spacy_tokenize_gec(sent)
is_good = True
for _ in range(1):
sent_perturbations_w, orig_sent = get_local_neighbors_word_level(sent_toked, max_n_samples=n_samples//2, mode=word_level_mode)
sent_perturbations_c = get_local_neighbors_char_level(orig_sent, max_n_samples=n_samples//2)
if verbose > 1:
print ("#sent_perturbations (char-level)", len(sent_perturbations_c))
return_string.append("#sent_perturbations (char-level)\n", len(sent_perturbations_c))
print ("#sent_perturbations (word-level)", len(sent_perturbations_w))
return_string.append("#sent_perturbations (word-level)\n", len(sent_perturbations_w))
sents = [orig_sent] + list(sent_perturbations_c.union(sent_perturbations_w))
if fp16:
with torch.cuda.amp.autocast():
logps = run_gpt2(sents, cuda)
else:
logps = run_gpt2(sents, cuda)
if logps is None:
if verbose:
print ('Invalid input. Maybe the sentence is too long.')
return_string.append('Invalid input. Maybe the sentence is too long.\n')
return None
best_idx = int(logps.argmax())
if best_idx != 0:
is_good = False
break
if verbose:
if is_good:
print ('Good! Your sentence log(p) = {:.3f}'.format(float(logps[0])))
return_string.append('Good! Your sentence log(p) = {:.3f}\n'.format(float(logps[0])))
else:
print ('Bad! Your sentence log(p) = {:.3f}'.format(float(logps[0])))
return_string.append('Bad! Your sentence log(p) = {:.3f}\n'.format(float(logps[0])))
print ('Neighbor sentence with highest log(p): {} (= {:.3f})'.format(sents[best_idx], float(logps[best_idx])))
return_string.append('Neighbor sentence with highest log(p): {} (= {:.3f})\n'.format(sents[best_idx], float(logps[best_idx])))
counter_example = None
if not is_good:
counter_example = [sents[best_idx], float(logps[best_idx])]
return is_good, float(logps[0]), counter_example, return_string
def main():
import streamlit as st
st.subheader('Exploring Unsupervised Grammatical Error Correction with Transformer-Based Models')
sent = st.text_input('Enter a sentence:', value="")
if sent != '':
st.markdown(f"**Sentence**: {sent}")
_,_,_,return_string = gpt2_critic(sent)
st.markdown("**Results:**")
st.write('\n'.join(return_string))
if __name__ == '__main__':
main()
|