diff --git a/railib/api.py b/railib/api.py index c1c1e5b..69eeebe 100644 --- a/railib/api.py +++ b/railib/api.py @@ -331,6 +331,9 @@ def poll_with_specified_overhead( max_tries: int = None, max_delay: int = 120, ): + if overhead_rate < 0: + raise ValueError("overhead_rate must be non-negative") + if start_time is None: start_time = time.time() diff --git a/test/test_unit.py b/test/test_unit.py index c14ab93..bdd9574 100644 --- a/test/test_unit.py +++ b/test/test_unit.py @@ -10,6 +10,8 @@ class TestPolling(unittest.TestCase): + @patch('time.sleep', return_value=None) + @patch('time.time') def test_timeout_exception(self): try: api.poll_with_specified_overhead(lambda: False, overhead_rate=0.1, timeout=1) @@ -28,14 +30,14 @@ def test_validation(self): api.poll_with_specified_overhead(lambda: True, overhead_rate=0.1, max_tries=1) api.poll_with_specified_overhead(lambda: True, overhead_rate=0.1, timeout=1, max_tries=1) - def test_initial_delay(self): - start_time = time.time() - with patch('time.sleep') as mock_sleep: - try: - api.poll_with_specified_overhead(lambda: False, overhead_rate=0.1, max_tries=1) - except Exception: - pass - mock_sleep.assert_called_with((time.time() - start_time) * 0.1) + def test_initial_delay(self, mock_time, mock_sleep): + start_time = 100 # Fixed start time + mock_time.return_value = start_time + 0.0001 # Fixed increment + + api.poll_with_specified_overhead(lambda: False, overhead_rate=0.1, max_tries=1) + + expected_sleep_time = (mock_time.return_value - start_time) * 0.1 + mock_sleep.assert_called_with(expected_sleep_time) def test_max_delay_cap(self): with patch('time.sleep') as mock_sleep: