diff --git a/solnlib/splunk_rest_client.py b/solnlib/splunk_rest_client.py index c83c2545..419c4177 100644 --- a/solnlib/splunk_rest_client.py +++ b/solnlib/splunk_rest_client.py @@ -26,6 +26,7 @@ import traceback from io import BytesIO from urllib.parse import quote +from urllib3.util.retry import Retry from splunklib import binding, client @@ -33,6 +34,7 @@ from .splunkenv import get_splunkd_access_info __all__ = ["SplunkRestClient"] +MAX_REQUEST_RETRIES = 5 def _get_proxy_info(context): @@ -98,10 +100,18 @@ def _request_handler(context): else: cert = None + retries = Retry( + total=MAX_REQUEST_RETRIES, + backoff_factor=0.3, + status_forcelist=[500, 502, 503, 504], + allowed_methods=["GET", "POST", "PUT", "DELETE"], + raise_on_status=False, + ) if context.get("pool_connections", 0): logging.info("Use HTTP connection pooling") session = requests.Session() adapter = requests.adapters.HTTPAdapter( + max_retries=retries, pool_connections=context.get("pool_connections", 10), pool_maxsize=context.get("pool_maxsize", 10), ) diff --git a/tests/unit/test_splunk_rest_client.py b/tests/unit/test_splunk_rest_client.py index 4a13e3cc..de43dbdd 100644 --- a/tests/unit/test_splunk_rest_client.py +++ b/tests/unit/test_splunk_rest_client.py @@ -17,8 +17,11 @@ from unittest import mock import pytest +from solnlib.splunk_rest_client import MAX_REQUEST_RETRIES +from requests.exceptions import ConnectionError from solnlib import splunk_rest_client +from solnlib.splunk_rest_client import SplunkRestClient @mock.patch.dict(os.environ, {"SPLUNK_HOME": "/opt/splunk"}, clear=True) @@ -80,3 +83,29 @@ def test_init_with_invalid_port(): host="localhost", port=99999, ) + + +@mock.patch.dict(os.environ, {"SPLUNK_HOME": "/opt/splunk"}, clear=True) +@mock.patch("solnlib.splunk_rest_client.get_splunkd_access_info") +@mock.patch("http.client.HTTPResponse") +@mock.patch("urllib3.HTTPConnectionPool._make_request") +def test_request_retry(http_conn_pool, http_resp, mock_get_splunkd_access_info): + mock_get_splunkd_access_info.return_value = "https", "localhost", 8089 + session_key = "123" + context = {"pool_connections": 5} + rest_client = SplunkRestClient("msg_name_1", session_key, "_", **context) + + mock_resp = http_resp() + mock_resp.status = 200 + mock_resp.reason = "TEST OK" + + side_effects = [ConnectionError(), ConnectionError(), ConnectionError(), mock_resp] + http_conn_pool.side_effect = side_effects + res = rest_client.get("test") + assert http_conn_pool.call_count == len(side_effects) + assert res.reason == mock_resp.reason + + side_effects = [ConnectionError()] * (MAX_REQUEST_RETRIES + 1) + [mock_resp] + http_conn_pool.side_effect = side_effects + with pytest.raises(ConnectionError): + rest_client.get("test")