diff --git a/tests/unit/test_metric.py b/tests/unit/test_metric.py index 649a1f05d..ef76cf54c 100644 --- a/tests/unit/test_metric.py +++ b/tests/unit/test_metric.py @@ -13,7 +13,7 @@ def test_get_available_metrics(): assert all( [ - m.required_columns == {"SINGLE_TURN": {"response", "user_input"}} + m.required_columns["SINGLE_TURN"] == {"response", "user_input"} for m in get_available_metrics(ds) ] ), "All metrics should have required columns ('user_input', 'response')"