diff --git a/src/fastapi_redis_cache/cache.py b/src/fastapi_redis_cache/cache.py index 5257a14..6e614fd 100644 --- a/src/fastapi_redis_cache/cache.py +++ b/src/fastapi_redis_cache/cache.py @@ -3,7 +3,7 @@ from datetime import timedelta from functools import partial, update_wrapper, wraps from http import HTTPStatus -from typing import Union +from typing import Optional, Union from fastapi import Response @@ -19,13 +19,20 @@ ) -def cache(*, expire: Union[int, timedelta] = ONE_YEAR_IN_SECONDS): +def cache(*, expire: Union[int, timedelta] = ONE_YEAR_IN_SECONDS, web_expire: Optional[Union[int, timedelta]] = None): """Enable caching behavior for the decorated function. Args: expire (Union[int, timedelta], optional): The number of seconds from now when the cached response should expire. Defaults to 31,536,000 seconds (i.e., the number of seconds in one year). + + web_expire (Optional[Union[int, timedelta]]): The number of seconds + from now when cached web responses should expire. This is achived + by setting the ``cache-control`` header's ``max-age`` directive to the + specified number of seconds. If ``web_expire`` is not specified, the + value specified for ``expire`` or the default value for ``expire`` + will be used. """ def outer_wrapper(func): @@ -36,7 +43,9 @@ async def inner_wrapper(*args, **kwargs): func_kwargs = kwargs.copy() request = func_kwargs.pop("request", None) response = func_kwargs.pop("response", None) + create_response_directly = not response + if create_response_directly: response = Response() redis_cache = FastApiRedisCache() @@ -46,6 +55,8 @@ async def inner_wrapper(*args, **kwargs): key = redis_cache.get_cache_key(func, *args, **kwargs) ttl, in_cache = redis_cache.check_cache(key) if in_cache: + if web_expire is not None: + ttl = calculate_ttl(web_expire) redis_cache.set_response_headers(response, True, deserialize_json(in_cache), ttl) if redis_cache.requested_resource_not_modified(request, in_cache): response.status_code = int(HTTPStatus.NOT_MODIFIED) @@ -65,10 +76,11 @@ async def inner_wrapper(*args, **kwargs): else deserialize_json(in_cache) ) response_data = await get_api_response_async(func, *args, **kwargs) - ttl = calculate_ttl(expire) - cached = redis_cache.add_to_cache(key, response_data, ttl) + redis_ttl = calculate_ttl(expire) + web_ttl = calculate_ttl(web_expire) if web_expire is not None else redis_ttl + cached = redis_cache.add_to_cache(key, response_data, redis_ttl) if cached: - redis_cache.set_response_headers(response, cache_hit=False, response_data=response_data, ttl=ttl) + redis_cache.set_response_headers(response, cache_hit=False, response_data=response_data, ttl=web_ttl) return ( Response( content=serialize_json(response_data), media_type="application/json", headers=response.headers diff --git a/src/fastapi_redis_cache/version.py b/src/fastapi_redis_cache/version.py index 4bf3b7b..30b867d 100644 --- a/src/fastapi_redis_cache/version.py +++ b/src/fastapi_redis_cache/version.py @@ -1,3 +1,3 @@ # flake8: noqa -__version_info__ = ("0", "2", "5") # pragma: no cover +__version_info__ = ("0", "3", "0") # pragma: no cover __version__ = ".".join(__version_info__) # pragma: no cover diff --git a/tests/main.py b/tests/main.py index be2b693..fac7f1b 100644 --- a/tests/main.py +++ b/tests/main.py @@ -34,10 +34,20 @@ def cache_json_encoder(): @app.get("/cache_one_hour") @cache_one_hour() -def partial_cache_one_hour(response: Response): +def cache_one_hour(response: Response): return {"success": True, "message": "this data should be cached for one hour"} +REDIS_EXPIRE_SECONDS = 10 +WEB_EXPIRE_SECONDS = 5 + + +@app.get("/cache_web_expires_before_redis") +@cache(expire=REDIS_EXPIRE_SECONDS, web_expire=WEB_EXPIRE_SECONDS) +async def cache_web_expires_before_redis(request: Request, response: Response): + return {"success": True, "message": "this data should be web cached for five seconds"} + + @app.get("/cache_invalid_type") @cache_one_minute() def cache_invalid_type(request: Request, response: Response): diff --git a/tests/test_cache.py b/tests/test_cache.py index bdc0ccb..0a81686 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -10,6 +10,7 @@ from fastapi_redis_cache.util import deserialize_json from tests.main import app +from tests.main import REDIS_EXPIRE_SECONDS, WEB_EXPIRE_SECONDS client = TestClient(app) MAX_AGE_REGEX = re.compile(r"max-age=(?P\d+)") @@ -190,7 +191,7 @@ def test_if_none_match(): assert "etag" in response.headers -def test_partial_cache_one_hour(): +def test_cache_one_hour(): # Simple test that verifies that the @cache_for_one_hour partial function version of the @cache decorator # is working correctly. response = client.get("/cache_one_hour") @@ -214,3 +215,98 @@ def test_cache_invalid_type(): assert "cache-control" not in response.headers assert "expires" not in response.headers assert "etag" not in response.headers + + +def test_cache_web_expires_before_redis(): + target_endpoint = "/cache_web_expires_before_redis" + expected_response = {"success": True, "message": "this data should be web cached for five seconds"} + + # Store time when response data was added to cache + added_at_utc = datetime.utcnow() + + # Initial request, X-FastAPI-Cache header field should equal "Miss" + response = client.get(target_endpoint) + assert response.status_code == 200 + assert response.json() == expected_response + assert "x-fastapi-cache" in response.headers and response.headers["x-fastapi-cache"] == "Miss" + assert "expires" in response.headers + assert "etag" in response.headers + + # Store 'max-age' value of 'cache-control' header field + assert "cache-control" in response.headers + match = MAX_AGE_REGEX.search(response.headers.get("cache-control")) + assert match + miss_ttl = int(match.groupdict()["ttl"]) + assert miss_ttl <= WEB_EXPIRE_SECONDS + + # Store eTag value from response header + check_etag = response.headers["etag"] + + # Send request, X-FastAPI-Cache header field should now equal "Hit" + response = client.get(target_endpoint) + assert response.status_code == 200 + assert response.json() == expected_response + assert "x-fastapi-cache" in response.headers and response.headers["x-fastapi-cache"] == "Hit" + + # Verify eTag value matches the value stored from the initial response + assert "etag" in response.headers + assert response.headers["etag"] == check_etag + + # Store 'max-age' value of 'cache-control' header field + assert "cache-control" in response.headers + match = MAX_AGE_REGEX.search(response.headers.get("cache-control")) + assert match + hit_ttl = int(match.groupdict()["ttl"]) + assert hit_ttl <= miss_ttl + + # Store value of 'expires' header field + assert "expires" in response.headers + expire_at_utc = datetime.strptime(response.headers["expires"], HTTP_TIME) + + # Wait until web expiration time has passed + now = datetime.utcnow() + time.sleep((expire_at_utc - now).total_seconds()) + # Wait any additional time neecessary to ensure the web expiration has passed + now = datetime.utcnow() + while expire_at_utc > now: + time.sleep(1) + now = datetime.utcnow() + + # Wait one additional second to ensure the web cache has expired + time.sleep(1) + + # Verify that the time elapsed since the data was added to the cache is greater than the ttl value + second_request_utc = datetime.utcnow() + elapsed = (second_request_utc - added_at_utc).total_seconds() + assert elapsed > hit_ttl + + # Send request, X-FastAPI-Cache header field should equal "Hit" since the Redis cached value has a longer + # lifespan than the web cache value + response = client.get(target_endpoint) + assert response.status_code == 200 + assert response.json() == expected_response + assert "x-fastapi-cache" in response.headers and response.headers["x-fastapi-cache"] == "Hit" + assert "cache-control" in response.headers + assert "expires" in response.headers + + # Check eTag value again. Since data is the same, the value should still match + assert "etag" in response.headers + assert response.headers["etag"] == check_etag + + # Wait until Redis expiration time has passed + elapsed_since_added = (datetime.utcnow() - added_at_utc).total_seconds() + if elapsed_since_added < REDIS_EXPIRE_SECONDS: + time.sleep(REDIS_EXPIRE_SECONDS - elapsed_since_added) + # Wait any additional time neecessary, waiting an additional second to ensure Redis has + # deleted the response data + while (datetime.utcnow() - added_at_utc).total_seconds() < REDIS_EXPIRE_SECONDS: + time.sleep(1) + + # Send request, X-FastAPI-Cache header field should equal "Miss" since the Redis cached value has now expired + response = client.get(target_endpoint) + assert response.status_code == 200 + assert response.json() == expected_response + assert "x-fastapi-cache" in response.headers and response.headers["x-fastapi-cache"] == "Miss" + assert "cache-control" in response.headers + assert "expires" in response.headers + assert "etag" in response.headers