CarolXia commited on
Commit
fdfe620
·
1 Parent(s): 54f4eaa

Temporarily disable eval_model

Browse files
Files changed (1) hide show
  1. app.py +22 -22
app.py CHANGED
@@ -178,7 +178,7 @@ def distillation_loss(student_logits, teacher_logits, true_labels, temperature,
178
  # hyperparameters
179
  batch_size = 32
180
  lr = 1e-4
181
- num_epochs = 50
182
  temperature = 2.0
183
  alpha = 0.5
184
 
@@ -220,27 +220,27 @@ for epoch in range(num_epochs):
220
 
221
  print(f"Epoch {epoch + 1} completed with loss: {loss.item()}")
222
 
223
- # Evaluate the teacher model
224
- teacher_accuracy, teacher_precision, teacher_recall, teacher_f1 = evaluate_model(teacher_model, test_dataloader, device)
225
- print(f"Teacher (test) - Accuracy: {teacher_accuracy:.4f}, Precision: {teacher_precision:.4f}, Recall: {teacher_recall:.4f}, F1 Score: {teacher_f1:.4f}")
226
-
227
- # Evaluate the student model
228
- student_accuracy, student_precision, student_recall, student_f1 = evaluate_model(student_model, test_dataloader, device)
229
- print(f"Student (test) - Accuracy: {student_accuracy:.4f}, Precision: {student_precision:.4f}, Recall: {student_recall:.4f}, F1 Score: {student_f1:.4f}")
230
- print("\n")
231
-
232
- # put student model back into train mode
233
- student_model.train()
234
-
235
- #Compare the models
236
- # create testing data loader
237
- validation_dataloader = DataLoader(tokenized_data['test'], batch_size=8, collate_fn=data_collator)
238
- # Evaluate the teacher model
239
- teacher_accuracy, teacher_precision, teacher_recall, teacher_f1 = evaluate_model(teacher_model, validation_dataloader, device)
240
- print(f"Teacher (validation) - Accuracy: {teacher_accuracy:.4f}, Precision: {teacher_precision:.4f}, Recall: {teacher_recall:.4f}, F1 Score: {teacher_f1:.4f}")
241
- # Evaluate the student model
242
- student_accuracy, student_precision, student_recall, student_f1 = evaluate_model(student_model, validation_dataloader, device)
243
- print(f"Student (validation) - Accuracy: {student_accuracy:.4f}, Precision: {student_precision:.4f}, Recall: {student_recall:.4f}, F1 Score: {student_f1:.4f}")
244
 
245
 
246
  st.write('Pushing model to huggingface')
 
178
  # hyperparameters
179
  batch_size = 32
180
  lr = 1e-4
181
+ num_epochs = 30
182
  temperature = 2.0
183
  alpha = 0.5
184
 
 
220
 
221
  print(f"Epoch {epoch + 1} completed with loss: {loss.item()}")
222
 
223
+ # # # Evaluate the teacher model
224
+ # # teacher_accuracy, teacher_precision, teacher_recall, teacher_f1 = evaluate_model(teacher_model, test_dataloader, device)
225
+ # # print(f"Teacher (test) - Accuracy: {teacher_accuracy:.4f}, Precision: {teacher_precision:.4f}, Recall: {teacher_recall:.4f}, F1 Score: {teacher_f1:.4f}")
226
+
227
+ # # # Evaluate the student model
228
+ # # student_accuracy, student_precision, student_recall, student_f1 = evaluate_model(student_model, test_dataloader, device)
229
+ # # print(f"Student (test) - Accuracy: {student_accuracy:.4f}, Precision: {student_precision:.4f}, Recall: {student_recall:.4f}, F1 Score: {student_f1:.4f}")
230
+ # # print("\n")
231
+
232
+ # # # put student model back into train mode
233
+ # # student_model.train()
234
+
235
+ # #Compare the models
236
+ # # create testing data loader
237
+ # validation_dataloader = DataLoader(tokenized_data['test'], batch_size=8, collate_fn=data_collator)
238
+ # # Evaluate the teacher model
239
+ # teacher_accuracy, teacher_precision, teacher_recall, teacher_f1 = evaluate_model(teacher_model, validation_dataloader, device)
240
+ # print(f"Teacher (validation) - Accuracy: {teacher_accuracy:.4f}, Precision: {teacher_precision:.4f}, Recall: {teacher_recall:.4f}, F1 Score: {teacher_f1:.4f}")
241
+ # # Evaluate the student model
242
+ # student_accuracy, student_precision, student_recall, student_f1 = evaluate_model(student_model, validation_dataloader, device)
243
+ # print(f"Student (validation) - Accuracy: {student_accuracy:.4f}, Precision: {student_precision:.4f}, Recall: {student_recall:.4f}, F1 Score: {student_f1:.4f}")
244
 
245
 
246
  st.write('Pushing model to huggingface')