PeVe_mistral / app.py
nileshhanotia's picture
Update app.py
b0226fd verified
raw
history blame
3.66 kB
import os
import streamlit as st
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
from transformers import TextDataset, DataCollatorForLanguageModeling
import torch
from tqdm import tqdm
# Streamlit caching functions
@st.cache_data
def load_data(file_path):
try:
return load_dataset('json', data_files=file_path)
except Exception as e:
st.error(f"Error loading dataset: {str(e)}")
return None
@st.cache_resource
def initialize_model_and_tokenizer(model_name):
try:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
return tokenizer, model
except Exception as e:
st.error(f"Error initializing model and tokenizer: {str(e)}")
return None, None
def preprocess_function(examples, tokenizer, max_length):
return tokenizer(examples['prompt'], truncation=True, padding="max_length", max_length=max_length)
def main():
st.title("Model Training with Streamlit")
# User inputs
model_name = st.text_input("Enter model name", "distilgpt2")
file_path = st.text_input("Enter path to training data JSON file", "training_data.json")
max_length = st.number_input("Enter max token length", min_value=32, max_value=512, value=128)
num_epochs = st.number_input("Enter number of training epochs", min_value=1, max_value=10, value=3)
batch_size = st.number_input("Enter batch size", min_value=1, max_value=32, value=4)
learning_rate = st.number_input("Enter learning rate", min_value=1e-6, max_value=1e-3, value=2e-5, format="%.1e")
tokenizer, model = initialize_model_and_tokenizer(model_name)
if tokenizer is None or model is None:
st.warning("Failed to initialize model and tokenizer. Please check the model name and try again.")
return
st.write("Loading and processing dataset...")
dataset = load_data(file_path)
if dataset is None:
st.warning("Failed to load dataset. Please check the file path and try again.")
return
st.write("Tokenizing dataset...")
tokenized_dataset = dataset['train'].map(
lambda x: preprocess_function(x, tokenizer, max_length),
batched=True,
remove_columns=dataset['train'].column_names
)
# Define training arguments
training_args = TrainingArguments(
output_dir='./results',
evaluation_strategy='epoch',
learning_rate=learning_rate,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
num_train_epochs=num_epochs,
weight_decay=0.01,
logging_dir='./logs',
logging_steps=10,
)
# Initialize the Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False),
)
if st.button('Start Training'):
st.write("Starting training...")
progress_bar = st.progress(0)
for epoch in range(int(num_epochs)):
trainer.train()
progress = (epoch + 1) / num_epochs
progress_bar.progress(progress)
# Save the model after each epoch
model_path = f"./results/model_epoch_{epoch+1}"
trainer.save_model(model_path)
st.write(f"Model saved: {model_path}")
st.write("Training complete.")
st.write("You can now use the trained model for inference or further fine-tuning.")
if __name__ == "__main__":
main()