Gbssreejith commited on
Commit
c074598
1 Parent(s): 3f16a66

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +484 -0
app.py CHANGED
@@ -0,0 +1,484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import torch
3
+ import math
4
+ import torch.nn as nn
5
+ from torch.nn.parameter import Parameter
6
+ import random
7
+ import numpy as np
8
+ from load_weights import load_weight
9
+ from sklearn.model_selection import train_test_split
10
+ from transformers import GPT2TokenizerFast
11
+ import pandas as pd
12
+ from torch.utils.data import Dataset, DataLoader
13
+ from transformers import AdamW, get_linear_schedule_with_warmup
14
+ torch.manual_seed(42)
15
+ import nltk
16
+ nltk.download('punkt')
17
+
18
+ from transformers import GPT2Tokenizer
19
+ from torch.utils.data import Dataset, DataLoader, random_split, RandomSampler, SequentialSampler
20
+ import datetime
21
+ import time
22
+ import os
23
+ os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
24
+ from tqdm import trange
25
+ import gradio as gr
26
+ import re
27
+
28
+
29
+
30
+
31
+ def gelu(x):
32
+ return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
33
+
34
+ class Conv1D(nn.Module):
35
+ def __init__(self, nf, nx):
36
+ super(Conv1D, self).__init__()
37
+ self.nf = nf
38
+ w = torch.empty(nx, nf)
39
+ nn.init.normal_(w, std=0.02)
40
+ self.weight = Parameter(w)
41
+ self.bias = Parameter(torch.zeros(nf))
42
+
43
+ def forward(self, x):
44
+ size_out = x.size()[:-1] + (self.nf,)
45
+ x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
46
+ x = x.view(*size_out)
47
+ return x
48
+
49
+ class LayerNorm(nn.Module):
50
+ def __init__(self, hidden_size, eps=1e-12):
51
+ """Construct a layernorm module in the TF style (epsilon inside the square root).
52
+ """
53
+ super(LayerNorm, self).__init__()
54
+ self.weight = nn.Parameter(torch.ones(hidden_size))
55
+ self.bias = nn.Parameter(torch.zeros(hidden_size))
56
+ self.variance_epsilon = eps
57
+
58
+ def forward(self, x):
59
+ u = x.mean(-1, keepdim=True)
60
+ s = (x - u).pow(2).mean(-1, keepdim=True)
61
+ x = (x - u) / torch.sqrt(s + self.variance_epsilon)
62
+ return self.weight * x + self.bias
63
+
64
+
65
+
66
+ class Attention(nn.Module):
67
+ def __init__(self, nx, n_ctx, config, scale=False):
68
+ super(Attention, self).__init__()
69
+ n_state = nx # in Attention: n_state=768 (nx=n_embd)
70
+ # [switch nx => n_state from Block to Attention to keep identical to TF implem]
71
+ assert n_state % config.n_head == 0
72
+ self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
73
+ self.n_head = config.n_head
74
+ self.split_size = n_state
75
+ self.scale = scale
76
+ self.c_attn = Conv1D(n_state * 3, nx)
77
+ self.c_proj = Conv1D(n_state, nx)
78
+
79
+ def _attn(self, q, k, v):
80
+ w = torch.matmul(q, k)
81
+ if self.scale:
82
+ w = w / math.sqrt(v.size(-1))
83
+ nd, ns = w.size(-2), w.size(-1)
84
+ b = self.bias[:, :, ns-nd:ns, :ns]
85
+ w = w * b - 1e10 * (1 - b)
86
+ w = nn.Softmax(dim=-1)(w)
87
+ return torch.matmul(w, v)
88
+
89
+ def merge_heads(self, x):
90
+ x = x.permute(0, 2, 1, 3).contiguous()
91
+ new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
92
+ return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states
93
+
94
+ def split_heads(self, x, k=False):
95
+ new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
96
+ x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states
97
+ if k:
98
+ return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length)
99
+ else:
100
+ return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
101
+
102
+ def forward(self, x, layer_past=None):
103
+ x = self.c_attn(x)
104
+ query, key, value = x.split(self.split_size, dim=2)
105
+ query = self.split_heads(query)
106
+ key = self.split_heads(key, k=True)
107
+ value = self.split_heads(value)
108
+ if layer_past is not None:
109
+ past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below
110
+ key = torch.cat((past_key, key), dim=-1)
111
+ value = torch.cat((past_value, value), dim=-2)
112
+ present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking
113
+ a = self._attn(query, key, value)
114
+ a = self.merge_heads(a)
115
+ a = self.c_proj(a)
116
+ return a, present
117
+
118
+
119
+ class MLP(nn.Module):
120
+ def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd)
121
+ super(MLP, self).__init__()
122
+ nx = config.n_embd
123
+ self.c_fc = Conv1D(n_state, nx)
124
+ self.c_proj = Conv1D(nx, n_state)
125
+ self.act = gelu
126
+
127
+ def forward(self, x):
128
+ h = self.act(self.c_fc(x))
129
+ h2 = self.c_proj(h)
130
+ return h2
131
+
132
+
133
+ class Block(nn.Module):
134
+ def __init__(self, n_ctx, config, scale=False):
135
+ super(Block, self).__init__()
136
+ nx = config.n_embd
137
+ self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon)
138
+ self.attn = Attention(nx, n_ctx, config, scale)
139
+ self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon)
140
+ self.mlp = MLP(4 * nx, config)
141
+
142
+ def forward(self, x, layer_past=None):
143
+ a, present = self.attn(self.ln_1(x), layer_past=layer_past)
144
+ x = x + a
145
+ m = self.mlp(self.ln_2(x))
146
+ x = x + m
147
+ return x, present
148
+
149
+
150
+
151
+ class GPT2Model(nn.Module):
152
+ def __init__(self, config):
153
+ super(GPT2Model, self).__init__()
154
+ self.n_layer = config.n_layer
155
+ self.n_embd = config.n_embd
156
+ self.n_vocab = config.vocab_size
157
+
158
+ self.wte = nn.Embedding(config.vocab_size, config.n_embd)
159
+ self.wpe = nn.Embedding(config.n_positions, config.n_embd)
160
+ block = Block(config.n_ctx, config, scale=True)
161
+ self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)])
162
+ self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
163
+
164
+ def set_embeddings_weights(self, model_embeddings_weights):
165
+ embed_shape = model_embeddings_weights.shape
166
+ self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
167
+ self.decoder.weight = model_embeddings_weights # Tied weights
168
+
169
+
170
+
171
+ def forward(self, input_ids, position_ids=None, token_type_ids=None, past=None):
172
+
173
+ if (input_ids >= self.n_vocab).any():
174
+ raise ValueError(f"Invalid token ID found in input_ids: {input_ids}")
175
+
176
+ # print(f"input_ids: {input_ids}") # Debugging statement
177
+ # print(f"Max input_id: {input_ids.max().item()}") # Debugging statement
178
+ # print(f"Min input_id: {input_ids.min().item()}") # Debugging statement
179
+
180
+ if past is None:
181
+ past_length = 0
182
+ past = [None] * len(self.h)
183
+ else:
184
+ past_length = past[0][0].size(-2)
185
+ if position_ids is None:
186
+ position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long,
187
+ device=input_ids.device)
188
+ position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
189
+
190
+ input_shape = input_ids.size()
191
+ input_ids = input_ids.view(-1, input_ids.size(-1))
192
+ position_ids = position_ids.view(-1, position_ids.size(-1))
193
+
194
+ inputs_embeds = self.wte(input_ids)
195
+ position_embeds = self.wpe(position_ids)
196
+
197
+ # print(f"inputs_embeds shape: {inputs_embeds.shape}")
198
+ # print(f"position_embeds shape: {position_embeds.shape}")
199
+
200
+
201
+ if token_type_ids is not None:
202
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
203
+ token_type_embeds = self.wte(token_type_ids)
204
+ else:
205
+ token_type_embeds = 0
206
+ hidden_states = inputs_embeds + position_embeds + token_type_embeds
207
+ presents = []
208
+ for block, layer_past in zip(self.h, past):
209
+ hidden_states, present = block(hidden_states, layer_past)
210
+ presents.append(present)
211
+ hidden_states = self.ln_f(hidden_states)
212
+ output_shape = input_shape + (hidden_states.size(-1),)
213
+ return hidden_states.view(*output_shape), presents
214
+
215
+ class GPT2LMHead(nn.Module):
216
+ def __init__(self, model_embeddings_weights, config):
217
+ super(GPT2LMHead, self).__init__()
218
+ self.n_embd = config.n_embd
219
+ self.set_embeddings_weights(model_embeddings_weights)
220
+
221
+ def set_embeddings_weights(self, model_embeddings_weights):
222
+ embed_shape = model_embeddings_weights.shape
223
+ self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
224
+ self.decoder.weight = model_embeddings_weights # Tied weights
225
+
226
+ def forward(self, hidden_state):
227
+ # Truncated Language modeling logits (we remove the last token)
228
+ # h_trunc = h[:, :-1].contiguous().view(-1, self.n_embd)
229
+ lm_logits = self.decoder(hidden_state)
230
+ return lm_logits
231
+
232
+ import torch.nn.functional as F
233
+
234
+ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
235
+ """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
236
+ Args:
237
+ logits: logits distribution shape (batch size, vocabulary size)
238
+ top_k > 0: keep only top k tokens with highest probability (top-k filtering).
239
+ top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
240
+ filter_value: value to replace filtered logits.
241
+ """
242
+ assert logits.dim() == 2 # batch size x vocabulary size
243
+ top_k = min(top_k, logits.size(-1)) # Safety check
244
+ if top_k > 0:
245
+ # Remove all tokens with a probability less than the last token of the top-k
246
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
247
+ logits[indices_to_remove] = filter_value
248
+
249
+ if top_p > 0.0:
250
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
251
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
252
+
253
+ # Remove tokens with cumulative probability above the threshold
254
+ sorted_indices_to_remove = cumulative_probs > top_p
255
+ # Shift the indices to the right to keep also the first token above the threshold
256
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
257
+ sorted_indices_to_remove[..., 0] = 0
258
+
259
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
260
+ logits[indices_to_remove] = filter_value
261
+ return logits
262
+
263
+
264
+ class GPT2LMHeadModel(nn.Module):
265
+ def __init__(self, config):
266
+ super(GPT2LMHeadModel, self).__init__()
267
+ self.transformer = GPT2Model(config)
268
+ self.lm_head = GPT2LMHead(self.transformer.wte.weight, config)
269
+
270
+ def set_tied(self):
271
+ """ Make sure we are sharing the embeddings
272
+ """
273
+ self.lm_head.set_embeddings_weights(self.transformer.wte.weight)
274
+
275
+ def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None):
276
+ hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past)
277
+ lm_logits = self.lm_head(hidden_states)
278
+
279
+ outputs = (lm_logits,presents)
280
+
281
+ if lm_labels is not None:
282
+ shift_logits = lm_logits[..., :-1, :].contiguous()
283
+ shift_labels = lm_labels[..., 1:].contiguous()
284
+ loss_fct = nn.CrossEntropyLoss()
285
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
286
+ outputs = (loss,) + outputs
287
+ return outputs
288
+
289
+ import torch.nn.functional as F
290
+
291
+
292
+
293
+ def generate(
294
+ self, input_ids, max_length, temperature=1.0, top_k=0, top_p=0.9, repetition_penalty=1.0, device='cuda'
295
+ ):
296
+ self.eval()
297
+ input_ids = input_ids.to(device)
298
+ batch_size = input_ids.shape[0]
299
+ past = None
300
+
301
+ generated = input_ids
302
+ with torch.no_grad():
303
+ for _ in range(max_length):
304
+ outputs = self(input_ids, past=past)
305
+ next_token_logits = outputs[0][:, -1, :]
306
+ past = outputs[1]
307
+
308
+ for i in range(batch_size):
309
+ for token_id in set(generated[i].tolist()):
310
+ next_token_logits[i, token_id] /= repetition_penalty
311
+
312
+ next_token_logits = next_token_logits / temperature
313
+ filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
314
+ next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
315
+ generated = torch.cat((generated, next_token), dim=1)
316
+
317
+ if (next_token == self.config.eos_token_id).all():
318
+ break
319
+
320
+ input_ids = next_token
321
+
322
+ return generated
323
+
324
+
325
+ class GPT2Config(object):
326
+ def __init__(
327
+ self,
328
+ vocab_size_or_config_json_file=50257,
329
+ n_positions=1024,
330
+ n_ctx=1024,
331
+ n_embd=768,
332
+ n_layer=12,
333
+ n_head=12,
334
+ layer_norm_epsilon=1e-5,
335
+ initializer_range=0.02,
336
+ ):
337
+ self.vocab_size = vocab_size_or_config_json_file
338
+ self.n_ctx = n_ctx
339
+ self.n_positions = n_positions
340
+ self.n_embd = n_embd
341
+ self.n_layer = n_layer
342
+ self.n_head = n_head
343
+ self.layer_norm_epsilon = layer_norm_epsilon
344
+ self.initializer_range = initializer_range
345
+
346
+
347
+
348
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
349
+ config = GPT2Config()
350
+ model = GPT2LMHeadModel(config)
351
+ state_dict = torch.load(r'C:\vision_model\gpt-2-Pytorch\test\gpt_today\weights\epoch_1.pth', map_location='cpu' if not torch.cuda.is_available() else None)
352
+ model = load_weight(model, state_dict)
353
+ model.to(device)
354
+ print(model)
355
+ model.eval()
356
+
357
+ tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
358
+ tokenizer.pad_token = tokenizer.eos_token
359
+
360
+
361
+
362
+ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
363
+ """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
364
+ Args:
365
+ logits: logits distribution shape (batch size x vocabulary size)
366
+ top_k > 0: keep only top k tokens with highest probability (top-k filtering).
367
+ top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
368
+ Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
369
+ """
370
+ assert logits.dim() == 2, "Expected logits dimension to be 2 (batch size x vocabulary size)"
371
+ top_k = min(top_k, logits.size(-1)) # Safety check
372
+ if top_k > 0:
373
+ # Remove all tokens with a probability less than the last token of the top-k
374
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
375
+ logits[indices_to_remove] = filter_value
376
+
377
+ if top_p > 0.0:
378
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
379
+ cumulative_probs = torch.cumsum(nn.Softmax(dim=-1)(sorted_logits), dim=-1)
380
+
381
+ # Remove tokens with cumulative probability above the threshold
382
+ sorted_indices_to_remove = cumulative_probs > top_p
383
+ # Shift the indices to the right to keep also the first token above the threshold
384
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
385
+ sorted_indices_to_remove[..., 0] = 0
386
+
387
+ # Ensure that the dimensions match
388
+ if sorted_indices_to_remove.size() != sorted_indices.size():
389
+ raise ValueError(f"Size mismatch: {sorted_indices_to_remove.size()} vs {sorted_indices.size()}")
390
+
391
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
392
+
393
+ # Expand dimensions to match logits tensor and use scatter_
394
+ for batch_idx in range(logits.size(0)):
395
+ logits[batch_idx, indices_to_remove[batch_idx]] = filter_value
396
+
397
+ return logits
398
+
399
+ # prompt_text = "What is the classical conceptualisation of oxidation and reduction in redox reactions?"
400
+ # prompt = f"\n<|startoftext|>[WP] {prompt_text} \n[RESPONSE]"
401
+ # input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
402
+
403
+
404
+ # max_length = 50
405
+ # temperature = 0.7
406
+ # top_k = 50
407
+ # top_p = 0.95
408
+ # repetition_penalty = 1.0
409
+
410
+ # with torch.no_grad():
411
+ # for _ in range(max_length):
412
+ # outputs = model(input_ids)
413
+ # logits = outputs[0]
414
+ # next_token_logits = logits[:, -1, :] / temperature
415
+
416
+ # # Apply repetition penalty
417
+ # for i in range(input_ids.size(0)):
418
+ # for token_id in set(input_ids[i].tolist()):
419
+ # next_token_logits[0, token_id] /= repetition_penalty
420
+
421
+ # # Filter logits using top-k and/or top-p filtering
422
+ # filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
423
+ # next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
424
+ # input_ids = torch.cat([input_ids, next_token], dim=-1).to(device)
425
+
426
+
427
+ # import re
428
+ # # generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
429
+ # # wp_responses = re.split(r"\[WP\].*?\n|\[RESPONSE\]", generated_text)[1:]
430
+ # print(input_ids[0])
431
+
432
+ # generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
433
+ # wp_responses = re.split(r"\[WP\].*?\n|\[RESPONSE\]", generated_text)[1:]
434
+ # print(wp_responses)
435
+
436
+
437
+ # Define the generation function
438
+ def generate_text(prompt_text, max_length=50, temperature=0.7, top_k=50, top_p=0.95, repetition_penalty=1.0):
439
+ prompt = f"\n[WP] {prompt_text} \n[RESPONSE]"
440
+ input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
441
+
442
+ with torch.no_grad():
443
+ for _ in range(max_length):
444
+ outputs = model(input_ids)
445
+ logits = outputs[0]
446
+ next_token_logits = logits[:, -1, :] / temperature
447
+
448
+ # Apply repetition penalty
449
+ for i in range(input_ids.size(0)):
450
+ for token_id in set(input_ids[i].tolist()):
451
+ next_token_logits[0, token_id] /= repetition_penalty
452
+
453
+ # Filter logits using top-k and/or top-p filtering
454
+ filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
455
+ next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
456
+ input_ids = torch.cat([input_ids, next_token], dim=-1).to(device)
457
+
458
+ generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
459
+ wp_responses = re.split(r"\[WP\].*?\n|\[RESPONSE\]", generated_text)[1:]
460
+ return wp_responses[1]
461
+
462
+ # Define the Gradio interface using Blocks
463
+ with gr.Blocks() as demo:
464
+ with gr.Row():
465
+ gr.Markdown("<h1 style='text-align: center'>GPT-2 Text Generator</h1>")
466
+ with gr.Row():
467
+ with gr.Column():
468
+ prompt = gr.Textbox(lines=2, placeholder="Enter prompt here...", label="Prompt")
469
+ max_length = gr.Slider(minimum=10, maximum=100, step=1, value=50, label="Max Length")
470
+ temperature = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=0.7, label="Temperature")
471
+ top_k = gr.Slider(minimum=0, maximum=100, step=1, value=50, label="Top K")
472
+ top_p = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.95, label="Top P")
473
+ repetition_penalty = gr.Slider(minimum=1.0, maximum=2.0, step=0.1, value=1.0, label="Repetition Penalty")
474
+ generate_button = gr.Button("Generate")
475
+ with gr.Column():
476
+ output_text = gr.Textbox(lines=20, label="Generated Text")
477
+
478
+ generate_button.click(
479
+ fn=generate_text,
480
+ inputs=[prompt, max_length, temperature, top_k, top_p, repetition_penalty],
481
+ outputs=output_text
482
+ )
483
+
484
+ demo.launch()