import os import json import random import streamlit as st from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments from datasets import Dataset import torch from huggingface_hub import Repository, HfFolder import subprocess # Authenticate Hugging Face Hub hf_token = st.secrets["HF_TOKEN"] HfFolder.save_token(hf_token) # Set Git user identity def set_git_config(): try: subprocess.run(['git', 'config', '--global', 'user.email', 'nilesh.hanotia@outlook.com'], check=True) subprocess.run(['git', 'config', '--global', 'user.name', 'Nilesh'], check=True) st.success("Git configuration set successfully.") except subprocess.CalledProcessError as e: st.error(f"Git configuration error: {str(e)}") # Call set_git_config at the start of the script set_git_config() @st.cache_data def load_data(file_path): if not os.path.exists(file_path): st.error(f"File not found: {file_path}") return None try: with open(file_path, 'r') as f: data = json.load(f) return data except Exception as e: st.error(f"Error loading dataset: {str(e)}") return None @st.cache_resource def initialize_model_and_tokenizer(model_name, num_labels): try: tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model.config.pad_token_id = model.config.eos_token_id return tokenizer, model except Exception as e: st.error(f"Error initializing model and tokenizer: {str(e)}") return None, None def create_dataset(data, tokenizer, max_length): texts = [item['prompt'] for item in data] labels = [item['label'] for item in data] encodings = tokenizer(texts, truncation=True, padding='max_length', max_length=max_length) dataset = Dataset.from_dict({ 'input_ids': encodings['input_ids'], 'attention_mask': encodings['attention_mask'], 'labels': labels }) return dataset def split_data(data, test_size=0.2): random.shuffle(data) split_index = int(len(data) * (1 - test_size)) return data[:split_index], data[split_index:] def main(): st.title("Appointment Classification Model Training") model_name = st.text_input("Enter model name", "distilgpt2") file_path = st.text_input("Enter path to training data JSON file", "training_data.json") max_length = st.number_input("Enter max token length", min_value=32, max_value=512, value=128) num_epochs = st.number_input("Enter number of training epochs", min_value=1, max_value=10, value=3) batch_size = st.number_input("Enter batch size", min_value=1, max_value=32, value=8) learning_rate = st.number_input("Enter learning rate", min_value=1e-6, max_value=1e-3, value=5e-5, format="%.1e") num_labels = 3 # We have 3 classes: schedule, reschedule, cancel repo_id = st.text_input("Enter Hugging Face repository ID", "nileshhanotia/PeVe") tokenizer, model = initialize_model_and_tokenizer(model_name, num_labels) if tokenizer is None or model is None: st.warning("Failed to initialize model and tokenizer. Please check the model name and try again.") return st.write("Loading and processing dataset...") data = load_data(file_path) if data is None: st.warning("Failed to load dataset. Please check the file path and try again.") return st.write("Preparing dataset...") # Split the data into train and evaluation sets train_data, eval_data = split_data(data) train_dataset = create_dataset(train_data, tokenizer, max_length) eval_dataset = create_dataset(eval_data, tokenizer, max_length) training_args = TrainingArguments( output_dir='./results', evaluation_strategy='epoch', learning_rate=learning_rate, per_device_train_batch_size=batch_size, per_device_eval_batch_size=batch_size, num_train_epochs=num_epochs, weight_decay=0.01, logging_dir='./logs', logging_steps=10, push_to_hub=True, hub_model_id=repo_id, ) trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, ) if st.button('Start Training'): st.write("Starting training...") trainer.train() trainer.push_to_hub() st.write(f"Training complete. Model is available on the Hugging Face Hub: {repo_id}") if __name__ == "__main__": main()