Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

LockStore: fix acquiring a lock via LockStore.try_acquire_lock #12832

Merged
merged 3 commits into from
May 30, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/12832.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed a bug which allowed multiple async operations to access database locks concurrently. Contributed by @sumnerevans @ Beeper.
19 changes: 18 additions & 1 deletion synapse/storage/databases/main/lock.py
Original file line number Diff line number Diff line change
@@ -13,7 +13,7 @@
# limitations under the License.
import logging
from types import TracebackType
from typing import TYPE_CHECKING, Optional, Tuple, Type
from typing import TYPE_CHECKING, Optional, Set, Tuple, Type
from weakref import WeakValueDictionary

from twisted.internet.interfaces import IReactorCore
@@ -84,6 +84,8 @@ def __init__(
self._on_shutdown,
)

self._acquiring_locks: Set[Tuple[str, str]] = set()

@wrap_as_background_process("LockStore._on_shutdown")
async def _on_shutdown(self) -> None:
"""Called when the server is shutting down"""
@@ -103,6 +105,21 @@ async def try_acquire_lock(self, lock_name: str, lock_key: str) -> Optional["Loc
context manager if the lock is successfully acquired, which *must* be
used (otherwise the lock will leak).
"""
if (lock_name, lock_key) in self._acquiring_locks:
return None
try:
self._acquiring_locks.add((lock_name, lock_key))
return await self._try_acquire_lock(lock_name, lock_key)
finally:
self._acquiring_locks.discard((lock_name, lock_key))

async def _try_acquire_lock(
self, lock_name: str, lock_key: str
) -> Optional["Lock"]:
"""Try to acquire a lock for the given name/key. Will return an async
context manager if the lock is successfully acquired, which *must* be
used (otherwise the lock will leak).
"""

# Check if this process has taken out a lock and if it's still valid.
lock = self._live_tokens.get((lock_name, lock_key))
54 changes: 54 additions & 0 deletions tests/storage/databases/main/test_lock.py
Original file line number Diff line number Diff line change
@@ -12,6 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from twisted.internet import defer, reactor
from twisted.internet.base import ReactorBase
from twisted.internet.defer import Deferred

from synapse.server import HomeServer
from synapse.storage.databases.main.lock import _LOCK_TIMEOUT_MS

@@ -22,6 +26,56 @@ class LockTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs: HomeServer):
self.store = hs.get_datastores().main

def test_acquire_contention(self):
# Track the number of tasks holding the lock.
# Should be at most 1.
in_lock = 0
max_in_lock = 0

release_lock: "Deferred[None]" = Deferred()

async def task():
nonlocal in_lock
nonlocal max_in_lock

lock = await self.store.try_acquire_lock("name", "key")
if not lock:
return

async with lock:
in_lock += 1
max_in_lock = max(max_in_lock, in_lock)

# Block to allow other tasks to attempt to take the lock.
await release_lock

in_lock -= 1

# Start 3 tasks.
task1 = defer.ensureDeferred(task())
task2 = defer.ensureDeferred(task())
task3 = defer.ensureDeferred(task())

# Give the reactor a kick so that the database transaction returns.
self.pump()

release_lock.callback(None)

# Run the tasks to completion.
# To work around `Linearizer`s using a different reactor to sleep when
# contended (#12841), we call `runUntilCurrent` on
# `twisted.internet.reactor`, which is a different reactor to that used
# by the homeserver.
assert isinstance(reactor, ReactorBase)
self.get_success(task1)
reactor.runUntilCurrent()
self.get_success(task2)
reactor.runUntilCurrent()
self.get_success(task3)

# At most one task should have held the lock at a time.
self.assertEqual(max_in_lock, 1)

def test_simple_lock(self):
"""Test that we can take out a lock and that while we hold it nobody
else can take it out.