zaidmehdi commited on
Commit
46e333b
1 Parent(s): 490e341

update tests for dictionary output

Browse files
Files changed (1) hide show
  1. tests/model_tests.py +8 -7
tests/model_tests.py CHANGED
@@ -19,17 +19,18 @@ class TestClassifier(unittest.TestCase):
19
  """Test if the response of the main function is correct"""
20
  text = "حاجة حلوة اكيد"
21
  predictions = classify_arabic_dialect(text)
22
- self.assertEqual(len(predictions), 3)
23
- for i in range(3):
24
- self.assertIn(predictions[i][0], self.dialects)
25
- self.assertGreaterEqual(predictions[i][1], 0)
26
- self.assertLessEqual(predictions[i][1], 1)
27
 
28
  def test_model_output(self):
29
  """Test that the model correctly classifies obvious dialects"""
30
  for country, text, in self.test_set.items():
31
- first_prediction, _, _ = classify_arabic_dialect(text)
32
- self.assertEqual(first_prediction[0], country)
 
33
 
34
 
35
  if __name__ == "__main__":
 
19
  """Test if the response of the main function is correct"""
20
  text = "حاجة حلوة اكيد"
21
  predictions = classify_arabic_dialect(text)
22
+ self.assertEqual(len(predictions), len(self.dialects))
23
+ for key, value in predictions.items():
24
+ self.assertIn(key, self.dialects)
25
+ self.assertGreaterEqual(value, 0)
26
+ self.assertLessEqual(value, 1)
27
 
28
  def test_model_output(self):
29
  """Test that the model correctly classifies obvious dialects"""
30
  for country, text, in self.test_set.items():
31
+ predictions = classify_arabic_dialect(text)
32
+ label = max(predictions, key=predictions.get)
33
+ self.assertEqual(label, country)
34
 
35
 
36
  if __name__ == "__main__":