wop commited on
Commit
48b3788
·
verified ·
1 Parent(s): 2d0f9fd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -21
app.py CHANGED
@@ -13,34 +13,38 @@ model = GPT2LMHeadModel.from_pretrained(model_path).to(device)
13
  # Set model to evaluation mode
14
  model.eval()
15
 
16
- # Function to generate text based on input prompt
17
  def generate_text(prompt):
18
  # Tokenize and encode the input prompt
19
  input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
 
20
 
21
- # Generate continuation
22
  with torch.no_grad():
23
- generated_ids = model.generate(
24
  input_ids,
25
- max_length=50, # Maximum length of generated text
26
- num_return_sequences=1, # Generate 1 sequence
27
- pad_token_id=tokenizer.eos_token_id, # Use EOS token for padding
28
- do_sample=True, # Enable sampling
29
- top_k=50, # Top-k sampling
30
- top_p=0.95 # Nucleus sampling
31
- )
32
-
33
- # Decode the generated text
34
- generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
35
- return generated_text
36
-
37
- # Create a Gradio interface
 
 
 
38
  interface = gr.Interface(
39
- fn=generate_text, # Function to call when interacting with the UI
40
- inputs="text", # Input type: Single-line text
41
- outputs="text", # Output type: Text (the generated output)
42
- title="Quble Text Generation", # Title of the UI
43
- description="Enter a prompt to generate text using Quble." # Simple description
44
  )
45
 
46
  # Launch the Gradio app
 
13
  # Set model to evaluation mode
14
  model.eval()
15
 
16
+ # Function to generate text in a stream-based manner
17
  def generate_text(prompt):
18
  # Tokenize and encode the input prompt
19
  input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
20
+ max_length = 50 # Maximum length of generated text
21
 
22
+ # Generate continuation with streaming tokens
23
  with torch.no_grad():
24
+ for generated_ids in model.generate(
25
  input_ids,
26
+ max_length=max_length,
27
+ num_return_sequences=1,
28
+ pad_token_id=tokenizer.eos_token_id,
29
+ do_sample=True,
30
+ top_k=50,
31
+ top_p=0.95,
32
+ output_scores=True, # Include scores for sampling
33
+ return_dict_in_generate=True,
34
+ use_cache=True
35
+ ).sequences:
36
+
37
+ # Decode each step incrementally
38
+ decoded_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
39
+ yield decoded_text # Stream the partial text back to the UI
40
+
41
+ # Create a Gradio interface with streaming enabled
42
  interface = gr.Interface(
43
+ fn=generate_text, # Function to call when interacting with the UI
44
+ inputs="text", # Input type: Single-line text
45
+ outputs=gr.Markdown(), # Stream output using Markdown
46
+ title="Quble Text Generation", # Title of the UI
47
+ description="Enter a prompt to generate text using Quble with live streaming." # Simple description
48
  )
49
 
50
  # Launch the Gradio app