nileshhanotia commited on
Commit
a8d1617
1 Parent(s): 997991d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -8
app.py CHANGED
@@ -2,13 +2,10 @@ import os
2
  import json
3
  import streamlit as st
4
  from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
5
- from transformers import TextDataset, DataCollatorForLanguageModeling
6
  import torch
7
  from tqdm import tqdm
8
 
9
- # Remove the datasets import as we won't be using it anymore
10
- # from datasets import load_dataset
11
-
12
  @st.cache_data
13
  def load_data(file_path):
14
  if not os.path.exists(file_path):
@@ -22,6 +19,16 @@ def load_data(file_path):
22
  st.error(f"Error loading dataset: {str(e)}")
23
  return None
24
 
 
 
 
 
 
 
 
 
 
 
25
  def create_dataset(data, tokenizer, max_length):
26
  inputs = []
27
  for item in data:
@@ -65,13 +72,12 @@ def main():
65
  self.encodings = encodings
66
 
67
  def __getitem__(self, idx):
68
- item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
69
- return item
70
 
71
  def __len__(self):
72
- return len(self.encodings['input_ids'])
73
 
74
- dataset = SimpleDataset({'input_ids': tokenized_dataset})
75
 
76
  # Define training arguments
77
  training_args = TrainingArguments(
 
2
  import json
3
  import streamlit as st
4
  from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
5
+ from transformers import DataCollatorForLanguageModeling
6
  import torch
7
  from tqdm import tqdm
8
 
 
 
 
9
  @st.cache_data
10
  def load_data(file_path):
11
  if not os.path.exists(file_path):
 
19
  st.error(f"Error loading dataset: {str(e)}")
20
  return None
21
 
22
+ @st.cache_resource
23
+ def initialize_model_and_tokenizer(model_name):
24
+ try:
25
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
26
+ model = AutoModelForCausalLM.from_pretrained(model_name)
27
+ return tokenizer, model
28
+ except Exception as e:
29
+ st.error(f"Error initializing model and tokenizer: {str(e)}")
30
+ return None, None
31
+
32
  def create_dataset(data, tokenizer, max_length):
33
  inputs = []
34
  for item in data:
 
72
  self.encodings = encodings
73
 
74
  def __getitem__(self, idx):
75
+ return torch.tensor(self.encodings[idx])
 
76
 
77
  def __len__(self):
78
+ return len(self.encodings)
79
 
80
+ dataset = SimpleDataset(tokenized_dataset)
81
 
82
  # Define training arguments
83
  training_args = TrainingArguments(