nileshhanotia commited on
Commit
997991d
1 Parent(s): b9365f3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -28
app.py CHANGED
@@ -1,43 +1,47 @@
1
  import os
 
2
  import streamlit as st
3
- from datasets import load_dataset
4
  from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
5
  from transformers import TextDataset, DataCollatorForLanguageModeling
6
  import torch
7
  from tqdm import tqdm
8
 
9
- # Streamlit caching functions
 
 
10
  @st.cache_data
11
  def load_data(file_path):
 
 
 
12
  try:
13
- return load_dataset('json', data_files=file_path)
 
 
14
  except Exception as e:
15
  st.error(f"Error loading dataset: {str(e)}")
16
  return None
17
 
18
- @st.cache_resource
19
- def initialize_model_and_tokenizer(model_name):
20
- try:
21
- tokenizer = AutoTokenizer.from_pretrained(model_name)
22
- model = AutoModelForCausalLM.from_pretrained(model_name)
23
- return tokenizer, model
24
- except Exception as e:
25
- st.error(f"Error initializing model and tokenizer: {str(e)}")
26
- return None, None
27
-
28
- def preprocess_function(examples, tokenizer, max_length):
29
- return tokenizer(examples['prompt'], truncation=True, padding="max_length", max_length=max_length)
30
 
31
  def main():
32
  st.title("Model Training with Streamlit")
33
 
34
- # User inputs
35
  model_name = st.text_input("Enter model name", "distilgpt2")
36
- file_path = st.text_input("Enter path to training data JSON file", "training_data.json")
37
- max_length = st.number_input("Enter max token length", min_value=32, max_value=512, value=128)
38
  num_epochs = st.number_input("Enter number of training epochs", min_value=1, max_value=10, value=3)
39
- batch_size = st.number_input("Enter batch size", min_value=1, max_value=32, value=4)
40
- learning_rate = st.number_input("Enter learning rate", min_value=1e-6, max_value=1e-3, value=2e-5, format="%.1e")
41
 
42
  tokenizer, model = initialize_model_and_tokenizer(model_name)
43
 
@@ -46,18 +50,28 @@ def main():
46
  return
47
 
48
  st.write("Loading and processing dataset...")
49
- dataset = load_data(file_path)
50
 
51
- if dataset is None:
52
  st.warning("Failed to load dataset. Please check the file path and try again.")
53
  return
54
 
55
  st.write("Tokenizing dataset...")
56
- tokenized_dataset = dataset['train'].map(
57
- lambda x: preprocess_function(x, tokenizer, max_length),
58
- batched=True,
59
- remove_columns=dataset['train'].column_names
60
- )
 
 
 
 
 
 
 
 
 
 
61
 
62
  # Define training arguments
63
  training_args = TrainingArguments(
@@ -76,7 +90,7 @@ def main():
76
  trainer = Trainer(
77
  model=model,
78
  args=training_args,
79
- train_dataset=tokenized_dataset,
80
  data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False),
81
  )
82
 
 
1
  import os
2
+ import json
3
  import streamlit as st
 
4
  from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
5
  from transformers import TextDataset, DataCollatorForLanguageModeling
6
  import torch
7
  from tqdm import tqdm
8
 
9
+ # Remove the datasets import as we won't be using it anymore
10
+ # from datasets import load_dataset
11
+
12
  @st.cache_data
13
  def load_data(file_path):
14
+ if not os.path.exists(file_path):
15
+ st.error(f"File not found: {file_path}")
16
+ return None
17
  try:
18
+ with open(file_path, 'r') as f:
19
+ data = json.load(f)
20
+ return data
21
  except Exception as e:
22
  st.error(f"Error loading dataset: {str(e)}")
23
  return None
24
 
25
+ def create_dataset(data, tokenizer, max_length):
26
+ inputs = []
27
+ for item in data:
28
+ prompt = item['prompt']
29
+ response = item['response']
30
+ full_text = f"Human: {prompt}\nAssistant: {response}"
31
+ encoded = tokenizer.encode(full_text, truncation=True, max_length=max_length, padding='max_length')
32
+ inputs.append(encoded)
33
+ return inputs
 
 
 
34
 
35
  def main():
36
  st.title("Model Training with Streamlit")
37
 
38
+ # User inputs with recommended values
39
  model_name = st.text_input("Enter model name", "distilgpt2")
40
+ file_path = st.text_input("Enter path to training data JSON file", "appointment_training_data.json")
41
+ max_length = st.number_input("Enter max token length", min_value=32, max_value=512, value=256)
42
  num_epochs = st.number_input("Enter number of training epochs", min_value=1, max_value=10, value=3)
43
+ batch_size = st.number_input("Enter batch size", min_value=1, max_value=32, value=8)
44
+ learning_rate = st.number_input("Enter learning rate", min_value=1e-6, max_value=1e-3, value=5e-5, format="%.1e")
45
 
46
  tokenizer, model = initialize_model_and_tokenizer(model_name)
47
 
 
50
  return
51
 
52
  st.write("Loading and processing dataset...")
53
+ data = load_data(file_path)
54
 
55
+ if data is None:
56
  st.warning("Failed to load dataset. Please check the file path and try again.")
57
  return
58
 
59
  st.write("Tokenizing dataset...")
60
+ tokenized_dataset = create_dataset(data, tokenizer, max_length)
61
+
62
+ # Convert tokenized_dataset to a torch Dataset
63
+ class SimpleDataset(torch.utils.data.Dataset):
64
+ def __init__(self, encodings):
65
+ self.encodings = encodings
66
+
67
+ def __getitem__(self, idx):
68
+ item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
69
+ return item
70
+
71
+ def __len__(self):
72
+ return len(self.encodings['input_ids'])
73
+
74
+ dataset = SimpleDataset({'input_ids': tokenized_dataset})
75
 
76
  # Define training arguments
77
  training_args = TrainingArguments(
 
90
  trainer = Trainer(
91
  model=model,
92
  args=training_args,
93
+ train_dataset=dataset,
94
  data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False),
95
  )
96