Skip to content

Commit

Permalink
Update delay_cancellation to accept any awaitable (matrix-org#12468)
Browse files Browse the repository at this point in the history
This will mainly be useful when dealing with module callbacks, which are
all typed as returning `Awaitable`s instead of coroutines or
`Deferred`s.

Signed-off-by: Sean Quah <[email protected]>
  • Loading branch information
squahtx authored Apr 22, 2022
1 parent b82fff6 commit a50fb41
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 14 deletions.
1 change: 1 addition & 0 deletions changelog.d/12468.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Update `delay_cancellation` to accept any awaitable, rather than just `Deferred`s.
3 changes: 1 addition & 2 deletions synapse/storage/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
from typing_extensions import Literal

from twisted.enterprise import adbapi
from twisted.internet import defer

from synapse.api.errors import StoreError
from synapse.config.database import DatabaseConnectionConfig
Expand Down Expand Up @@ -794,7 +793,7 @@ async def _runInteraction() -> R:
# We also wait until everything above is done before releasing the
# `CancelledError`, so that logging contexts won't get used after they have been
# finished.
return await delay_cancellation(defer.ensureDeferred(_runInteraction()))
return await delay_cancellation(_runInteraction())

async def runWithConnection(
self,
Expand Down
52 changes: 42 additions & 10 deletions synapse/util/async_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

import abc
import asyncio
import collections
import inspect
import itertools
Expand All @@ -25,6 +26,7 @@
Awaitable,
Callable,
Collection,
Coroutine,
Dict,
Generic,
Hashable,
Expand Down Expand Up @@ -701,27 +703,57 @@ def stop_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]":
return new_deferred


def delay_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]":
"""Delay cancellation of a `Deferred` until it resolves.
@overload
def delay_cancellation(awaitable: "defer.Deferred[T]") -> "defer.Deferred[T]":
...


@overload
def delay_cancellation(awaitable: Coroutine[Any, Any, T]) -> "defer.Deferred[T]":
...


@overload
def delay_cancellation(awaitable: Awaitable[T]) -> Awaitable[T]:
...


def delay_cancellation(awaitable: Awaitable[T]) -> Awaitable[T]:
"""Delay cancellation of a coroutine or `Deferred` awaitable until it resolves.
Has the same effect as `stop_cancellation`, but the returned `Deferred` will not
resolve with a `CancelledError` until the original `Deferred` resolves.
resolve with a `CancelledError` until the original awaitable resolves.
Args:
deferred: The `Deferred` to protect against cancellation. May optionally follow
the Synapse logcontext rules.
deferred: The coroutine or `Deferred` to protect against cancellation. May
optionally follow the Synapse logcontext rules.
Returns:
A new `Deferred`, which will contain the result of the original `Deferred`.
The new `Deferred` will not propagate cancellation through to the original.
When cancelled, the new `Deferred` will wait until the original `Deferred`
resolves before failing with a `CancelledError`.
A new `Deferred`, which will contain the result of the original coroutine or
`Deferred`. The new `Deferred` will not propagate cancellation through to the
original coroutine or `Deferred`.
The new `Deferred` will follow the Synapse logcontext rules if `deferred`
When cancelled, the new `Deferred` will wait until the original coroutine or
`Deferred` resolves before failing with a `CancelledError`.
The new `Deferred` will follow the Synapse logcontext rules if `awaitable`
follows the Synapse logcontext rules. Otherwise the new `Deferred` should be
wrapped with `make_deferred_yieldable`.
"""

# First, convert the awaitable into a `Deferred`.
if isinstance(awaitable, defer.Deferred):
deferred = awaitable
elif asyncio.iscoroutine(awaitable):
# Ideally we'd use `Deferred.fromCoroutine()` here, to save on redundant
# type-checking, but we'd need Twisted >= 21.2.
deferred = defer.ensureDeferred(awaitable)
else:
# We have no idea what to do with this awaitable.
# We assume it's already resolved, such as `DoneAwaitable`s or `Future`s from
# `make_awaitable`, and let the caller `await` it normally.
return awaitable

def handle_cancel(new_deferred: "defer.Deferred[T]") -> None:
# before the new deferred is cancelled, we `pause` it to stop the cancellation
# propagating. we then `unpause` it once the wrapped deferred completes, to
Expand Down
33 changes: 31 additions & 2 deletions tests/util/test_async_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def test_cancellation(self):
class DelayCancellationTests(TestCase):
"""Tests for the `delay_cancellation` function."""

def test_cancellation(self):
def test_deferred_cancellation(self):
"""Test that cancellation of the new `Deferred` waits for the original."""
deferred: "Deferred[str]" = Deferred()
wrapper_deferred = delay_cancellation(deferred)
Expand All @@ -403,6 +403,35 @@ def test_cancellation(self):
# Now that the original `Deferred` has failed, we should get a `CancelledError`.
self.failureResultOf(wrapper_deferred, CancelledError)

def test_coroutine_cancellation(self):
"""Test that cancellation of the new `Deferred` waits for the original."""
blocking_deferred: "Deferred[None]" = Deferred()
completion_deferred: "Deferred[None]" = Deferred()

async def task():
await blocking_deferred
completion_deferred.callback(None)
# Raise an exception. Twisted should consume it, otherwise unwanted
# tracebacks will be printed in logs.
raise ValueError("abc")

wrapper_deferred = delay_cancellation(task())

# Cancel the new `Deferred`.
wrapper_deferred.cancel()
self.assertNoResult(wrapper_deferred)
self.assertFalse(
blocking_deferred.called, "Cancellation was propagated too deep"
)
self.assertFalse(completion_deferred.called)

# Unblock the task.
blocking_deferred.callback(None)
self.assertTrue(completion_deferred.called)

# Now that the original coroutine has failed, we should get a `CancelledError`.
self.failureResultOf(wrapper_deferred, CancelledError)

def test_suppresses_second_cancellation(self):
"""Test that a second cancellation is suppressed.
Expand Down Expand Up @@ -451,7 +480,7 @@ async def inner():
async def outer():
with LoggingContext("c") as c:
try:
await delay_cancellation(defer.ensureDeferred(inner()))
await delay_cancellation(inner())
self.fail("`CancelledError` was not raised")
except CancelledError:
self.assertEqual(c, current_context())
Expand Down

0 comments on commit a50fb41

Please sign in to comment.