Spaces:
Sleeping
Sleeping
import os | |
import json | |
import streamlit as st | |
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments | |
from transformers import DataCollatorForLanguageModeling | |
import torch | |
from tqdm import tqdm | |
def load_data(file_path): | |
if not os.path.exists(file_path): | |
st.error(f"File not found: {file_path}") | |
return None | |
try: | |
with open(file_path, 'r') as f: | |
data = json.load(f) | |
return data | |
except Exception as e: | |
st.error(f"Error loading dataset: {str(e)}") | |
return None | |
def initialize_model_and_tokenizer(model_name): | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
# Set the pad token to the eos token if it doesn't exist | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
model.config.pad_token_id = model.config.eos_token_id | |
return tokenizer, model | |
except Exception as e: | |
st.error(f"Error initializing model and tokenizer: {str(e)}") | |
return None, None | |
def create_dataset(data, tokenizer, max_length): | |
inputs = [] | |
for item in data: | |
prompt = item['prompt'] | |
response = item['response'] | |
full_text = f"Human: {prompt}\nAssistant: {response}" | |
encoded = tokenizer.encode_plus( | |
full_text, | |
max_length=max_length, | |
padding='max_length', | |
truncation=True, | |
return_tensors='pt' | |
) | |
inputs.append({ | |
'input_ids': encoded['input_ids'].squeeze(), | |
'attention_mask': encoded['attention_mask'].squeeze() | |
}) | |
return inputs | |
def main(): | |
st.title("Model Training with Streamlit") | |
# User inputs with recommended values | |
model_name = st.text_input("Enter model name", "distilgpt2") | |
file_path = st.text_input("Enter path to training data JSON file", "appointment_training_data.json") | |
max_length = st.number_input("Enter max token length", min_value=32, max_value=512, value=256) | |
num_epochs = st.number_input("Enter number of training epochs", min_value=1, max_value=10, value=3) | |
batch_size = st.number_input("Enter batch size", min_value=1, max_value=32, value=8) | |
learning_rate = st.number_input("Enter learning rate", min_value=1e-6, max_value=1e-3, value=5e-5, format="%.1e") | |
tokenizer, model = initialize_model_and_tokenizer(model_name) | |
if tokenizer is None or model is None: | |
st.warning("Failed to initialize model and tokenizer. Please check the model name and try again.") | |
return | |
st.write("Loading and processing dataset...") | |
data = load_data(file_path) | |
if data is None: | |
st.warning("Failed to load dataset. Please check the file path and try again.") | |
return | |
st.write("Tokenizing dataset...") | |
tokenized_dataset = create_dataset(data, tokenizer, max_length) | |
# Convert tokenized_dataset to a torch Dataset | |
class SimpleDataset(torch.utils.data.Dataset): | |
def __init__(self, encodings): | |
self.encodings = encodings | |
def __getitem__(self, idx): | |
return {key: val[idx] for key, val in self.encodings[idx].items()} | |
def __len__(self): | |
return len(self.encodings) | |
dataset = SimpleDataset(tokenized_dataset) | |
# Define training arguments | |
training_args = TrainingArguments( | |
output_dir='./results', | |
evaluation_strategy='no', | |
learning_rate=learning_rate, | |
per_device_train_batch_size=batch_size, | |
per_device_eval_batch_size=batch_size, | |
num_train_epochs=num_epochs, | |
weight_decay=0.01, | |
logging_dir='./logs', | |
logging_steps=10, | |
) | |
# Initialize the Trainer | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=dataset, | |
data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False), | |
) | |
if st.button('Start Training'): | |
st.write("Starting training...") | |
progress_bar = st.progress(0) | |
for epoch in range(int(num_epochs)): | |
trainer.train() | |
progress = (epoch + 1) / num_epochs | |
progress_bar.progress(progress) | |
# Save the model after each epoch | |
model_path = f"./results/model_epoch_{epoch+1}" | |
trainer.save_model(model_path) | |
st.write(f"Model saved: {model_path}") | |
st.write("Training complete.") | |
st.write("You can now use the trained model for inference or further fine-tuning.") | |
if __name__ == "__main__": | |
main() |