pszemraj commited on
Commit
1721ded
·
1 Parent(s): d735a1b

Upload _ai_msgbot_gpt_j_6b_8bit_with_hub.py

Browse files
Files changed (1) hide show
  1. _ai_msgbot_gpt_j_6b_8bit_with_hub.py +710 -0
_ai_msgbot_gpt_j_6b_8bit_with_hub.py ADDED
@@ -0,0 +1,710 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """-ai-msgbot-gpt-j-6b-8bit-with-hub.ipynb
3
+
4
+ Automatically generated by Colaboratory.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/gist/pszemraj/e49c60aafe04acc52fcfdd1baefe12e4/-ai-msgbot-gpt-j-6b-8bit-with-hub.ipynb
8
+
9
+ # <center> ai-msgbot - conversational 6B GPT-J 8bit demo
10
+
11
+
12
+ > This notebook demos interaction with a 6B GPT-J finetuned for dialogue via methods in [ai-msgbot](https://github.com/pszemraj/ai-msgbot)
13
+
14
+
15
+ By [Peter](https://github.com/pszemraj). This notebook and `ai-msgbot` are [licensed under creative commons](https://github.com/pszemraj/ai-msgbot/blob/main/LICENSE). Models trained on given datasets are subject to those datasets' licenses.
16
+
17
+
18
+ ## usage
19
+
20
+ 1. select the checkpoint of the model to use for generation in the `model_checkpoint` dropdown
21
+ 2. Run all cells to load everything
22
+ 3. adjust the prompt fields at the bottom of the notebook to whatever you want, see how AI responds.
23
+
24
+
25
+ A fine-tuning example etc. will come _eventually_
26
+
27
+
28
+ ---
29
+
30
+ # setup
31
+ """
32
+
33
+ #@markdown setup logging
34
+ import logging
35
+ from pathlib import Path
36
+ for handler in logging.root.handlers[:]:
37
+ logging.root.removeHandler(handler)
38
+
39
+ das_logfile = Path.cwd() / "8bit_inference.log"
40
+
41
+ logging.basicConfig(
42
+ level=logging.INFO,
43
+ filename=das_logfile,
44
+ filemode='w',
45
+ format="%(asctime)s %(levelname)s %(message)s",
46
+ datefmt="%m/%d/%Y %I:%M:%S",
47
+ )
48
+
49
+ #@markdown add auto-Colab formatting with `IPython.display`
50
+ from IPython.display import HTML, display
51
+ # colab formatting
52
+ def set_css():
53
+ display(
54
+ HTML(
55
+ """
56
+ <style>
57
+ pre {
58
+ white-space: pre-wrap;
59
+ }
60
+ </style>
61
+ """
62
+ )
63
+ )
64
+
65
+ get_ipython().events.register("pre_run_cell", set_css)
66
+
67
+ from pathlib import Path
68
+
69
+ """### GPU info"""
70
+
71
+ !nvidia-smi
72
+
73
+ """## install and import
74
+
75
+ _this notebook uses a specific version of `torch` which can take a while to install._
76
+ """
77
+
78
+ !pip install transformers==4.24.0 -q
79
+ !pip install bitsandbytes==0.32.2 -q
80
+ !pip install datasets==1.16.1 -q
81
+ !pip install torch==1.11 -q
82
+ !pip install accelerate==0.12.0 -q
83
+ !pip install pysbd==0.3.4 -q
84
+
85
+ # Commented out IPython magic to ensure Python compatibility.
86
+ # %%capture
87
+ # import transformers
88
+ #
89
+ # import pandas as pd
90
+ #
91
+ # import torch
92
+ # import torch.nn.functional as F
93
+ # from torch import nn
94
+ # from torch.cuda.amp import custom_fwd, custom_bwd
95
+ #
96
+ # import bitsandbytes as bnb
97
+ # from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise
98
+ #
99
+ # from tqdm.auto import tqdm
100
+
101
+ #@markdown utils
102
+ from transformers.utils.logging import set_verbosity
103
+
104
+ set_verbosity(40)
105
+
106
+ import warnings
107
+ # ignore hf pipeline complaints
108
+ warnings.filterwarnings("ignore", category=UserWarning, module='transformers')
109
+
110
+ """## Converting the model to 8 bits
111
+
112
+ """
113
+
114
+ #@title define 8bit classes
115
+
116
+ #@markdown - bitsandbytes lib
117
+ class FrozenBNBLinear(nn.Module):
118
+ def __init__(self, weight, absmax, code, bias=None):
119
+ assert isinstance(bias, nn.Parameter) or bias is None
120
+ super().__init__()
121
+ self.out_features, self.in_features = weight.shape
122
+ self.register_buffer("weight", weight.requires_grad_(False))
123
+ self.register_buffer("absmax", absmax.requires_grad_(False))
124
+ self.register_buffer("code", code.requires_grad_(False))
125
+ self.adapter = None
126
+ self.bias = bias
127
+
128
+ def forward(self, input):
129
+ output = DequantizeAndLinear.apply(
130
+ input, self.weight, self.absmax, self.code, self.bias
131
+ )
132
+ if self.adapter:
133
+ output += self.adapter(input)
134
+ return output
135
+
136
+ @classmethod
137
+ def from_linear(cls, linear: nn.Linear) -> "FrozenBNBLinear":
138
+ weights_int8, state = quantize_blockise_lowmemory(linear.weight)
139
+ return cls(weights_int8, *state, linear.bias)
140
+
141
+ def __repr__(self):
142
+ return f"{self.__class__.__name__}({self.in_features}, {self.out_features})"
143
+
144
+
145
+ class DequantizeAndLinear(torch.autograd.Function):
146
+ @staticmethod
147
+ @custom_fwd
148
+ def forward(
149
+ ctx,
150
+ input: torch.Tensor,
151
+ weights_quantized: torch.ByteTensor,
152
+ absmax: torch.FloatTensor,
153
+ code: torch.FloatTensor,
154
+ bias: torch.FloatTensor,
155
+ ):
156
+ weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
157
+ ctx.save_for_backward(input, weights_quantized, absmax, code)
158
+ ctx._has_bias = bias is not None
159
+ return F.linear(input, weights_deq, bias)
160
+
161
+ @staticmethod
162
+ @custom_bwd
163
+ def backward(ctx, grad_output: torch.Tensor):
164
+ assert (
165
+ not ctx.needs_input_grad[1]
166
+ and not ctx.needs_input_grad[2]
167
+ and not ctx.needs_input_grad[3]
168
+ )
169
+ input, weights_quantized, absmax, code = ctx.saved_tensors
170
+ # grad_output: [*batch, out_features]
171
+ weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
172
+ grad_input = grad_output @ weights_deq
173
+ grad_bias = grad_output.flatten(0, -2).sum(dim=0) if ctx._has_bias else None
174
+ return grad_input, None, None, None, grad_bias
175
+
176
+
177
+ class FrozenBNBEmbedding(nn.Module):
178
+ def __init__(self, weight, absmax, code):
179
+ super().__init__()
180
+ self.num_embeddings, self.embedding_dim = weight.shape
181
+ self.register_buffer("weight", weight.requires_grad_(False))
182
+ self.register_buffer("absmax", absmax.requires_grad_(False))
183
+ self.register_buffer("code", code.requires_grad_(False))
184
+ self.adapter = None
185
+
186
+ def forward(self, input, **kwargs):
187
+ with torch.no_grad():
188
+ # note: both quantuized weights and input indices are *not* differentiable
189
+ weight_deq = dequantize_blockwise(
190
+ self.weight, absmax=self.absmax, code=self.code
191
+ )
192
+ output = F.embedding(input, weight_deq, **kwargs)
193
+ if self.adapter:
194
+ output += self.adapter(input)
195
+ return output
196
+
197
+ @classmethod
198
+ def from_embedding(cls, embedding: nn.Embedding) -> "FrozenBNBEmbedding":
199
+ weights_int8, state = quantize_blockise_lowmemory(embedding.weight)
200
+ return cls(weights_int8, *state)
201
+
202
+ def __repr__(self):
203
+ return f"{self.__class__.__name__}({self.num_embeddings}, {self.embedding_dim})"
204
+
205
+
206
+ def quantize_blockise_lowmemory(matrix: torch.Tensor, chunk_size: int = 2**20):
207
+ assert chunk_size % 4096 == 0
208
+ code = None
209
+ chunks = []
210
+ absmaxes = []
211
+ flat_tensor = matrix.view(-1)
212
+ for i in range((matrix.numel() - 1) // chunk_size + 1):
213
+ input_chunk = flat_tensor[i * chunk_size : (i + 1) * chunk_size].clone()
214
+ quantized_chunk, (absmax_chunk, code) = quantize_blockwise(
215
+ input_chunk, code=code
216
+ )
217
+ chunks.append(quantized_chunk)
218
+ absmaxes.append(absmax_chunk)
219
+ matrix_i8 = torch.cat(chunks).reshape_as(matrix)
220
+ absmax = torch.cat(absmaxes)
221
+ return matrix_i8, (absmax, code)
222
+
223
+
224
+ def convert_to_int8(model):
225
+ """Convert linear and embedding modules to 8-bit with optional adapters"""
226
+ for module in list(model.modules()):
227
+ for name, child in module.named_children():
228
+ if isinstance(child, nn.Linear):
229
+ print(name, child)
230
+ setattr(
231
+ module,
232
+ name,
233
+ FrozenBNBLinear(
234
+ weight=torch.zeros(
235
+ child.out_features, child.in_features, dtype=torch.uint8
236
+ ),
237
+ absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
238
+ code=torch.zeros(256),
239
+ bias=child.bias,
240
+ ),
241
+ )
242
+ elif isinstance(child, nn.Embedding):
243
+ setattr(
244
+ module,
245
+ name,
246
+ FrozenBNBEmbedding(
247
+ weight=torch.zeros(
248
+ child.num_embeddings, child.embedding_dim, dtype=torch.uint8
249
+ ),
250
+ absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
251
+ code=torch.zeros(256),
252
+ ),
253
+ )
254
+
255
+ #@markdown Patch GPT-J before loading:
256
+
257
+
258
+ class GPTJBlock(transformers.models.gptj.modeling_gptj.GPTJBlock):
259
+ def __init__(self, config):
260
+ super().__init__(config)
261
+
262
+ convert_to_int8(self.attn)
263
+ convert_to_int8(self.mlp)
264
+
265
+
266
+ class GPTJModel(transformers.models.gptj.modeling_gptj.GPTJModel):
267
+ def __init__(self, config):
268
+ super().__init__(config)
269
+ convert_to_int8(self)
270
+
271
+
272
+ class GPTJForCausalLM(transformers.models.gptj.modeling_gptj.GPTJForCausalLM):
273
+ def __init__(self, config):
274
+ super().__init__(config)
275
+ convert_to_int8(self)
276
+
277
+
278
+ transformers.models.gptj.modeling_gptj.GPTJBlock = GPTJBlock
279
+
280
+ # Commented out IPython magic to ensure Python compatibility.
281
+ # %%capture
282
+ # #@markdown `add_adapters()`
283
+ #
284
+ # def add_adapters(model, adapter_dim=4, p = 0.1):
285
+ # assert adapter_dim > 0
286
+ #
287
+ # for name, module in model.named_modules():
288
+ # if isinstance(module, FrozenBNBLinear):
289
+ # if "attn" in name or "mlp" in name or "head" in name:
290
+ # print("Adding adapter to", name)
291
+ # module.adapter = nn.Sequential(
292
+ # nn.Linear(module.in_features, adapter_dim, bias=False),
293
+ # nn.Dropout(p=p),
294
+ # nn.Linear(adapter_dim, module.out_features, bias=False),
295
+ # )
296
+ # print("Initializing", name)
297
+ # nn.init.zeros_(module.adapter[2].weight)
298
+ #
299
+ # else:
300
+ # print("Not adding adapter to", name)
301
+ # elif isinstance(module, FrozenBNBEmbedding):
302
+ # print("Adding adapter to", name)
303
+ # module.adapter = nn.Sequential(
304
+ # nn.Embedding(module.num_embeddings, adapter_dim),
305
+ # nn.Dropout(p=p),
306
+ # nn.Linear(adapter_dim, module.embedding_dim, bias=False),
307
+ # )
308
+ # print("Initializing", name)
309
+ # nn.init.zeros_(module.adapter[2].weight)
310
+ #
311
+
312
+ #@markdown set up config
313
+ config = transformers.GPTJConfig.from_pretrained("hivemind/gpt-j-6B-8bit")
314
+ tokenizer = transformers.AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
315
+ config.pad_token_id = config.eos_token_id
316
+ tokenizer.pad_token = config.pad_token_id
317
+
318
+ """# load model
319
+
320
+ """
321
+
322
+ from contextlib import contextmanager
323
+ import sys, os, gc
324
+ import logging
325
+ from tqdm.auto import tqdm
326
+ #@markdown define `load_8bit_from_hub()`
327
+
328
+ @contextmanager
329
+ def suppress_stdout():
330
+ with open(os.devnull, "w") as devnull:
331
+ old_stdout = sys.stdout
332
+ sys.stdout = devnull
333
+ try:
334
+ yield
335
+ finally:
336
+ sys.stdout = old_stdout
337
+
338
+ def load_8bit_from_hub(model_id:str, **kwargs):
339
+ pbar = tqdm(desc="instantiating model..", total=3)
340
+
341
+ with suppress_stdout():
342
+ gc.collect()
343
+ model = GPTJForCausalLM.from_pretrained(model_id,
344
+ device_map='auto',
345
+ low_cpu_mem_usage=True,
346
+ **kwargs)
347
+ pbar.update()
348
+ add_adapters(model)
349
+ pbar.update()
350
+ model = model.to("cuda" if torch.cuda.is_available() else -1)
351
+ pbar.update()
352
+ return model
353
+
354
+ #@title <font color="orange"> Select Model to Load </font>
355
+ model_name = "ethzanalytics/gpt-j-8bit-KILT_WoW_10k_steps" #@param ["ethzanalytics/gpt-j-8bit-KILT_WoW_10k_steps", "ethzanalytics/gpt-j-8bit-daily_dialogues", "ethzanalytics/gpt-j-6B-8bit-sharded"]
356
+
357
+ # load_8bit_from_hub() is a wrapper around AutoModel.from_pretrained() and will
358
+ # passthrough all kwargs to that
359
+ model = load_8bit_from_hub(model_name,)
360
+
361
+ """# generate text
362
+
363
+ ## standard generation
364
+ `
365
+
366
+ with torch:
367
+
368
+ > with "standard" generation it's recommended to put the **speaker token labels** at the end of your prompt so the model "knows" to respond.
369
+
370
+ i.e `Person Alpha:` or `Person Beta:` for these two models.
371
+ """
372
+
373
+ prompt = "Person Alpha: what is the theory of being \"woke\" all about?\\n Person Beta: " # @param {type:"string"}
374
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
375
+ with torch.no_grad():
376
+ prompt = tokenizer(prompt, return_tensors="pt")
377
+ prompt = {key: value.to(device) for key, value in prompt.items()}
378
+ out = model.generate(
379
+ **prompt,
380
+ min_length=24,
381
+ max_length=96,
382
+ top_k=30,
383
+ top_p=0.9,
384
+ temperature=0.4,
385
+ do_sample=True,
386
+ repetition_penalty=1.2,
387
+ no_repeat_ngram_size=3,
388
+ pad_token_id=tokenizer.eos_token_id,
389
+ )
390
+ result = tokenizer.decode(
391
+ out[0],
392
+ remove_invalid_values=True,
393
+ skip_special_tokens=True,
394
+ clean_up_tokenization_spaces=True,
395
+ )
396
+ result
397
+
398
+ """---
399
+
400
+ ## 'Extract' bot response
401
+ - transformers `pipeline` object
402
+ - generate with better params
403
+ - extract the bot's response with `get_bot_response()` - start to use [ai-msgbot](https://github.com/pszemraj/ai-msgbot) _like it was meant to be used_
404
+ """
405
+
406
+ from transformers import pipeline
407
+
408
+ generator = pipeline(
409
+ "text-generation",
410
+ model=model,
411
+ tokenizer="EleutherAI/gpt-j-6B",
412
+ device= 0 if torch.cuda.is_available() else -1,
413
+ )
414
+
415
+ """### generation functions
416
+
417
+ for extracting the response, beam search vs. sampling, etc
418
+ """
419
+
420
+ # @markdown `get_bot_response(name_resp: str, model_resp: list, name_spk: str, verbose: bool = False)`
421
+ # @markdown - this extracts the response from "Person Beta" from the total generation
422
+ import pysbd
423
+
424
+ seg = pysbd.Segmenter(language="en", clean=False)
425
+
426
+ import re
427
+
428
+
429
+ def split_sentences(text, use_regex=False, min_len=2):
430
+ """given a string, splits it into sentences based on punctuation marks."""
431
+
432
+ if use_regex:
433
+ sentences = re.split(r'(?<=[.!?]) +', string)
434
+ else:
435
+ # https://github.com/nipunsadvilkar/pySBD
436
+ sentences = seg.segment(text)
437
+ return [s.strip() for s in sentences if len(s.strip()) > min_len]
438
+
439
+
440
+ def validate_response(response_text):
441
+
442
+ if isinstance(response_text, list):
443
+
444
+ return response_text
445
+ # if len(response_text) > 1 else split_sentences(str(response_text))
446
+ elif isinstance(response_text, str):
447
+ return split_sentences(response_text)
448
+ else:
449
+ raise ValueError(f"response input {response_text} not a list or str..")
450
+
451
+
452
+ def get_bot_response(
453
+ name_resp: str, model_resp: list, name_spk: str, verbose: bool = False
454
+ ):
455
+ """
456
+ get_bot_response - gets the bot response to a prompt, checking to ensure that additional statements by the "speaker" are not included in the response.
457
+ Args:
458
+ name_resp (str): the name of the responder
459
+ model_resp (list): the model response
460
+ name_spk (str): the name of the speaker
461
+ verbose (bool, optional): Defaults to False.
462
+ Returns:
463
+ bot_response (str): the bot response, isolated down to just text without the "name tokens" or further messages from the speaker.
464
+ """
465
+
466
+ model_resp = validate_response(model_resp)
467
+ logging.info(f"isolating response from:\t{model_resp}")
468
+ fn_resp = []
469
+
470
+ name_counter = 0
471
+ break_safe = False
472
+ for resline in model_resp:
473
+ if name_resp.lower() in resline.lower():
474
+ name_counter += 1
475
+ break_safe = True
476
+ continue
477
+ if ":" in resline and name_resp.lower() not in resline.lower():
478
+ break
479
+ if name_spk.lower() in resline.lower() and not break_safe:
480
+ break
481
+ else:
482
+ fn_resp.append(resline)
483
+ if verbose:
484
+ print("the full response is:\n")
485
+ print("\n".join(fn_resp))
486
+ if isinstance(fn_resp, list):
487
+ fn_resp = fn_resp[0] if len(fn_resp) == 1 else " ".join(fn_resp)
488
+ return fn_resp
489
+
490
+ import pprint as pp
491
+
492
+ # @markdown define `generate_sampling(prompt: str, ...)`
493
+
494
+
495
+ def generate_sampling(
496
+ prompt: str,
497
+ suffix:str=None,
498
+ temperature=0.4,
499
+ top_k: int = 40,
500
+ top_p=0.90,
501
+ min_length: int = 16,
502
+ max_length: int = 128,
503
+ no_repeat_ngram_size: int = 3,
504
+ repetition_penalty=1.5,
505
+ return_full_text=False,
506
+ verbose=False,
507
+ **kwargs,
508
+ ) -> None:
509
+
510
+ logging.info(f"generating results for input:\n\t{prompt}\n\t...")
511
+ if verbose:
512
+ print(f"generating results for input:\n\t{prompt}\n\t...")
513
+ prompt = f"{prompt}{suffix}" if suffix is not None else prompt
514
+
515
+ _prompt_tokens = len(generator.tokenizer(prompt).input_ids)
516
+ result = generator(
517
+ prompt,
518
+ min_length=min_length+_prompt_tokens,
519
+ temperature=temperature,
520
+ top_k=top_k,
521
+ top_p=top_p,
522
+ no_repeat_ngram_size=no_repeat_ngram_size,
523
+ repetition_penalty=repetition_penalty,
524
+ remove_invalid_values=True,
525
+ clean_up_tokenization_spaces=True,
526
+ do_sample=True,
527
+ return_full_text=return_full_text,
528
+ max_new_tokens=max_length+_prompt_tokens,
529
+ pad_token_id=generator.tokenizer.eos_token_id,
530
+ **kwargs,
531
+ )
532
+
533
+ output = result[0]["generated_text"]
534
+ logging.info(f"model output:\n\t{output}")
535
+ if verbose:
536
+ print(f"model output:\n\t{output}")
537
+ response = get_bot_response(
538
+ model_resp=output,
539
+ name_spk="Person Alpha",
540
+ name_resp="Person Beta",
541
+ verbose=False,
542
+ )
543
+
544
+ logging.info(f"extracted bot response:\n\t{response}")
545
+
546
+ pp.pprint(response)
547
+
548
+ return response
549
+
550
+ import pprint as pp
551
+
552
+ #@markdown define `generate_beams(prompt: str, num_beams:int =4, ...)`
553
+
554
+
555
+ def generate_beams(
556
+ prompt: str,
557
+ suffix:str=None,
558
+ num_beams=4,
559
+ min_length: int = 32,
560
+ max_length: int = 128,
561
+ no_repeat_ngram_size: int = 3,
562
+ repetition_penalty=2.5,
563
+ return_full_text=False,
564
+ verbose=False,
565
+ **kwargs,
566
+ ) -> None:
567
+
568
+ logging.info(f"generating results for input:\n\t{prompt}\n\t...")
569
+ if verbose:
570
+ print(f"generating results for input:\n\t{prompt}\n\t")
571
+
572
+ prompt = f"{prompt}{suffix}" if suffix is not None else prompt
573
+ _prompt_tokens = len(generator.tokenizer(prompt).input_ids)
574
+ result = generator(
575
+ prompt,
576
+ min_length=min_length+_prompt_tokens,
577
+ num_beams=num_beams,
578
+ do_sample=False,
579
+ early_stopping=True,
580
+ no_repeat_ngram_size=no_repeat_ngram_size,
581
+ repetition_penalty=repetition_penalty,
582
+ remove_invalid_values=True,
583
+ clean_up_tokenization_spaces=True,
584
+ return_full_text=return_full_text,
585
+ max_new_tokens=max_length+_prompt_tokens,
586
+ pad_token_id=generator.tokenizer.eos_token_id,
587
+ **kwargs,
588
+ )
589
+
590
+ output = result[0]["generated_text"]
591
+ logging.info(f"model output:\n\t{output}")
592
+ if verbose:
593
+ print(f"model output:\n\t{output}")
594
+ response = get_bot_response(
595
+ model_resp=output,
596
+ name_spk="Person Alpha",
597
+ name_resp="Person Beta",
598
+ verbose=False,
599
+ )
600
+
601
+
602
+ logging.info(f"extracted bot response:\n\t{response}")
603
+
604
+ pp.pprint(response)
605
+
606
+ return response
607
+
608
+ import pprint as pp
609
+
610
+ #@markdown define `generate_csearch(prompt: str, num_beams:int =4, ...)`
611
+
612
+
613
+ def generate_csearch(
614
+ prompt: str,
615
+ suffix:str=None,
616
+ max_length: int = 96,
617
+ min_length: int = 24,
618
+ penalty_alpha: float=0.6,
619
+ top_k: int=5,
620
+ return_full_text=False,
621
+ verbose=False,
622
+ **kwargs,
623
+ ) -> None:
624
+
625
+ logging.info(f"generating results for input:\n\t{prompt}\n\t...")
626
+ if verbose:
627
+ print(f"generating results for input:\n\t{prompt}\n\t")
628
+
629
+ prompt = f"{prompt}{suffix}" if suffix is not None else prompt
630
+ _prompt_tokens = len(generator.tokenizer(prompt).input_ids)
631
+ result = generator(
632
+ prompt,
633
+ min_length=min_length+_prompt_tokens,
634
+ max_new_tokens=max_length,
635
+ penalty_alpha=penalty_alpha,
636
+ top_k=top_k,
637
+ remove_invalid_values=True,
638
+ clean_up_tokenization_spaces=True,
639
+ return_full_text=return_full_text,
640
+ pad_token_id=generator.tokenizer.eos_token_id,
641
+ **kwargs,
642
+ )
643
+
644
+ output = result[0]["generated_text"]
645
+ logging.info(f"model output:\n\t{output}")
646
+ if verbose:
647
+ print(f"model output:\n\t{output}")
648
+ response = get_bot_response(
649
+ model_resp=output,
650
+ name_spk="Person Alpha",
651
+ name_resp="Person Beta",
652
+ verbose=False,
653
+ )
654
+
655
+
656
+ logging.info(f"extracted bot response:\n\t{response}")
657
+
658
+ pp.pprint(response)
659
+
660
+ return response
661
+
662
+ """### generate - sampling
663
+
664
+ > **NOTE:** that here the `suffix="\nPerson Beta: ",` is passed so it does not need to be added to a prompt
665
+ """
666
+
667
+ # Commented out IPython magic to ensure Python compatibility.
668
+ # %%time
669
+ #
670
+ # prompt = "How do we harness space energy?" #@param {type:"string"}
671
+ # temperature = 0.2 #@param {type:"slider", min:0.1, max:1, step:0.1}
672
+ # top_k = 30 #@param {type:"slider", min:10, max:60, step:10}
673
+ #
674
+ #
675
+ # result = generate_sampling(
676
+ # prompt,
677
+ # suffix="\nPerson Beta: ",
678
+ # max_length=128,
679
+ # min_length=32,
680
+ # temperature=temperature,
681
+ # top_k=top_k,
682
+ # )
683
+ #
684
+
685
+ prompt = "What is the purpose of life?" # @param {type:"string"}
686
+ temperature = 0.5 # @param {type:"slider", min:0.1, max:1, step:0.1}
687
+ top_k = 30 # @param {type:"slider", min:10, max:60, step:10}
688
+
689
+ generated_result = generate_sampling(
690
+ prompt,
691
+ temperature=temperature,
692
+ top_k=top_k,
693
+ min_length=32,
694
+ suffix="\nPerson Beta: ",
695
+ )
696
+
697
+ """### generate - beam search"""
698
+
699
+ # Commented out IPython magic to ensure Python compatibility.
700
+ # %%time
701
+ # prompt = "How was your day?" #@param {type:"string"}
702
+ # num_beams = 4 #@param {type:"slider", min:2, max:10, step:2}
703
+ # min_length = 16 #@param {type:"slider", min:8, max:128, step:8}
704
+ #
705
+ # generated_result = generate_beams(
706
+ # prompt,
707
+ # suffix="\nPerson Beta: ",
708
+ # min_length=min_length,
709
+ # num_beams=num_beams,
710
+ # )