Spaces:
Sleeping
Sleeping
File size: 5,780 Bytes
0a83766 997991d 59782fa de7d627 145afb1 1a4cbe3 b0226fd 8ad732e 3caf963 8ad732e c8ab462 c899f78 c8ab462 2882051 c8ab462 2882051 de7d627 59782fa 1a4cbe3 c4b5351 1a4cbe3 c4b5351 1a4cbe3 59782fa a8d1617 145afb1 a8d1617 145afb1 155ecd2 a8d1617 3caf963 145afb1 c4b5351 145afb1 ca9766d 145afb1 3caf963 ca9766d 59782fa c4b5351 59782fa b2c4316 48d786d b0226fd 1a4cbe3 48d786d b0226fd 997991d 48d786d 8ad732e b0226fd 145afb1 b0226fd b2c4316 1a4cbe3 b2c4316 997991d 1a4cbe3 b0226fd 145afb1 c8ab462 c4b5351 c8ab462 b2c4316 b0226fd 145afb1 b0226fd c8ab462 b0226fd 145afb1 b2c4316 b0226fd 0a83766 b0226fd c8ab462 0a83766 b0226fd b2c4316 145afb1 b2c4316 1a4cbe3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import os
import json
import random
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from datasets import Dataset, concatenate_datasets
import torch
from huggingface_hub import Repository, HfFolder
import subprocess
# Authenticate Hugging Face Hub
hf_token = st.secrets["HF_TOKEN"]
HfFolder.save_token(hf_token)
# Set Git user identity
def set_git_config():
try:
subprocess.run(['git', 'config', '--global', 'user.email', '[email protected]'], check=True)
subprocess.run(['git', 'config', '--global', 'user.name', 'Nilesh'], check=True)
st.success("Git configuration set successfully.")
except subprocess.CalledProcessError as e:
st.error(f"Git configuration error: {str(e)}")
# Call set_git_config at the start of the script
set_git_config()
@st.cache_data
def load_data(file_paths):
combined_data = []
for file_path in file_paths:
file_path = file_path.strip()
if not os.path.exists(file_path):
st.error(f"File not found: {file_path}")
return None
try:
with open(file_path, 'r') as f:
data = json.load(f)
if 'intents' in data:
for intent in data['intents']:
combined_data.extend(intent['examples'])
else:
st.error(f"Invalid format in file: {file_path}")
return None
except Exception as e:
st.error(f"Error loading dataset from {file_path}: {str(e)}")
return None
return combined_data
@st.cache_resource
def initialize_model_and_tokenizer(model_name, num_labels):
try:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id
return tokenizer, model
except Exception as e:
st.error(f"Error initializing model and tokenizer: {str(e)}")
return None, None
def create_dataset(data, tokenizer, max_length):
texts = [item.get('prompt', '') for item in data]
labels = [item.get('label', -1) for item in data]
# Debugging: Print out labels to check for invalid values
print(f"Labels before adjustment: {labels}")
# Ensure all labels are within the valid range
labels = [label if 0 <= label < num_labels else 0 for label in labels]
# Debugging: Print out adjusted labels
print(f"Labels after adjustment: {labels}")
encodings = tokenizer(texts, truncation=True, padding='max_length', max_length=max_length)
dataset = Dataset.from_dict({
'input_ids': encodings['input_ids'],
'attention_mask': encodings['attention_mask'],
'labels': labels
})
return dataset
def split_data(data, test_size=0.2):
if not data:
raise ValueError("Data is empty, cannot split.")
random.shuffle(data)
split_index = int(len(data) * (1 - test_size))
return data[:split_index], data[split_index:]
def main():
st.title("Appointment Classification Model Training")
model_name = st.text_input("Enter model name", "distilgpt2")
file_paths = st.text_area("Enter paths to training data JSON files (comma-separated)", "training_data1.json,training_data2.json").split(',')
max_length = st.number_input("Enter max token length", min_value=32, max_value=512, value=128)
num_epochs = st.number_input("Enter number of training epochs", min_value=1, max_value=10, value=3)
batch_size = st.number_input("Enter batch size", min_value=1, max_value=32, value=8)
learning_rate = st.number_input("Enter learning rate", min_value=1e-6, max_value=1e-3, value=5e-5, format="%.1e")
num_labels = 3 # We have 3 classes: schedule, reschedule, cancel
repo_id = st.text_input("Enter Hugging Face repository ID", "nileshhanotia/PeVe")
tokenizer, model = initialize_model_and_tokenizer(model_name, num_labels)
if tokenizer is None or model is None:
st.warning("Failed to initialize model and tokenizer. Please check the model name and try again.")
return
st.write("Loading and processing dataset...")
data = load_data(file_paths)
if data is None:
st.warning("Failed to load dataset. Please check the file paths and try again.")
return
st.write("Preparing dataset...")
# Split the data into train and evaluation sets
try:
train_data, eval_data = split_data(data)
except ValueError as e:
st.error(f"Data splitting error: {str(e)}")
return
train_dataset = create_dataset(train_data, tokenizer, max_length)
eval_dataset = create_dataset(eval_data, tokenizer, max_length)
training_args = TrainingArguments(
output_dir='./results',
evaluation_strategy='epoch',
learning_rate=learning_rate,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
num_train_epochs=num_epochs,
weight_decay=0.01,
logging_dir='./logs',
logging_steps=10,
push_to_hub=True,
hub_model_id=repo_id,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
if st.button('Start Training'):
st.write("Starting training...")
trainer.train()
trainer.push_to_hub()
st.write(f"Training complete. Model is available on the Hugging Face Hub: {repo_id}")
if __name__ == "__main__":
main()
|