ms1449 commited on
Commit
c51ba33
·
verified ·
1 Parent(s): 0263fca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -16
app.py CHANGED
@@ -18,28 +18,22 @@ if st.button("Generate Blog Post"):
18
  if topic:
19
  # Prepare the prompt
20
  prompt = f"Write a blog post about {topic}:\n\n"
21
-
 
 
 
22
  # Tokenize the input
23
  input_ids = tokenizer.encode(prompt, return_tensors="pt")
 
 
 
24
 
25
- # Generate text
26
- with torch.no_grad():
27
- output = model.generate(
28
- input_ids,
29
- max_length=500,
30
- num_return_sequences=1,
31
- no_repeat_ngram_size=2,
32
- top_k=50,
33
- top_p=0.95,
34
- temperature=0.7
35
- )
36
-
37
- # Decode the generated text
38
- generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
39
 
40
  # Display the generated blog post
41
  st.subheader("Generated Blog Post:")
42
- st.write(generated_text)
43
  else:
44
  st.warning("Please enter a topic.")
45
 
 
18
  if topic:
19
  # Prepare the prompt
20
  prompt = f"Write a blog post about {topic}:\n\n"
21
+
22
+ # Generate text
23
+ generation_config = GenerationConfig(max_new_tokens=50, do_sample=True, temperature=0.7)
24
+
25
  # Tokenize the input
26
  input_ids = tokenizer.encode(prompt, return_tensors="pt")
27
+
28
+ # Model output
29
+ model_output = model.generate(inputs_encoded["input_ids"], generation_config=generation_config)[0]
30
 
31
+ # Decode the output
32
+ output = tokenizer.decode(model_output, skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  # Display the generated blog post
35
  st.subheader("Generated Blog Post:")
36
+ st.write(output)
37
  else:
38
  st.warning("Please enter a topic.")
39