Spaces:
Sleeping
Sleeping
nileshhanotia
commited on
Commit
•
b0226fd
1
Parent(s):
9c83f07
Update app.py
Browse files
app.py
CHANGED
@@ -2,60 +2,100 @@ import os
|
|
2 |
import streamlit as st
|
3 |
from datasets import load_dataset
|
4 |
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
|
|
|
|
|
|
|
5 |
|
6 |
# Streamlit caching functions
|
7 |
@st.cache_data
|
8 |
-
def load_data():
|
9 |
-
|
10 |
-
|
|
|
|
|
|
|
11 |
|
12 |
@st.cache_resource
|
13 |
def initialize_model_and_tokenizer(model_name):
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
|
|
|
|
|
|
18 |
|
19 |
-
def preprocess_function(examples, tokenizer):
|
20 |
-
return tokenizer(examples['prompt'], truncation=True, padding="max_length", max_length=
|
21 |
|
22 |
def main():
|
23 |
st.title("Model Training with Streamlit")
|
24 |
-
|
25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
tokenizer, model = initialize_model_and_tokenizer(model_name)
|
|
|
|
|
|
|
|
|
27 |
|
28 |
st.write("Loading and processing dataset...")
|
29 |
-
dataset = load_data()
|
30 |
|
31 |
-
|
|
|
|
|
|
|
32 |
st.write("Tokenizing dataset...")
|
33 |
-
tokenized_dataset = dataset['train'].map(
|
|
|
|
|
|
|
|
|
34 |
|
35 |
# Define training arguments
|
36 |
training_args = TrainingArguments(
|
37 |
-
output_dir='./results',
|
38 |
-
evaluation_strategy='epoch',
|
39 |
-
learning_rate=
|
40 |
-
per_device_train_batch_size=
|
41 |
-
per_device_eval_batch_size=
|
42 |
-
num_train_epochs=
|
43 |
-
weight_decay=0.01,
|
44 |
-
logging_dir='./logs',
|
45 |
-
logging_steps=10,
|
46 |
)
|
47 |
-
|
48 |
# Initialize the Trainer
|
49 |
trainer = Trainer(
|
50 |
-
model=model,
|
51 |
-
args=training_args,
|
52 |
-
train_dataset=tokenized_dataset,
|
|
|
53 |
)
|
54 |
-
|
55 |
if st.button('Start Training'):
|
56 |
st.write("Starting training...")
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
st.write("Training complete.")
|
|
|
59 |
|
60 |
if __name__ == "__main__":
|
61 |
-
main()
|
|
|
2 |
import streamlit as st
|
3 |
from datasets import load_dataset
|
4 |
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
|
5 |
+
from transformers import TextDataset, DataCollatorForLanguageModeling
|
6 |
+
import torch
|
7 |
+
from tqdm import tqdm
|
8 |
|
9 |
# Streamlit caching functions
|
10 |
@st.cache_data
|
11 |
+
def load_data(file_path):
|
12 |
+
try:
|
13 |
+
return load_dataset('json', data_files=file_path)
|
14 |
+
except Exception as e:
|
15 |
+
st.error(f"Error loading dataset: {str(e)}")
|
16 |
+
return None
|
17 |
|
18 |
@st.cache_resource
|
19 |
def initialize_model_and_tokenizer(model_name):
|
20 |
+
try:
|
21 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
22 |
+
model = AutoModelForCausalLM.from_pretrained(model_name)
|
23 |
+
return tokenizer, model
|
24 |
+
except Exception as e:
|
25 |
+
st.error(f"Error initializing model and tokenizer: {str(e)}")
|
26 |
+
return None, None
|
27 |
|
28 |
+
def preprocess_function(examples, tokenizer, max_length):
|
29 |
+
return tokenizer(examples['prompt'], truncation=True, padding="max_length", max_length=max_length)
|
30 |
|
31 |
def main():
|
32 |
st.title("Model Training with Streamlit")
|
33 |
+
|
34 |
+
# User inputs
|
35 |
+
model_name = st.text_input("Enter model name", "distilgpt2")
|
36 |
+
file_path = st.text_input("Enter path to training data JSON file", "training_data.json")
|
37 |
+
max_length = st.number_input("Enter max token length", min_value=32, max_value=512, value=128)
|
38 |
+
num_epochs = st.number_input("Enter number of training epochs", min_value=1, max_value=10, value=3)
|
39 |
+
batch_size = st.number_input("Enter batch size", min_value=1, max_value=32, value=4)
|
40 |
+
learning_rate = st.number_input("Enter learning rate", min_value=1e-6, max_value=1e-3, value=2e-5, format="%.1e")
|
41 |
+
|
42 |
tokenizer, model = initialize_model_and_tokenizer(model_name)
|
43 |
+
|
44 |
+
if tokenizer is None or model is None:
|
45 |
+
st.warning("Failed to initialize model and tokenizer. Please check the model name and try again.")
|
46 |
+
return
|
47 |
|
48 |
st.write("Loading and processing dataset...")
|
49 |
+
dataset = load_data(file_path)
|
50 |
|
51 |
+
if dataset is None:
|
52 |
+
st.warning("Failed to load dataset. Please check the file path and try again.")
|
53 |
+
return
|
54 |
+
|
55 |
st.write("Tokenizing dataset...")
|
56 |
+
tokenized_dataset = dataset['train'].map(
|
57 |
+
lambda x: preprocess_function(x, tokenizer, max_length),
|
58 |
+
batched=True,
|
59 |
+
remove_columns=dataset['train'].column_names
|
60 |
+
)
|
61 |
|
62 |
# Define training arguments
|
63 |
training_args = TrainingArguments(
|
64 |
+
output_dir='./results',
|
65 |
+
evaluation_strategy='epoch',
|
66 |
+
learning_rate=learning_rate,
|
67 |
+
per_device_train_batch_size=batch_size,
|
68 |
+
per_device_eval_batch_size=batch_size,
|
69 |
+
num_train_epochs=num_epochs,
|
70 |
+
weight_decay=0.01,
|
71 |
+
logging_dir='./logs',
|
72 |
+
logging_steps=10,
|
73 |
)
|
74 |
+
|
75 |
# Initialize the Trainer
|
76 |
trainer = Trainer(
|
77 |
+
model=model,
|
78 |
+
args=training_args,
|
79 |
+
train_dataset=tokenized_dataset,
|
80 |
+
data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False),
|
81 |
)
|
82 |
+
|
83 |
if st.button('Start Training'):
|
84 |
st.write("Starting training...")
|
85 |
+
progress_bar = st.progress(0)
|
86 |
+
|
87 |
+
for epoch in range(int(num_epochs)):
|
88 |
+
trainer.train()
|
89 |
+
progress = (epoch + 1) / num_epochs
|
90 |
+
progress_bar.progress(progress)
|
91 |
+
|
92 |
+
# Save the model after each epoch
|
93 |
+
model_path = f"./results/model_epoch_{epoch+1}"
|
94 |
+
trainer.save_model(model_path)
|
95 |
+
st.write(f"Model saved: {model_path}")
|
96 |
+
|
97 |
st.write("Training complete.")
|
98 |
+
st.write("You can now use the trained model for inference or further fine-tuning.")
|
99 |
|
100 |
if __name__ == "__main__":
|
101 |
+
main()
|