From 8f4390346e527aa7e1060d245103e77be2dd7a30 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 20 Oct 2023 12:48:21 +0100 Subject: [PATCH] Ensure stream position never goes backwards, and add test --- synapse/storage/util/id_generators.py | 43 ++++++++++++++-------- tests/storage/test_id_generators.py | 51 +++++++++++++++++++++++++++ 2 files changed, 80 insertions(+), 14 deletions(-) diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 7911d1841f96..043e35ee5e2b 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -420,6 +420,11 @@ def __init__( # The maximum stream ID that we have seen been allocated across any writer. self._max_seen_allocated_stream_id = 1 + # The maximum position of the local instance. This can be higher than + # the corresponding position in `current_positions` table when there are + # no active writes in progress. + self._max_position_of_local_instance = self._max_seen_allocated_stream_id + self._sequence_gen = PostgresSequenceGenerator(sequence_name) # We check that the table and sequence haven't diverged. @@ -439,6 +444,16 @@ def __init__( self._current_positions.values(), default=1 ) + # For the case where `stream_positions` is not up to date, + # `_persisted_upto_position` may be higher. + self._max_seen_allocated_stream_id = max( + self._max_seen_allocated_stream_id, self._persisted_upto_position + ) + + # Bump our local maximum position now that we've loaded things from the + # DB. + self._max_position_of_local_instance = self._max_seen_allocated_stream_id + if not writers: # If there have been no explicit writers given then any instance can # write to the stream. In which case, let's pre-seed our own @@ -708,6 +723,7 @@ def _mark_id_as_finished(self, next_id: int) -> None: if new_cur: curr = self._current_positions.get(self._instance_name, 0) self._current_positions[self._instance_name] = max(curr, new_cur) + self._max_position_of_local_instance = max(curr, new_cur) self._add_persisted_position(next_id) @@ -722,6 +738,9 @@ def get_current_token_for_writer(self, instance_name: str) -> int: # persisted up to position. This stops Synapse from doing a full table # scan when a new writer announces itself over replication. with self._lock: + if self._instance_name == instance_name: + return self._return_factor * self._max_position_of_local_instance + pos = self._current_positions.get( instance_name, self._persisted_upto_position ) @@ -731,20 +750,6 @@ def get_current_token_for_writer(self, instance_name: str) -> int: # possible. pos = max(pos, self._persisted_upto_position) - if ( - self._instance_name == instance_name - and not self._in_flight_fetches - and not self._unfinished_ids - ): - # For our own instance when there's nothing in flight, it's safe - # to advance to the maximum persisted position we've seen (as we - # know that any new tokens we request will be greater). - max_pos_of_all_writers = max( - self._current_positions.values(), - default=self._persisted_upto_position, - ) - pos = max(pos, max_pos_of_all_writers) - return self._return_factor * pos def get_minimal_local_current_token(self) -> int: @@ -821,6 +826,16 @@ def _add_persisted_position(self, new_id: int) -> None: self._persisted_upto_position = max(min_curr, self._persisted_upto_position) + # Advance our local max position. + self._max_position_of_local_instance = max( + self._max_position_of_local_instance, self._persisted_upto_position + ) + + if not self._unfinished_ids and not self._in_flight_fetches: + # If we don't have anything in flight, it's safe to advance to the + # max seen stream ID. + self._max_position_of_local_instance = self._max_seen_allocated_stream_id + # We now iterate through the seen positions, discarding those that are # less than the current min positions, and incrementing the min position # if its exactly one greater. diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index 855eab6ac0d6..e35f13247686 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -661,6 +661,57 @@ def test_minimal_local_token(self) -> None: self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7}) self.assertEqual(second_id_gen.get_minimal_local_current_token(), 7) + def test_current_token_gap(self) -> None: + """Test that getting the current token for a writer returns the maximal + token when there are no writes. + """ + self._insert_rows("first", 3) + self._insert_rows("second", 4) + + first_id_gen = self._create_id_generator( + "first", writers=["first", "second", "third"] + ) + second_id_gen = self._create_id_generator( + "second", writers=["first", "second", "third"] + ) + + self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7) + self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7) + self.assertEqual(second_id_gen.get_current_token(), 7) + + # Check that the first ID gen advancing causes the second ID gen to + # advance (as it has nothing in flight). + + async def _get_next_async() -> None: + async with first_id_gen.get_next_mult(2): + pass + + self.get_success(_get_next_async()) + second_id_gen.advance("first", 9) + + self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 9) + self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 9) + self.assertEqual(second_id_gen.get_current_token(), 7) + + # Check that the first ID gen advancing doesn't advance the second ID + # gen when it has stuff in flight. + self.get_success(_get_next_async()) + + ctxmgr = second_id_gen.get_next() + self.get_success(ctxmgr.__aenter__()) + + second_id_gen.advance("first", 11) + + self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 11) + self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 9) + self.assertEqual(second_id_gen.get_current_token(), 7) + + self.get_success(ctxmgr.__aexit__(None, None, None)) + + self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 11) + self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 12) + self.assertEqual(second_id_gen.get_current_token(), 7) + class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): """Tests MultiWriterIdGenerator that produce *negative* stream IDs."""