ahabb commited on
Commit
3356732
·
verified ·
1 Parent(s): 7602acd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -10
app.py CHANGED
@@ -1,17 +1,27 @@
1
  import gradio as gr
2
- from transformers import pipeline
3
 
4
- # Initialize the text generation pipeline
5
- generator = pipeline('text-generation', model='gpt2')
 
6
 
7
- def generate_blogpost(topic, max_length=500):
8
  prompt = f"Write a blog post about {topic}:\n\n"
9
-
10
- # Generate the blog post
11
- generated_text = generator(prompt, max_length=max_length, num_return_sequences=1)[0]['generated_text']
 
 
 
 
 
 
 
 
 
12
 
13
  # Remove the prompt from the generated text
14
- blog_post = generated_text[len(prompt):].strip()
15
 
16
  return blog_post
17
 
@@ -20,7 +30,8 @@ iface = gr.Interface(
20
  fn=generate_blogpost,
21
  inputs=[
22
  gr.Textbox(lines=1, placeholder="Enter the blog post topic here..."),
23
- gr.Slider(minimum=100, maximum=1000, step=50, label="Max Length", value=500)
 
24
  ],
25
  outputs="text",
26
  title="Blog Post Generator",
@@ -28,4 +39,4 @@ iface = gr.Interface(
28
  )
29
 
30
  # Launch the app
31
- iface.launch()
 
1
  import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
 
4
+ model_name = 'openai-community/gpt2-large'
5
+ model = AutoModelForCausalLM.from_pretrained(model_name)
6
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
7
 
8
+ def generate_blogpost(topic, max_length=500, temperature=0.7):
9
  prompt = f"Write a blog post about {topic}:\n\n"
10
+
11
+ # Encode input:
12
+ inputs_encoded = tokenizer(prompt, return_tensors='pt')
13
+ # Model Output:
14
+ model_output = model.generate(
15
+ inputs_encoded["input_ids"],
16
+ max_new_tokens=max_length,
17
+ do_sample=True,
18
+ temperature=temperature
19
+ )[0]
20
+ # Decode the output
21
+ output = tokenizer.decode(model_output, skip_special_tokens=True)
22
 
23
  # Remove the prompt from the generated text
24
+ blog_post = output[len(prompt):].strip()
25
 
26
  return blog_post
27
 
 
30
  fn=generate_blogpost,
31
  inputs=[
32
  gr.Textbox(lines=1, placeholder="Enter the blog post topic here..."),
33
+ gr.Slider(minimum=100, maximum=1000, step=50, label="Max Length", value=500),
34
+ gr.Slider(minimum=0.1, maximum=1.0, step=0.1, label="Temperature", value=0.7)
35
  ],
36
  outputs="text",
37
  title="Blog Post Generator",
 
39
  )
40
 
41
  # Launch the app
42
+ iface.launch()