Spaces:
Sleeping
Sleeping
import os | |
import json | |
import random | |
import streamlit as st | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments | |
from datasets import Dataset | |
import torch | |
from huggingface_hub import Repository, HfFolder | |
import subprocess | |
# Authenticate Hugging Face Hub | |
hf_token = st.secrets["HF_TOKEN"] | |
HfFolder.save_token(hf_token) | |
# Set Git user identity | |
def set_git_config(): | |
try: | |
subprocess.run(['git', 'config', '--global', 'user.email', '[email protected]'], check=True) | |
subprocess.run(['git', 'config', '--global', 'user.name', 'Nilesh'], check=True) | |
st.success("Git configuration set successfully.") | |
except subprocess.CalledProcessError as e: | |
st.error(f"Git configuration error: {str(e)}") | |
# Call set_git_config at the start of the script | |
set_git_config() | |
def load_data(file_path): | |
if not os.path.exists(file_path): | |
st.error(f"File not found: {file_path}") | |
return None | |
try: | |
with open(file_path, 'r') as f: | |
data = json.load(f) | |
return data | |
except Exception as e: | |
st.error(f"Error loading dataset: {str(e)}") | |
return None | |
def initialize_model_and_tokenizer(model_name, num_labels): | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels) | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
model.config.pad_token_id = model.config.eos_token_id | |
return tokenizer, model | |
except Exception as e: | |
st.error(f"Error initializing model and tokenizer: {str(e)}") | |
return None, None | |
def create_dataset(data, tokenizer, max_length): | |
texts = [item['prompt'] for item in data] | |
labels = [item['label'] for item in data] | |
encodings = tokenizer(texts, truncation=True, padding='max_length', max_length=max_length) | |
dataset = Dataset.from_dict({ | |
'input_ids': encodings['input_ids'], | |
'attention_mask': encodings['attention_mask'], | |
'labels': labels | |
}) | |
return dataset | |
def split_data(data, test_size=0.2): | |
random.shuffle(data) | |
split_index = int(len(data) * (1 - test_size)) | |
return data[:split_index], data[split_index:] | |
def main(): | |
st.title("Appointment Classification Model Training") | |
model_name = st.text_input("Enter model name", "distilgpt2") | |
file_path = st.text_input("Enter path to training data JSON file", "training_data.json") | |
max_length = st.number_input("Enter max token length", min_value=32, max_value=512, value=128) | |
num_epochs = st.number_input("Enter number of training epochs", min_value=1, max_value=10, value=3) | |
batch_size = st.number_input("Enter batch size", min_value=1, max_value=32, value=8) | |
learning_rate = st.number_input("Enter learning rate", min_value=1e-6, max_value=1e-3, value=5e-5, format="%.1e") | |
num_labels = 3 # We have 3 classes: schedule, reschedule, cancel | |
repo_id = st.text_input("Enter Hugging Face repository ID", "nileshhanotia/PeVe") | |
tokenizer, model = initialize_model_and_tokenizer(model_name, num_labels) | |
if tokenizer is None or model is None: | |
st.warning("Failed to initialize model and tokenizer. Please check the model name and try again.") | |
return | |
st.write("Loading and processing dataset...") | |
data = load_data(file_path) | |
if data is None: | |
st.warning("Failed to load dataset. Please check the file path and try again.") | |
return | |
st.write("Preparing dataset...") | |
# Split the data into train and evaluation sets | |
train_data, eval_data = split_data(data) | |
train_dataset = create_dataset(train_data, tokenizer, max_length) | |
eval_dataset = create_dataset(eval_data, tokenizer, max_length) | |
training_args = TrainingArguments( | |
output_dir='./results', | |
evaluation_strategy='epoch', | |
learning_rate=learning_rate, | |
per_device_train_batch_size=batch_size, | |
per_device_eval_batch_size=batch_size, | |
num_train_epochs=num_epochs, | |
weight_decay=0.01, | |
logging_dir='./logs', | |
logging_steps=10, | |
push_to_hub=True, | |
hub_model_id=repo_id, | |
) | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=train_dataset, | |
eval_dataset=eval_dataset, | |
) | |
if st.button('Start Training'): | |
st.write("Starting training...") | |
trainer.train() | |
trainer.push_to_hub() | |
st.write(f"Training complete. Model is available on the Hugging Face Hub: {repo_id}") | |
if __name__ == "__main__": | |
main() |