import os import streamlit as st from datasets import load_dataset from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments from transformers import TextDataset, DataCollatorForLanguageModeling import torch from tqdm import tqdm # Streamlit caching functions @st.cache_data def load_data(file_path): try: return load_dataset('json', data_files=file_path) except Exception as e: st.error(f"Error loading dataset: {str(e)}") return None @st.cache_resource def initialize_model_and_tokenizer(model_name): try: tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name) return tokenizer, model except Exception as e: st.error(f"Error initializing model and tokenizer: {str(e)}") return None, None def preprocess_function(examples, tokenizer, max_length): return tokenizer(examples['prompt'], truncation=True, padding="max_length", max_length=max_length) def main(): st.title("Model Training with Streamlit") # User inputs model_name = st.text_input("Enter model name", "distilgpt2") file_path = st.text_input("Enter path to training data JSON file", "training_data.json") max_length = st.number_input("Enter max token length", min_value=32, max_value=512, value=128) num_epochs = st.number_input("Enter number of training epochs", min_value=1, max_value=10, value=3) batch_size = st.number_input("Enter batch size", min_value=1, max_value=32, value=4) learning_rate = st.number_input("Enter learning rate", min_value=1e-6, max_value=1e-3, value=2e-5, format="%.1e") tokenizer, model = initialize_model_and_tokenizer(model_name) if tokenizer is None or model is None: st.warning("Failed to initialize model and tokenizer. Please check the model name and try again.") return st.write("Loading and processing dataset...") dataset = load_data(file_path) if dataset is None: st.warning("Failed to load dataset. Please check the file path and try again.") return st.write("Tokenizing dataset...") tokenized_dataset = dataset['train'].map( lambda x: preprocess_function(x, tokenizer, max_length), batched=True, remove_columns=dataset['train'].column_names ) # Define training arguments training_args = TrainingArguments( output_dir='./results', evaluation_strategy='epoch', learning_rate=learning_rate, per_device_train_batch_size=batch_size, per_device_eval_batch_size=batch_size, num_train_epochs=num_epochs, weight_decay=0.01, logging_dir='./logs', logging_steps=10, ) # Initialize the Trainer trainer = Trainer( model=model, args=training_args, train_dataset=tokenized_dataset, data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False), ) if st.button('Start Training'): st.write("Starting training...") progress_bar = st.progress(0) for epoch in range(int(num_epochs)): trainer.train() progress = (epoch + 1) / num_epochs progress_bar.progress(progress) # Save the model after each epoch model_path = f"./results/model_epoch_{epoch+1}" trainer.save_model(model_path) st.write(f"Model saved: {model_path}") st.write("Training complete.") st.write("You can now use the trained model for inference or further fine-tuning.") if __name__ == "__main__": main()