teticio commited on
Commit
6425ca8
·
1 Parent(s): f6c9818

use LogitWarpers and add typical_p

Browse files
Files changed (1) hide show
  1. app.py +96 -47
app.py CHANGED
@@ -3,7 +3,15 @@ import torch
3
  import numpy as np
4
  import gradio as gr
5
  from nltk import sent_tokenize
6
- from transformers import RobertaTokenizer, RobertaForMaskedLM
 
 
 
 
 
 
 
 
7
 
8
  nltk.download('punkt')
9
 
@@ -17,53 +25,74 @@ if cuda:
17
  max_len = 20
18
  top_k = 100
19
  temperature = 1
 
20
  burnin = 250
21
  max_iter = 500
22
 
23
 
24
  # adapted from https://github.com/nyu-dl/bert-gen
25
- def generate_step(out,
26
- gen_idx,
27
- temperature=None,
28
- top_k=0,
29
- sample=False,
30
- return_list=True):
31
  """ Generate a word from from out[gen_idx]
32
 
33
  args:
34
  - out (torch.Tensor): tensor of logits of size batch_size x seq_len x vocab_size
35
  - gen_idx (int): location for which to generate for
36
  - top_k (int): if >0, only sample from the top k most probable words
37
- - sample (Bool): if True, sample from full distribution. Overridden by top_k
 
 
 
 
 
38
  """
39
  logits = out.logits[:, gen_idx]
40
- if temperature is not None:
41
- logits = logits / temperature
42
  if top_k > 0:
43
- kth_vals, kth_idx = logits.topk(top_k, dim=-1)
44
- dist = torch.distributions.categorical.Categorical(logits=kth_vals)
45
- idx = kth_idx.gather(dim=1,
46
- index=dist.sample().unsqueeze(-1)).squeeze(-1)
47
- elif sample:
48
- dist = torch.distributions.categorical.Categorical(logits=logits)
49
- idx = dist.sample() # removed superfluous squeeze(-1)
 
 
 
 
 
50
  else:
51
- idx = torch.argmax(logits, dim=-1)
52
- return idx.tolist() if return_list else idx
 
53
 
54
 
55
  # adapted from https://github.com/nyu-dl/bert-gen
56
- def parallel_sequential_generation(seed_text,
57
- seed_end_text,
58
- max_len=max_len,
59
- top_k=top_k,
60
- temperature=temperature,
61
- max_iter=max_iter,
62
- burnin=burnin):
63
- """ Generate for one random position at a timestep
 
64
 
65
- args:
 
 
 
 
 
 
66
  - burnin: during burn-in period, sample from full distribution; afterwards take argmax
 
 
 
67
  """
68
  inp = tokenizer(seed_text + tokenizer.mask_token * max_len + seed_end_text,
69
  return_tensors='pt')
@@ -75,12 +104,11 @@ def parallel_sequential_generation(seed_text,
75
 
76
  for ii in range(max_iter):
77
  kk = np.random.randint(0, max_len)
78
- out = model(**inp)
79
- topk = top_k if (ii >= burnin) else 0
80
- idxs = generate_step(out,
81
  gen_idx=seed_len + kk,
82
- top_k=topk,
83
  temperature=temperature,
 
84
  sample=(ii < burnin))
85
  inp['input_ids'][0][seed_len + kk] = idxs[0]
86
 
@@ -90,12 +118,27 @@ def parallel_sequential_generation(seed_text,
90
  return tokenizer.decode(tokens)
91
 
92
 
93
- def inbertolate(doc,
94
- max_len=max_len,
95
- top_k=top_k,
96
- temperature=temperature,
97
- max_iter=max_iter,
98
- burnin=burnin):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  new_doc = ''
100
  paras = doc.split('\n')
101
 
@@ -108,13 +151,15 @@ def inbertolate(doc,
108
 
109
  for sentence in range(len(para) - 1):
110
  new_doc += para[sentence] + ' '
111
- new_doc += parallel_sequential_generation(para[sentence],
112
- para[sentence + 1],
113
- max_len=max_len,
114
- top_k=top_k,
115
- temperature=temperature,
116
- burnin=burnin,
117
- max_iter=max_iter) + ' '
 
 
118
 
119
  new_doc += '\n'
120
  return new_doc
@@ -130,7 +175,7 @@ if __name__ == '__main__':
130
  gr.Interface(
131
  fn=inbertolate,
132
  inputs=[
133
- gr.Textbox(label="Text", lines=7),
134
  gr.Slider(label="Maximum length to insert between sentences",
135
  minimum=1,
136
  maximum=40,
@@ -141,6 +186,10 @@ if __name__ == '__main__':
141
  minimum=0,
142
  maximum=2,
143
  value=temperature),
 
 
 
 
144
  gr.Slider(label="Maximum iterations",
145
  minimum=0,
146
  maximum=1000,
@@ -150,5 +199,5 @@ if __name__ == '__main__':
150
  maximum=500,
151
  value=burnin),
152
  ],
153
- outputs=gr.Textbox(label="Expanded text", lines=24))
154
  block.launch(server_name='0.0.0.0')
 
3
  import numpy as np
4
  import gradio as gr
5
  from nltk import sent_tokenize
6
+
7
+ from transformers import (
8
+ RobertaTokenizer,
9
+ RobertaForMaskedLM,
10
+ LogitsProcessorList,
11
+ TopKLogitsWarper,
12
+ TemperatureLogitsWarper,
13
+ )
14
+ from transformers.generation_logits_process import TypicalLogitsWarper
15
 
16
  nltk.download('punkt')
17
 
 
25
  max_len = 20
26
  top_k = 100
27
  temperature = 1
28
+ typical_p = 0
29
  burnin = 250
30
  max_iter = 500
31
 
32
 
33
  # adapted from https://github.com/nyu-dl/bert-gen
34
+ def generate_step(out: object,
35
+ gen_idx: int,
36
+ top_k: int = top_k,
37
+ temperature: float = temperature,
38
+ typical_p: float = typical_p,
39
+ sample: bool = False) -> list:
40
  """ Generate a word from from out[gen_idx]
41
 
42
  args:
43
  - out (torch.Tensor): tensor of logits of size batch_size x seq_len x vocab_size
44
  - gen_idx (int): location for which to generate for
45
  - top_k (int): if >0, only sample from the top k most probable words
46
+ - temperature (float): sampling temperature
47
+ - typical_p (float): if >0 use typical sampling
48
+ - sample (bool): if True, sample from full distribution.
49
+
50
+ returns:
51
+ - list: batch_size tokens
52
  """
53
  logits = out.logits[:, gen_idx]
54
+ logit_warpers = []
 
55
  if top_k > 0:
56
+ logit_warpers += [TopKLogitsWarper(top_k)]
57
+ if temperature:
58
+ logit_warpers += [TemperatureLogitsWarper(temperature)]
59
+ if typical_p > 0:
60
+ if typical_p >= 1:
61
+ typical_p = 0.999
62
+ logit_warpers += [TypicalLogitsWarper(typical_p)]
63
+ logits = LogitsProcessorList(logit_warpers)(None, logits)
64
+
65
+ if sample:
66
+ probs = torch.nn.functional.softmax(logits, dim=-1)
67
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
68
  else:
69
+ next_tokens = torch.argmax(logits, dim=-1)
70
+
71
+ return next_tokens.tolist()
72
 
73
 
74
  # adapted from https://github.com/nyu-dl/bert-gen
75
+ def parallel_sequential_generation(seed_text: str,
76
+ seed_end_text: str,
77
+ max_len: int = max_len,
78
+ top_k: int = top_k,
79
+ temperature: float = temperature,
80
+ typical_p: float = typical_p,
81
+ max_iter: int = max_iter,
82
+ burnin: int = burnin) -> str:
83
+ """ Generate text consistent with preceding and following text
84
 
85
+ Args:
86
+ - seed_text (str): preceding text
87
+ - seed_end_text (str): following text
88
+ - top_k (int): if >0, only sample from the top k most probable words
89
+ - temperature (float): sampling temperature
90
+ - typical_p (float): if >0 use typical sampling
91
+ - max_iter (int): number of iterations in MCMC
92
  - burnin: during burn-in period, sample from full distribution; afterwards take argmax
93
+
94
+ Returns:
95
+ - string: generated text to insert between seed_text and seed_end_text
96
  """
97
  inp = tokenizer(seed_text + tokenizer.mask_token * max_len + seed_end_text,
98
  return_tensors='pt')
 
104
 
105
  for ii in range(max_iter):
106
  kk = np.random.randint(0, max_len)
107
+ idxs = generate_step(model(**inp),
 
 
108
  gen_idx=seed_len + kk,
109
+ top_k=top_k if (ii >= burnin) else 0,
110
  temperature=temperature,
111
+ typical_p=typical_p,
112
  sample=(ii < burnin))
113
  inp['input_ids'][0][seed_len + kk] = idxs[0]
114
 
 
118
  return tokenizer.decode(tokens)
119
 
120
 
121
+ def inbertolate(doc: str,
122
+ max_len: int = max_len,
123
+ top_k: int = top_k,
124
+ temperature: float = temperature,
125
+ typical_p: float = typical_p,
126
+ max_iter: int = max_iter,
127
+ burnin: int = burnin):
128
+ """ Pad out document generating every other sentence
129
+
130
+ Args:
131
+ - doc (str): document text
132
+ - max_len (int): number of tokens to insert between sentences
133
+ - top_k (int): if >0, only sample from the top k most probable words
134
+ - temperature (float): sampling temperature
135
+ - typical_p (float): if >0 use typical sampling
136
+ - max_iter (int): number of iterations in MCMC
137
+ - burnin: during burn-in period, sample from full distribution; afterwards take argmax
138
+
139
+ Returns:
140
+ - string: generated text to insert between seed_text and seed_end_text
141
+ """
142
  new_doc = ''
143
  paras = doc.split('\n')
144
 
 
151
 
152
  for sentence in range(len(para) - 1):
153
  new_doc += para[sentence] + ' '
154
+ new_doc += parallel_sequential_generation(
155
+ para[sentence],
156
+ para[sentence + 1],
157
+ max_len=max_len,
158
+ top_k=top_k,
159
+ temperature=float(temperature),
160
+ typical_p=typical_p,
161
+ burnin=burnin,
162
+ max_iter=max_iter) + ' '
163
 
164
  new_doc += '\n'
165
  return new_doc
 
175
  gr.Interface(
176
  fn=inbertolate,
177
  inputs=[
178
+ gr.Textbox(label="Text", lines=10),
179
  gr.Slider(label="Maximum length to insert between sentences",
180
  minimum=1,
181
  maximum=40,
 
186
  minimum=0,
187
  maximum=2,
188
  value=temperature),
189
+ gr.Slider(label="Typical p",
190
+ minimum=0,
191
+ maximum=1,
192
+ value=typical_p),
193
  gr.Slider(label="Maximum iterations",
194
  minimum=0,
195
  maximum=1000,
 
199
  maximum=500,
200
  value=burnin),
201
  ],
202
+ outputs=gr.Textbox(label="Expanded text", lines=30))
203
  block.launch(server_name='0.0.0.0')