File size: 4,721 Bytes
0a83766
997991d
de7d627
0a83766
a8d1617
b0226fd
 
4270cfb
b2c4316
b0226fd
997991d
 
 
b0226fd
997991d
 
 
b0226fd
 
 
de7d627
a8d1617
 
 
 
 
155ecd2
 
 
 
 
 
a8d1617
 
 
 
 
997991d
 
 
 
 
 
155ecd2
 
 
 
 
 
 
 
 
 
 
997991d
4270cfb
b2c4316
 
b0226fd
997991d
b0226fd
997991d
 
b0226fd
997991d
 
b0226fd
b2c4316
b0226fd
 
 
 
b2c4316
 
997991d
b2c4316
997991d
b0226fd
 
 
b2c4316
997991d
 
 
 
 
 
 
 
155ecd2
997991d
 
a8d1617
997991d
a8d1617
b2c4316
 
 
b0226fd
3c13618
b0226fd
 
 
 
 
 
 
b2c4316
b0226fd
b2c4316
0a83766
b0226fd
 
997991d
b0226fd
0a83766
b0226fd
b2c4316
 
b0226fd
 
 
 
 
 
 
 
 
 
 
 
b2c4316
b0226fd
b2c4316
 
b0226fd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import os
import json
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
from transformers import DataCollatorForLanguageModeling
import torch
from tqdm import tqdm

@st.cache_data
def load_data(file_path):
    if not os.path.exists(file_path):
        st.error(f"File not found: {file_path}")
        return None
    try:
        with open(file_path, 'r') as f:
            data = json.load(f)
        return data
    except Exception as e:
        st.error(f"Error loading dataset: {str(e)}")
        return None

@st.cache_resource
def initialize_model_and_tokenizer(model_name):
    try:
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForCausalLM.from_pretrained(model_name)
        
        # Set the pad token to the eos token if it doesn't exist
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
            model.config.pad_token_id = model.config.eos_token_id
        
        return tokenizer, model
    except Exception as e:
        st.error(f"Error initializing model and tokenizer: {str(e)}")
        return None, None

def create_dataset(data, tokenizer, max_length):
    inputs = []
    for item in data:
        prompt = item['prompt']
        response = item['response']
        full_text = f"Human: {prompt}\nAssistant: {response}"
        encoded = tokenizer.encode_plus(
            full_text,
            max_length=max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        inputs.append({
            'input_ids': encoded['input_ids'].squeeze(),
            'attention_mask': encoded['attention_mask'].squeeze()
        })
    return inputs

def main():
    st.title("Model Training with Streamlit")

    # User inputs with recommended values
    model_name = st.text_input("Enter model name", "distilgpt2")
    file_path = st.text_input("Enter path to training data JSON file", "appointment_training_data.json")
    max_length = st.number_input("Enter max token length", min_value=32, max_value=512, value=256)
    num_epochs = st.number_input("Enter number of training epochs", min_value=1, max_value=10, value=3)
    batch_size = st.number_input("Enter batch size", min_value=1, max_value=32, value=8)
    learning_rate = st.number_input("Enter learning rate", min_value=1e-6, max_value=1e-3, value=5e-5, format="%.1e")

    tokenizer, model = initialize_model_and_tokenizer(model_name)
    
    if tokenizer is None or model is None:
        st.warning("Failed to initialize model and tokenizer. Please check the model name and try again.")
        return

    st.write("Loading and processing dataset...")
    data = load_data(file_path)
    
    if data is None:
        st.warning("Failed to load dataset. Please check the file path and try again.")
        return

    st.write("Tokenizing dataset...")
    tokenized_dataset = create_dataset(data, tokenizer, max_length)

    # Convert tokenized_dataset to a torch Dataset
    class SimpleDataset(torch.utils.data.Dataset):
        def __init__(self, encodings):
            self.encodings = encodings

        def __getitem__(self, idx):
            return {key: val[idx] for key, val in self.encodings[idx].items()}

        def __len__(self):
            return len(self.encodings)

    dataset = SimpleDataset(tokenized_dataset)

    # Define training arguments
    training_args = TrainingArguments(
        output_dir='./results',
        evaluation_strategy='no',
        learning_rate=learning_rate,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        num_train_epochs=num_epochs,
        weight_decay=0.01,
        logging_dir='./logs',
        logging_steps=10,
    )

    # Initialize the Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dataset,
        data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False),
    )

    if st.button('Start Training'):
        st.write("Starting training...")
        progress_bar = st.progress(0)
        
        for epoch in range(int(num_epochs)):
            trainer.train()
            progress = (epoch + 1) / num_epochs
            progress_bar.progress(progress)
            
            # Save the model after each epoch
            model_path = f"./results/model_epoch_{epoch+1}"
            trainer.save_model(model_path)
            st.write(f"Model saved: {model_path}")

        st.write("Training complete.")
        st.write("You can now use the trained model for inference or further fine-tuning.")

if __name__ == "__main__":
    main()