Spaces:
Runtime error
Runtime error
Update train.py
Browse files
train.py
CHANGED
@@ -57,6 +57,9 @@ def load_data():
|
|
57 |
return dataset
|
58 |
|
59 |
def encode_decode(texts, tok):
|
|
|
|
|
|
|
60 |
tokenized_texts = tok(
|
61 |
texts,
|
62 |
padding="max_length",
|
@@ -69,7 +72,7 @@ def encode_decode(texts, tok):
|
|
69 |
decoded_texts = tok.batch_decode(tokenized_texts)
|
70 |
else:
|
71 |
print('Found invalid entry in examples. Returning dummy..')
|
72 |
-
decoded_texts = [
|
73 |
|
74 |
islist = not len(decoded_texts) == 1
|
75 |
|
@@ -97,7 +100,7 @@ def get_training_corpus(dataset):
|
|
97 |
def format_prompts(examples, tokenizer, isinst):
|
98 |
texts = []
|
99 |
for text in examples['text']:
|
100 |
-
if text:
|
101 |
if isinst:
|
102 |
conversation = []
|
103 |
parts = text.split('<|end|>')
|
@@ -115,6 +118,9 @@ def format_prompts(examples, tokenizer, isinst):
|
|
115 |
print('Found empty entry in examples. Moving on..')
|
116 |
continue
|
117 |
|
|
|
|
|
|
|
118 |
coded_texts = tokenizer.code(texts)
|
119 |
return {'text': coded_texts}
|
120 |
|
@@ -208,7 +214,24 @@ def train_model(model, tokenizer, dataset, push, isinst):
|
|
208 |
)
|
209 |
|
210 |
dataset = dataset.map(lambda examples: format_prompts(examples, tokenizer, isinst), batched=True, remove_columns=dataset.column_names)
|
|
|
|
|
|
|
|
|
211 |
print("Mapped dataset sample length:", len(dataset[0]['text']))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
212 |
|
213 |
trainer = trl.SFTTrainer(
|
214 |
model=model,
|
@@ -270,8 +293,14 @@ def main(push_to_hub=True, is_inst_finetune=False):
|
|
270 |
model = create_model(tokenizer)
|
271 |
print("Created Model.")
|
272 |
|
|
|
|
|
|
|
273 |
print("Resizing Token Embeddings..")
|
274 |
-
|
|
|
|
|
|
|
275 |
print("Resized Embeddings.")
|
276 |
|
277 |
print("Training Model..")
|
|
|
57 |
return dataset
|
58 |
|
59 |
def encode_decode(texts, tok):
|
60 |
+
if tok.pad_token is None:
|
61 |
+
tok.pad_token = tok.eos_token
|
62 |
+
|
63 |
tokenized_texts = tok(
|
64 |
texts,
|
65 |
padding="max_length",
|
|
|
72 |
decoded_texts = tok.batch_decode(tokenized_texts)
|
73 |
else:
|
74 |
print('Found invalid entry in examples. Returning dummy..')
|
75 |
+
decoded_texts = [tokenizer.pad_token * MAX_SEQ_LENGTH]
|
76 |
|
77 |
islist = not len(decoded_texts) == 1
|
78 |
|
|
|
100 |
def format_prompts(examples, tokenizer, isinst):
|
101 |
texts = []
|
102 |
for text in examples['text']:
|
103 |
+
if text and len(text.strip()) > 0:
|
104 |
if isinst:
|
105 |
conversation = []
|
106 |
parts = text.split('<|end|>')
|
|
|
118 |
print('Found empty entry in examples. Moving on..')
|
119 |
continue
|
120 |
|
121 |
+
if len(texts) == 0:
|
122 |
+
raise ValueError("No valid texts found in examples for formatting.")
|
123 |
+
|
124 |
coded_texts = tokenizer.code(texts)
|
125 |
return {'text': coded_texts}
|
126 |
|
|
|
214 |
)
|
215 |
|
216 |
dataset = dataset.map(lambda examples: format_prompts(examples, tokenizer, isinst), batched=True, remove_columns=dataset.column_names)
|
217 |
+
|
218 |
+
if 'text' not in dataset.column_names:
|
219 |
+
raise ValueError("Dataset transformation failed: 'text' column missing after mapping.")
|
220 |
+
|
221 |
print("Mapped dataset sample length:", len(dataset[0]['text']))
|
222 |
+
|
223 |
+
try:
|
224 |
+
test_input = tokenizer(
|
225 |
+
["This is a test input."],
|
226 |
+
return_tensors="pt",
|
227 |
+
padding="max_length",
|
228 |
+
truncation=True,
|
229 |
+
max_length=MAX_SEQ_LENGTH
|
230 |
+
)
|
231 |
+
test_output = model(**test_input)
|
232 |
+
print("Model test output shape:", test_output.logits.shape)
|
233 |
+
except RuntimeError as e:
|
234 |
+
print(f"Error processing test batch: {e}")
|
235 |
|
236 |
trainer = trl.SFTTrainer(
|
237 |
model=model,
|
|
|
293 |
model = create_model(tokenizer)
|
294 |
print("Created Model.")
|
295 |
|
296 |
+
print(f"Tokenizer vocabulary size: {len(tokenizer)}")
|
297 |
+
print(f"Special tokens: {tokenizer.special_tokens_map}")
|
298 |
+
|
299 |
print("Resizing Token Embeddings..")
|
300 |
+
try:
|
301 |
+
model.resize_token_embeddings(len(tokenizer))
|
302 |
+
except RuntimeError as e:
|
303 |
+
raise RuntimeError(f"Error resizing token embeddings: {e}")
|
304 |
print("Resized Embeddings.")
|
305 |
|
306 |
print("Training Model..")
|