Spaces:
Runtime error
Runtime error
import sys | |
import torch | |
import random | |
import hashlib | |
import numpy as np | |
from tqdm import tqdm | |
from transformers import GPT2Tokenizer, GPT2Model, GPT2LMHeadModel | |
import nltk | |
nltk.download('punkt') | |
sys.path.insert(0, '.') | |
from critic.perturbations import get_local_neighbors_char_level, get_local_neighbors_word_level | |
from utils.spacy_tokenizer import spacy_tokenize_gec | |
model_name = 'gpt2' | |
tokenizer = GPT2Tokenizer.from_pretrained(model_name) | |
tokenizer.pad_token = tokenizer.eos_token | |
model = GPT2LMHeadModel.from_pretrained(model_name) | |
model.eval() | |
#model.cuda() | |
model.cpu() | |
print (f'Loaded {model_name}') | |
def get_gpt2_loss(input_ids, attention_mask, labels): | |
with torch.no_grad(): | |
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) | |
lm_logits = outputs[1] #[bsize, seqlen, vocab] | |
if labels is not None: | |
shift_logits = lm_logits[..., :-1, :].contiguous() | |
shift_labels = labels[..., 1:].contiguous() | |
shift_mask = attention_mask[..., 1:].contiguous() | |
loss_fct = torch.nn.CrossEntropyLoss(reduction='none') | |
bsize, seqlen = input_ids.size() | |
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)).view(bsize, seqlen-1) | |
loss = (loss * shift_mask).sum(dim=1) #[bsize, ] | |
return loss | |
MAX_LENGTH = 66 | |
def run_gpt2(sents, cuda=False, model_name=None): | |
assert isinstance(sents, list) | |
_sents = [tokenizer.bos_token + s for s in sents] | |
inputs = tokenizer(_sents, return_tensors="pt", padding=True) | |
if inputs['input_ids'].size(1) > MAX_LENGTH: | |
return None | |
if cuda: | |
inputs = {k: v.cuda() for k, v in inputs.items()} | |
loss = get_gpt2_loss(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'], labels=inputs['input_ids']) | |
logps = - loss.detach().cpu() | |
return logps | |
def gpt2_critic_char_level_only(sent, verbose=1, cuda=False, fp16=True, seed='auto', n_samples=100): | |
return_string = [] | |
if seed == 'auto': | |
seed = int(hashlib.md5(sent.encode()).hexdigest(), 16) % (2**32) #Seed must be between 0 and 2**32 - 1 | |
if verbose > 1: | |
print ('seed', seed) | |
np.random.seed(seed); random.seed(seed) | |
is_good = True | |
for _ in range(1): | |
sent_perturbations = get_local_neighbors_char_level(sent, max_n_samples=n_samples) | |
if verbose > 1: | |
print ("#sent_perturbations (char-level)", len(sent_perturbations)) | |
return_string.append(f"#sent_perturbations (char-level){len(sent_perturbations)}\n") | |
sents = [sent] + list(sent_perturbations) | |
if fp16: | |
with torch.cuda.amp.autocast(): | |
logps = run_gpt2(sents, cuda) | |
else: | |
logps = run_gpt2(sents, cuda) | |
if logps is None: | |
if verbose: | |
print ('Invalid input. Maybe the sentence is too long.') | |
return_string.append('Invalid input. Maybe the sentence is too long.\n') | |
return None | |
best_idx = int(logps.argmax()) | |
if best_idx != 0: | |
is_good = False | |
break | |
if verbose: | |
if is_good: | |
print ('Good! Your sentence log(p) = {:.3f}'.format(float(logps[0]))) | |
return_string.append('Good! Your sentence log(p) = {:.3f}\n'.format(float(logps[0]))) | |
else: | |
print ('Bad! Your sentence log(p) = {:.3f}'.format(float(logps[0]))) | |
return_string.append('Bad! Your sentence log(p) = {:.3f}\n'.format(float(logps[0]))) | |
print ('Neighbor sentence with highest log(p): {} (= {:.3f})'.format(sents[best_idx], float(logps[best_idx]))) | |
return_string.append('Neighbor sentence with highest log(p): {} (= {:.3f})\n'.format(sents[best_idx], float(logps[best_idx]))) | |
counter_example = None | |
if not is_good: | |
counter_example = [sents[best_idx], float(logps[best_idx])] | |
return is_good, float(logps[0]), counter_example | |
def gpt2_critic(sent, verbose=1, cuda=False, fp16=True, seed='auto', n_samples=100, word_level_mode='refine'): | |
return_string = [] | |
if seed == 'auto': | |
seed = int(hashlib.md5(sent.encode()).hexdigest(), 16) % (2**32) #Seed must be between 0 and 2**32 - 1 | |
if verbose > 1: | |
print ('seed', seed) | |
return_string.append(f'seed{seed}\n') | |
np.random.seed(seed); random.seed(seed) | |
sent_toked = spacy_tokenize_gec(sent) | |
is_good = True | |
for _ in range(1): | |
sent_perturbations_w, orig_sent = get_local_neighbors_word_level(sent_toked, max_n_samples=n_samples//2, mode=word_level_mode) | |
sent_perturbations_c = get_local_neighbors_char_level(orig_sent, max_n_samples=n_samples//2) | |
if verbose > 1: | |
print ("#sent_perturbations (char-level)", len(sent_perturbations_c)) | |
return_string.append("#sent_perturbations (char-level)\n", len(sent_perturbations_c)) | |
print ("#sent_perturbations (word-level)", len(sent_perturbations_w)) | |
return_string.append("#sent_perturbations (word-level)\n", len(sent_perturbations_w)) | |
sents = [orig_sent] + list(sent_perturbations_c.union(sent_perturbations_w)) | |
if fp16: | |
with torch.cuda.amp.autocast(): | |
logps = run_gpt2(sents, cuda) | |
else: | |
logps = run_gpt2(sents, cuda) | |
if logps is None: | |
if verbose: | |
print ('Invalid input. Maybe the sentence is too long.') | |
return_string.append('Invalid input. Maybe the sentence is too long.\n') | |
return None | |
best_idx = int(logps.argmax()) | |
if best_idx != 0: | |
is_good = False | |
break | |
if verbose: | |
if is_good: | |
print ('Good! Your sentence log(p) = {:.3f}'.format(float(logps[0]))) | |
return_string.append('Good! Your sentence log(p) = {:.3f}\n'.format(float(logps[0]))) | |
else: | |
print ('Bad! Your sentence log(p) = {:.3f}'.format(float(logps[0]))) | |
return_string.append('Bad! Your sentence log(p) = {:.3f}\n'.format(float(logps[0]))) | |
print ('Neighbor sentence with highest log(p): {} (= {:.3f})'.format(sents[best_idx], float(logps[best_idx]))) | |
return_string.append('Neighbor sentence with highest log(p): {} (= {:.3f})\n'.format(sents[best_idx], float(logps[best_idx]))) | |
counter_example = None | |
if not is_good: | |
counter_example = [sents[best_idx], float(logps[best_idx])] | |
return is_good, float(logps[0]), counter_example, return_string | |
def main(): | |
import streamlit as st | |
st.subheader('Exploring Unsupervised Grammatical Error Correction with Transformer-Based Models') | |
sent = st.text_input('Enter a sentence:', value="") | |
if sent != '': | |
st.markdown(f"**Sentence**: {sent}") | |
_,_,_,return_string = gpt2_critic(sent) | |
st.markdown("**Results:**") | |
st.write('\n'.join(return_string)) | |
if __name__ == '__main__': | |
main() | |