diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 8598c1f..9db37d1 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,3 +1,8 @@ +3.1.1 +----- + +* Since 3.8, CancelledError is a subclass of BaseException rather than Exception, so we need to catch it explicitly. + 3.1.0 ----- diff --git a/memoize/wrapper.py b/memoize/wrapper.py index 9875072..14379dd 100644 --- a/memoize/wrapper.py +++ b/memoize/wrapper.py @@ -6,7 +6,7 @@ import datetime import functools import logging -from asyncio import Future +from asyncio import Future, CancelledError from typing import Optional, Callable from memoize.configuration import CacheConfiguration, NotConfiguredCacheCalledException, \ @@ -116,7 +116,7 @@ async def refresh(actual_entry: Optional[CacheEntry], key: CacheKey, logger.debug('Timeout for %s: %s', key, e) update_statuses.mark_update_aborted(key, e) raise CachedMethodFailedException('Refresh timed out') from e - except Exception as e: + except (Exception, CancelledError) as e: logger.debug('Error while refreshing cache for %s: %s', key, e) update_statuses.mark_update_aborted(key, e) raise CachedMethodFailedException('Refresh failed to complete') from e diff --git a/setup.py b/setup.py index 74e4a2f..22ab504 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ def prepare_description(): setup( name='py-memoize', - version='3.1.0', + version='3.1.1', author='Michal Zmuda', author_email='zmu.michal@gmail.com', url='https://github.com/DreamLab/memoize', diff --git a/tests/end2end/test_wrapper.py b/tests/end2end/test_wrapper.py index 2441ed9..6a40f7a 100644 --- a/tests/end2end/test_wrapper.py +++ b/tests/end2end/test_wrapper.py @@ -1,5 +1,6 @@ import asyncio import time +from asyncio import CancelledError from datetime import timedelta from unittest.mock import Mock @@ -174,6 +175,37 @@ async def get_value(arg, kwarg=None): assert context.value.__class__ == CachedMethodFailedException assert str(context.value.__cause__) == str(ValueError('stub0')) + async def test_should_return_cancelled_exception_for_all_concurrent_callers(self): + # given + value = 0 + + @memoize() + async def get_value(arg, kwarg=None): + new_task = asyncio.create_task(asyncio.sleep(1)) + new_task.cancel() # this will raise CancelledError + await new_task + + # when + res1 = get_value('test', kwarg='args1') + res2 = get_value('test', kwarg='args1') + res3 = get_value('test', kwarg='args1') + + # then + with pytest.raises(Exception) as context: + await res1 + assert context.value.__class__ == CachedMethodFailedException + assert str(context.value.__cause__) == str(CancelledError()) + + with pytest.raises(Exception) as context: + await res2 + assert context.value.__class__ == CachedMethodFailedException + assert str(context.value.__cause__) == str(CancelledError()) + + with pytest.raises(Exception) as context: + await res3 + assert context.value.__class__ == CachedMethodFailedException + assert str(context.value.__cause__) == str(CancelledError()) + async def test_should_return_timeout_for_all_concurrent_callers(self): # given value = 0