teticio commited on
Commit
95aad66
·
1 Parent(s): 1bcd89b

first commit

Browse files
Files changed (2) hide show
  1. .gitignore +5 -0
  2. app.py +151 -0
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ .ipynb_checkpoints
2
+ .vscode
3
+ .venv
4
+ poetry.lock
5
+ pyproject.toml
app.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import gradio as gr
4
+ from nltk import sent_tokenize
5
+ from transformers import RobertaTokenizer, RobertaForMaskedLM
6
+
7
+ cuda = torch.cuda.is_available()
8
+
9
+ tokenizer = RobertaTokenizer.from_pretrained("roberta-large")
10
+ model = RobertaForMaskedLM.from_pretrained("roberta-large")
11
+ if cuda:
12
+ model = model.cuda()
13
+
14
+ max_len = 20
15
+ top_k = 100
16
+ temperature = 1
17
+ burnin = 250
18
+ max_iter = 500
19
+
20
+
21
+ # adapted from https://github.com/nyu-dl/bert-gen
22
+ def generate_step(out,
23
+ gen_idx,
24
+ temperature=None,
25
+ top_k=0,
26
+ sample=False,
27
+ return_list=True):
28
+ """ Generate a word from from out[gen_idx]
29
+
30
+ args:
31
+ - out (torch.Tensor): tensor of logits of size batch_size x seq_len x vocab_size
32
+ - gen_idx (int): location for which to generate for
33
+ - top_k (int): if >0, only sample from the top k most probable words
34
+ - sample (Bool): if True, sample from full distribution. Overridden by top_k
35
+ """
36
+ logits = out.logits[:, gen_idx]
37
+ if temperature is not None:
38
+ logits = logits / temperature
39
+ if top_k > 0:
40
+ kth_vals, kth_idx = logits.topk(top_k, dim=-1)
41
+ dist = torch.distributions.categorical.Categorical(logits=kth_vals)
42
+ idx = kth_idx.gather(dim=1,
43
+ index=dist.sample().unsqueeze(-1)).squeeze(-1)
44
+ elif sample:
45
+ dist = torch.distributions.categorical.Categorical(logits=logits)
46
+ idx = dist.sample() # removed superfluous squeeze(-1)
47
+ else:
48
+ idx = torch.argmax(logits, dim=-1)
49
+ return idx.tolist() if return_list else idx
50
+
51
+
52
+ # adapted from https://github.com/nyu-dl/bert-gen
53
+ def parallel_sequential_generation(seed_text,
54
+ seed_end_text,
55
+ max_len=max_len,
56
+ top_k=top_k,
57
+ temperature=temperature,
58
+ max_iter=max_iter,
59
+ burnin=burnin):
60
+ """ Generate for one random position at a timestep
61
+
62
+ args:
63
+ - burnin: during burn-in period, sample from full distribution; afterwards take argmax
64
+ """
65
+ inp = tokenizer(seed_text + tokenizer.mask_token * max_len + seed_end_text,
66
+ return_tensors='pt')
67
+ masked_tokens = np.where(
68
+ inp['input_ids'][0].numpy() == tokenizer.mask_token_id)[0]
69
+ seed_len = masked_tokens[0]
70
+ if cuda:
71
+ inp = inp.to('cuda')
72
+
73
+ for ii in range(max_iter):
74
+ kk = np.random.randint(0, max_len)
75
+ out = model(**inp)
76
+ topk = top_k if (ii >= burnin) else 0
77
+ idxs = generate_step(out,
78
+ gen_idx=seed_len + kk,
79
+ top_k=topk,
80
+ temperature=temperature,
81
+ sample=(ii < burnin))
82
+ inp['input_ids'][0][seed_len + kk] = idxs[0]
83
+
84
+ tokens = inp['input_ids'].cpu().numpy()[0][masked_tokens]
85
+ tokens = tokens[(np.where((tokens != tokenizer.eos_token_id)
86
+ & (tokens != tokenizer.bos_token_id)))]
87
+ return tokenizer.decode(tokens)
88
+
89
+
90
+ def inbertolate(doc,
91
+ max_len=15,
92
+ top_k=0,
93
+ temperature=None,
94
+ max_iter=300,
95
+ burnin=200):
96
+ new_doc = ''
97
+ paras = doc.split('\n')
98
+
99
+ for para in paras:
100
+ para = sent_tokenize(para)
101
+ if para == '':
102
+ new_doc += '\n'
103
+ continue
104
+ para += ['']
105
+
106
+ for sentence in range(len(para) - 1):
107
+ new_doc += para[sentence] + ' '
108
+ new_doc += parallel_sequential_generation(para[sentence],
109
+ para[sentence + 1],
110
+ max_len=max_len,
111
+ top_k=top_k,
112
+ temperature=temperature,
113
+ burnin=burnin,
114
+ max_iter=max_iter) + ' '
115
+
116
+ new_doc += '\n'
117
+ return new_doc
118
+
119
+
120
+ if __name__ == '__main__':
121
+ block = gr.Blocks(css='.container')
122
+ with block:
123
+ gr.Markdown("<h1><center>inBERTolate</center></h1>")
124
+ gr.Markdown(
125
+ "<center>Hit your word count by using BERT to pad out your essays!</center>"
126
+ )
127
+ gr.Interface(
128
+ fn=inbertolate,
129
+ inputs=[
130
+ gr.Textbox(label="Text", lines=7),
131
+ gr.Slider(label="Maximum length to insert between sentences",
132
+ minimum=1,
133
+ maximum=40,
134
+ step=1,
135
+ value=max_len),
136
+ gr.Slider(label="Top k", minimum=0, maximum=200, value=top_k),
137
+ gr.Slider(label="Temperature",
138
+ minimum=0,
139
+ maximum=2,
140
+ value=temperature),
141
+ gr.Slider(label="Maximum iterations",
142
+ minimum=0,
143
+ maximum=1000,
144
+ value=max_iter),
145
+ gr.Slider(label="Burn-in",
146
+ minimum=0,
147
+ maximum=500,
148
+ value=burnin),
149
+ ],
150
+ outputs=gr.Textbox(label="Expanded text", lines=24))
151
+ block.launch(server_name='0.0.0.0')