File size: 7,167 Bytes
312011f
 
f40e7fa
 
7bb535d
95aad66
 
 
 
6425ca8
 
 
 
 
 
 
b0cd7be
6425ca8
95aad66
7bb535d
 
9529307
 
 
 
 
95aad66
 
 
 
6425ca8
95aad66
 
 
 
 
6425ca8
 
 
 
 
 
95aad66
 
 
 
3d84bdd
95aad66
6425ca8
 
 
 
 
 
95aad66
 
ec1ff5d
6425ca8
3d84bdd
ec1ff5d
3d84bdd
6425ca8
 
 
3d84bdd
ec1ff5d
6425ca8
 
 
 
95aad66
6425ca8
 
 
95aad66
 
 
6425ca8
 
 
 
 
 
 
 
 
95aad66
6425ca8
 
 
 
 
 
 
95aad66
6425ca8
 
 
95aad66
 
 
 
 
 
9529307
95aad66
 
 
6425ca8
95aad66
6425ca8
95aad66
6425ca8
95aad66
 
 
 
 
 
 
 
 
6425ca8
 
 
 
 
 
f049f5e
6425ca8
 
 
 
 
 
 
 
 
 
 
 
 
 
95aad66
 
 
 
 
 
 
 
 
 
 
 
6425ca8
 
 
 
 
 
 
 
 
95aad66
 
 
 
d1cf2b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95aad66
 
f40e7fa
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
# force update

import argparse

import nltk
import torch
import numpy as np
import gradio as gr
from nltk import sent_tokenize

from transformers import (
    RobertaTokenizer,
    RobertaForMaskedLM,
    LogitsProcessorList,
    TopKLogitsWarper,
    TemperatureLogitsWarper,
    TypicalLogitsWarper,
)

nltk.download('punkt')

device = "cuda" if torch.cuda.is_available() else "cpu"
pretrained = "roberta-large" if device == "cuda" else "roberta-base"
tokenizer = RobertaTokenizer.from_pretrained(pretrained)
model = RobertaForMaskedLM.from_pretrained(pretrained)
model = model.to(device)

max_len = 20
top_k = 100
temperature = 1
typical_p = 0
burnin = 250
max_iter = 500


# adapted from https://github.com/nyu-dl/bert-gen
def generate_step(out: object,
                  gen_idx: int,
                  top_k: int = top_k,
                  temperature: float = temperature,
                  typical_p: float = typical_p,
                  sample: bool = False) -> list:
    """ Generate a word from from out[gen_idx]
    
    args:
        - out (torch.Tensor): tensor of logits of size batch_size x seq_len x vocab_size
        - gen_idx (int): location for which to generate
        - top_k (int): if >0, only sample from the top k most probable words
        - temperature (float): sampling temperature
        - typical_p (float): if >0 use typical sampling
        - sample (bool): if True, sample from full distribution.
    
    returns:
        - list: batch_size tokens
    """
    logits = out.logits[:, gen_idx]
    warpers = LogitsProcessorList()
    if temperature:
        warpers.append(TemperatureLogitsWarper(temperature))
    if top_k > 0:
        warpers.append(TopKLogitsWarper(top_k))
    if typical_p > 0:
        if typical_p >= 1:
            typical_p = 0.999
        warpers.append(TypicalLogitsWarper(typical_p))
    logits = warpers(None, logits)

    if sample:
        probs = torch.nn.functional.softmax(logits, dim=-1)
        next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
    else:
        next_tokens = torch.argmax(logits, dim=-1)

    return next_tokens.tolist()


# adapted from https://github.com/nyu-dl/bert-gen
def parallel_sequential_generation(seed_text: str,
                                   seed_end_text: str,
                                   max_len: int = max_len,
                                   top_k: int = top_k,
                                   temperature: float = temperature,
                                   typical_p: float = typical_p,
                                   max_iter: int = max_iter,
                                   burnin: int = burnin) -> str:
    """ Generate text consistent with preceding and following text
    
    Args:
        - seed_text (str): preceding text
        - seed_end_text (str): following text
        - top_k (int): if >0, only sample from the top k most probable words
        - temperature (float): sampling temperature
        - typical_p (float): if >0 use typical sampling
        - max_iter (int): number of iterations in MCMC
        - burnin: during burn-in period, sample from full distribution; afterwards take argmax

    Returns:
        - string: generated text to insert between seed_text and seed_end_text
    """
    inp = tokenizer(seed_text + tokenizer.mask_token * max_len + seed_end_text,
                    return_tensors='pt')
    masked_tokens = np.where(
        inp['input_ids'][0].numpy() == tokenizer.mask_token_id)[0]
    seed_len = masked_tokens[0]
    inp = inp.to(device)

    for ii in range(max_iter):
        kk = np.random.randint(0, max_len)
        idxs = generate_step(model(**inp),
                             gen_idx=seed_len + kk,
                             top_k=top_k if (ii >= burnin) else 0,
                             temperature=temperature,
                             typical_p=typical_p,
                             sample=(ii < burnin))
        inp['input_ids'][0][seed_len + kk] = idxs[0]

    tokens = inp['input_ids'].cpu().numpy()[0][masked_tokens]
    tokens = tokens[(np.where((tokens != tokenizer.eos_token_id)
                              & (tokens != tokenizer.bos_token_id)))]
    return tokenizer.decode(tokens)


def inbertolate(doc: str,
                max_len: int = max_len,
                top_k: int = top_k,
                temperature: float = temperature,
                typical_p: float = typical_p,
                max_iter: int = max_iter,
                burnin: int = burnin) -> str:
    """ Pad out document generating every other sentence
    
    Args:
        - doc (str): document text
        - max_len (int): number of tokens to insert between sentences
        - top_k (int): if >0, only sample from the top k most probable words
        - temperature (float): sampling temperature
        - typical_p (float): if >0 use typical sampling
        - max_iter (int): number of iterations in MCMC
        - burnin: during burn-in period, sample from full distribution; afterwards take argmax

    Returns:
        - string: generated text to insert between seed_text and seed_end_text
    """
    new_doc = ''
    paras = doc.split('\n')

    for para in paras:
        para = sent_tokenize(para)
        if para == '':
            new_doc += '\n'
            continue
        para += ['']

        for sentence in range(len(para) - 1):
            new_doc += para[sentence] + ' '
            new_doc += parallel_sequential_generation(
                para[sentence],
                para[sentence + 1],
                max_len=max_len,
                top_k=top_k,
                temperature=float(temperature),
                typical_p=typical_p,
                burnin=burnin,
                max_iter=max_iter) + ' '

        new_doc += '\n'
    return new_doc

demo = gr.Interface(
    fn=inbertolate,
    title="inBERTolate",
    description=f"Hit your word count by using BERT ({pretrained}) to pad out your essays!",
    inputs=[
        gr.Textbox(label="Text", lines=10),
        gr.Slider(label="Maximum length to insert between sentences",
                    minimum=1,
                    maximum=40,
                    step=1,
                    value=max_len),
        gr.Slider(label="Top k", minimum=0, maximum=200, value=top_k),
        gr.Slider(label="Temperature",
                    minimum=0,
                    maximum=2,
                    value=temperature),
        gr.Slider(label="Typical p",
                    minimum=0,
                    maximum=1,
                    value=typical_p),
        gr.Slider(label="Maximum iterations",
                    minimum=0,
                    maximum=1000,
                    value=max_iter),
        gr.Slider(label="Burn-in",
                    minimum=0,
                    maximum=500,
                    value=burnin),
    ],
    outputs=gr.Textbox(label="Expanded text", lines=30))

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--port', type=int)
    parser.add_argument('--server', type=int)
    args = parser.parse_args()
    demo.launch(server_name=args.server or '0.0.0.0', server_port=args.port)