Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add functionality for specifying a web cache expiration #60

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions src/fastapi_redis_cache/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from datetime import timedelta
from functools import partial, update_wrapper, wraps
from http import HTTPStatus
from typing import Union
from typing import Optional, Union

from fastapi import Response

Expand All @@ -19,13 +19,20 @@
)


def cache(*, expire: Union[int, timedelta] = ONE_YEAR_IN_SECONDS):
def cache(*, expire: Union[int, timedelta] = ONE_YEAR_IN_SECONDS, web_expire: Optional[Union[int, timedelta]] = None):
"""Enable caching behavior for the decorated function.

Args:
expire (Union[int, timedelta], optional): The number of seconds
from now when the cached response should expire. Defaults to 31,536,000
seconds (i.e., the number of seconds in one year).

web_expire (Optional[Union[int, timedelta]]): The number of seconds
from now when cached web responses should expire. This is achived
by setting the ``cache-control`` header's ``max-age`` directive to the
specified number of seconds. If ``web_expire`` is not specified, the
value specified for ``expire`` or the default value for ``expire``
will be used.
"""

def outer_wrapper(func):
Expand All @@ -36,7 +43,9 @@ async def inner_wrapper(*args, **kwargs):
func_kwargs = kwargs.copy()
request = func_kwargs.pop("request", None)
response = func_kwargs.pop("response", None)

create_response_directly = not response

if create_response_directly:
response = Response()
redis_cache = FastApiRedisCache()
Expand All @@ -46,6 +55,8 @@ async def inner_wrapper(*args, **kwargs):
key = redis_cache.get_cache_key(func, *args, **kwargs)
ttl, in_cache = redis_cache.check_cache(key)
if in_cache:
if web_expire is not None:
ttl = calculate_ttl(web_expire)
redis_cache.set_response_headers(response, True, deserialize_json(in_cache), ttl)
if redis_cache.requested_resource_not_modified(request, in_cache):
response.status_code = int(HTTPStatus.NOT_MODIFIED)
Expand All @@ -65,10 +76,11 @@ async def inner_wrapper(*args, **kwargs):
else deserialize_json(in_cache)
)
response_data = await get_api_response_async(func, *args, **kwargs)
ttl = calculate_ttl(expire)
cached = redis_cache.add_to_cache(key, response_data, ttl)
redis_ttl = calculate_ttl(expire)
web_ttl = calculate_ttl(web_expire) if web_expire is not None else redis_ttl
cached = redis_cache.add_to_cache(key, response_data, redis_ttl)
if cached:
redis_cache.set_response_headers(response, cache_hit=False, response_data=response_data, ttl=ttl)
redis_cache.set_response_headers(response, cache_hit=False, response_data=response_data, ttl=web_ttl)
return (
Response(
content=serialize_json(response_data), media_type="application/json", headers=response.headers
Expand Down
2 changes: 1 addition & 1 deletion src/fastapi_redis_cache/version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# flake8: noqa
__version_info__ = ("0", "2", "5") # pragma: no cover
__version_info__ = ("0", "3", "0") # pragma: no cover
__version__ = ".".join(__version_info__) # pragma: no cover
12 changes: 11 additions & 1 deletion tests/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,20 @@ def cache_json_encoder():

@app.get("/cache_one_hour")
@cache_one_hour()
def partial_cache_one_hour(response: Response):
def cache_one_hour(response: Response):
return {"success": True, "message": "this data should be cached for one hour"}


REDIS_EXPIRE_SECONDS = 10
WEB_EXPIRE_SECONDS = 5


@app.get("/cache_web_expires_before_redis")
@cache(expire=REDIS_EXPIRE_SECONDS, web_expire=WEB_EXPIRE_SECONDS)
async def cache_web_expires_before_redis(request: Request, response: Response):
return {"success": True, "message": "this data should be web cached for five seconds"}


@app.get("/cache_invalid_type")
@cache_one_minute()
def cache_invalid_type(request: Request, response: Response):
Expand Down
98 changes: 97 additions & 1 deletion tests/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from fastapi_redis_cache.util import deserialize_json
from tests.main import app
from tests.main import REDIS_EXPIRE_SECONDS, WEB_EXPIRE_SECONDS

client = TestClient(app)
MAX_AGE_REGEX = re.compile(r"max-age=(?P<ttl>\d+)")
Expand Down Expand Up @@ -190,7 +191,7 @@ def test_if_none_match():
assert "etag" in response.headers


def test_partial_cache_one_hour():
def test_cache_one_hour():
# Simple test that verifies that the @cache_for_one_hour partial function version of the @cache decorator
# is working correctly.
response = client.get("/cache_one_hour")
Expand All @@ -214,3 +215,98 @@ def test_cache_invalid_type():
assert "cache-control" not in response.headers
assert "expires" not in response.headers
assert "etag" not in response.headers


def test_cache_web_expires_before_redis():
target_endpoint = "/cache_web_expires_before_redis"
expected_response = {"success": True, "message": "this data should be web cached for five seconds"}

# Store time when response data was added to cache
added_at_utc = datetime.utcnow()

# Initial request, X-FastAPI-Cache header field should equal "Miss"
response = client.get(target_endpoint)
assert response.status_code == 200
assert response.json() == expected_response
assert "x-fastapi-cache" in response.headers and response.headers["x-fastapi-cache"] == "Miss"
assert "expires" in response.headers
assert "etag" in response.headers

# Store 'max-age' value of 'cache-control' header field
assert "cache-control" in response.headers
match = MAX_AGE_REGEX.search(response.headers.get("cache-control"))
assert match
miss_ttl = int(match.groupdict()["ttl"])
assert miss_ttl <= WEB_EXPIRE_SECONDS

# Store eTag value from response header
check_etag = response.headers["etag"]

# Send request, X-FastAPI-Cache header field should now equal "Hit"
response = client.get(target_endpoint)
assert response.status_code == 200
assert response.json() == expected_response
assert "x-fastapi-cache" in response.headers and response.headers["x-fastapi-cache"] == "Hit"

# Verify eTag value matches the value stored from the initial response
assert "etag" in response.headers
assert response.headers["etag"] == check_etag

# Store 'max-age' value of 'cache-control' header field
assert "cache-control" in response.headers
match = MAX_AGE_REGEX.search(response.headers.get("cache-control"))
assert match
hit_ttl = int(match.groupdict()["ttl"])
assert hit_ttl <= miss_ttl

# Store value of 'expires' header field
assert "expires" in response.headers
expire_at_utc = datetime.strptime(response.headers["expires"], HTTP_TIME)

# Wait until web expiration time has passed
now = datetime.utcnow()
time.sleep((expire_at_utc - now).total_seconds())
# Wait any additional time neecessary to ensure the web expiration has passed
now = datetime.utcnow()
while expire_at_utc > now:
time.sleep(1)
now = datetime.utcnow()

# Wait one additional second to ensure the web cache has expired
time.sleep(1)

# Verify that the time elapsed since the data was added to the cache is greater than the ttl value
second_request_utc = datetime.utcnow()
elapsed = (second_request_utc - added_at_utc).total_seconds()
assert elapsed > hit_ttl

# Send request, X-FastAPI-Cache header field should equal "Hit" since the Redis cached value has a longer
# lifespan than the web cache value
response = client.get(target_endpoint)
assert response.status_code == 200
assert response.json() == expected_response
assert "x-fastapi-cache" in response.headers and response.headers["x-fastapi-cache"] == "Hit"
assert "cache-control" in response.headers
assert "expires" in response.headers

# Check eTag value again. Since data is the same, the value should still match
assert "etag" in response.headers
assert response.headers["etag"] == check_etag

# Wait until Redis expiration time has passed
elapsed_since_added = (datetime.utcnow() - added_at_utc).total_seconds()
if elapsed_since_added < REDIS_EXPIRE_SECONDS:
time.sleep(REDIS_EXPIRE_SECONDS - elapsed_since_added)
# Wait any additional time neecessary, waiting an additional second to ensure Redis has
# deleted the response data
while (datetime.utcnow() - added_at_utc).total_seconds() < REDIS_EXPIRE_SECONDS:
time.sleep(1)

# Send request, X-FastAPI-Cache header field should equal "Miss" since the Redis cached value has now expired
response = client.get(target_endpoint)
assert response.status_code == 200
assert response.json() == expected_response
assert "x-fastapi-cache" in response.headers and response.headers["x-fastapi-cache"] == "Miss"
assert "cache-control" in response.headers
assert "expires" in response.headers
assert "etag" in response.headers