simon_says_v2 / tests /test_simon_says.py
ericmichael's picture
first commit
0079b8d
raw
history blame
892 Bytes
import unittest
from utils.json_loader import JsonDataLoader
from models.simon_says import simon_says
from utils.metrics import accuracy
from utils.openai import cassette_for
class TestSimonSays(unittest.TestCase):
def test_accuracy(self):
with cassette_for("simon_says"):
loader = JsonDataLoader(filepath="data/validation.json")
inputs, targets = loader.load_data()
predictions = [simon_says(**input_) for input_ in inputs]
response_target = [target["response"] for target in targets]
accuracy_score = accuracy(predictions, response_target)
# Set a threshold for the accuracy
threshold = 0.9
self.assertGreaterEqual(
accuracy_score,
threshold,
f"Model accuracy {accuracy_score} is below the threshold {threshold}",
)