talha2001 commited on
Commit
d137226
·
verified ·
1 Parent(s): a7717fe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -63
app.py CHANGED
@@ -1,76 +1,40 @@
1
  import streamlit as st
2
- from datasets import load_dataset, DatasetDict
3
- from transformers import GPT2Tokenizer, GPT2LMHeadModel, DataCollatorForLanguageModeling, Trainer, TrainingArguments
4
 
5
- @st.cache_resource
6
- def load_and_fine_tune_model():
7
- # Load the dataset
8
- dataset = load_dataset("blog_authorship_corpus")
9
 
10
- # Split the dataset into train and validation sets
11
- dataset = dataset['train'].train_test_split(test_size=0.1)
12
-
13
- # Load the tokenizer
14
- tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
15
-
16
- def tokenize_function(examples):
17
- return tokenizer(examples["text"], truncation=True)
18
-
19
- tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=["text"])
20
-
21
- # Data collator
22
- data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
23
-
24
- # Load the model
25
- model = GPT2LMHeadModel.from_pretrained("gpt2")
26
-
27
- # Training arguments
28
- training_args = TrainingArguments(
29
- output_dir="./results",
30
- overwrite_output_dir=True,
31
- num_train_epochs=1,
32
- per_device_train_batch_size=2,
33
- save_steps=10_000,
34
- save_total_limit=2,
35
- )
36
 
37
- # Initialize the Trainer
38
- trainer = Trainer(
39
- model=model,
40
- args=training_args,
41
- data_collator=data_collator,
42
- train_dataset=tokenized_datasets['train'],
43
- eval_dataset=tokenized_datasets['test']
44
- )
45
-
46
- # Fine-tune the model
47
- trainer.train()
48
-
49
- # Save the fine-tuned model
50
- model.save_pretrained("./fine-tuned-gpt2")
51
- tokenizer.save_pretrained("./fine-tuned-gpt2")
52
- return model, tokenizer
53
-
54
- def generate_blog_post(prompt, model, tokenizer, max_length=500, temperature=0.7, top_k=50):
55
- input_ids = tokenizer.encode(prompt, return_tensors="pt")
56
  output = model.generate(
57
  input_ids,
58
  max_length=max_length,
59
- temperature=temperature,
60
- top_k=top_k,
61
  no_repeat_ngram_size=2,
62
- num_return_sequences=1
63
  )
64
- generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
65
- return generated_text
66
 
67
- # Streamlit UI
 
 
 
 
 
68
  st.title("Blog Post Generator")
69
- prompt = st.text_input("Enter a prompt for the blog post:", "The future of artificial intelligence in daily life")
 
 
70
 
71
  if st.button("Generate Blog Post"):
72
- with st.spinner("Fine-tuning the model. This might take a few minutes..."):
73
- model, tokenizer = load_and_fine_tune_model()
74
- blog_post = generate_blog_post(prompt, model, tokenizer)
75
- st.subheader("Generated Blog Post")
76
- st.write(blog_post)
 
 
1
  import streamlit as st
2
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
 
3
 
4
+ # Load pre-trained GPT-2 model and tokenizer
5
+ model_name = 'gpt2'
6
+ model = GPT2LMHeadModel.from_pretrained(model_name)
7
+ tokenizer = GPT2Tokenizer.from_pretrained(model_name)
8
 
9
+ # Function to generate blog post for a given topic
10
+ def generate_blog_post(topic, max_length=300):
11
+ # Encode the input topic into tokens
12
+ input_ids = tokenizer.encode(topic, return_tensors='pt')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ # Generate text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  output = model.generate(
16
  input_ids,
17
  max_length=max_length,
18
+ num_return_sequences=1,
 
19
  no_repeat_ngram_size=2,
20
+ early_stopping=True
21
  )
 
 
22
 
23
+ # Decode the output tokens into a string
24
+ blog_post = tokenizer.decode(output[0], skip_special_tokens=True)
25
+
26
+ return blog_post
27
+
28
+ # Streamlit app
29
  st.title("Blog Post Generator")
30
+
31
+ topic = st.text_input("Enter a topic for the blog post:")
32
+ max_length = st.slider("Maximum length of the blog post:", min_value=50, max_value=1000, value=300)
33
 
34
  if st.button("Generate Blog Post"):
35
+ if topic:
36
+ with st.spinner('Generating blog post...'):
37
+ blog_post = generate_blog_post(topic, max_length)
38
+ st.write(blog_post)
39
+ else:
40
+ st.warning("Please enter a topic.")