diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 63cf24a14d94..7476839db584 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -238,7 +238,7 @@ async def send_request( data[_STREAM_POSITION_KEY] = { "streams": { - stream.NAME: stream.current_token(local_instance_name) + stream.NAME: stream.minimal_local_current_token() for stream in streams }, "instance_name": local_instance_name, diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index c6088a0f99f8..d86b2b2134b2 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -33,6 +33,7 @@ if TYPE_CHECKING: from synapse.server import HomeServer + from synapse.storage.util.id_generators import AbstractStreamIdGenerator logger = logging.getLogger(__name__) @@ -107,22 +108,10 @@ def parse_row(cls, row: StreamRow) -> Any: def __init__( self, local_instance_name: str, - current_token_function: Callable[[str], Token], update_function: UpdateFunction, ): """Instantiate a Stream - `current_token_function` and `update_function` are callbacks which - should be implemented by subclasses. - - `current_token_function` takes an instance name, which is a writer to - the stream, and returns the position in the stream of the writer (as - viewed from the current process). On the writer process this is where - the writer has successfully written up to, whereas on other processes - this is the position which we have received updates up to over - replication. (Note that most streams have a single writer and so their - implementations ignore the instance name passed in). - `update_function` is called to get updates for this stream between a pair of stream tokens. See the `UpdateFunction` type definition for more info. @@ -133,12 +122,28 @@ def __init__( update_function: callback go get stream updates, as above """ self.local_instance_name = local_instance_name - self.current_token = current_token_function self.update_function = update_function # The token from which we last asked for updates self.last_token = self.current_token(self.local_instance_name) + def current_token(self, instance_name: str) -> Token: + """This takes an instance name, which is a writer to + the stream, and returns the position in the stream of the writer (as + viewed from the current process). + """ + # We can't make this an abstract class as it makes mypy unhappy. + raise NotImplementedError() + + def minimal_local_current_token(self) -> Token: + """Tries to return a minimal current token for the local instance, + i.e. for writers this would be the last successful write. + + If local instance is not a writer (or has written yet) then falls back + to returning the normal "current token". + """ + raise NotImplementedError() + def discard_updates_and_advance(self) -> None: """Called when the stream should advance but the updates would be discarded, e.g. when there are no currently connected workers. @@ -190,6 +195,25 @@ async def get_updates_since( return updates, upto_token, limited +class _StreamFromIdGen(Stream): + """Helper class for simple streams that use a stream ID generator""" + + def __init__( + self, + local_instance_name: str, + update_function: UpdateFunction, + stream_id_gen: "AbstractStreamIdGenerator", + ): + self._stream_id_gen = stream_id_gen + super().__init__(local_instance_name, update_function) + + def current_token(self, instance_name: str) -> Token: + return self._stream_id_gen.get_current_token_for_writer(instance_name) + + def minimal_local_current_token(self) -> Token: + return self._stream_id_gen.get_minimal_local_current_token() + + def current_token_without_instance( current_token: Callable[[], int] ) -> Callable[[str], int]: @@ -242,17 +266,21 @@ def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main super().__init__( hs.get_instance_name(), - self._current_token, self.store.get_all_new_backfill_event_rows, ) - def _current_token(self, instance_name: str) -> int: + def current_token(self, instance_name: str) -> Token: # The backfill stream over replication operates on *positive* numbers, # which means we need to negate it. return -self.store._backfill_id_gen.get_current_token_for_writer(instance_name) + def minimal_local_current_token(self) -> Token: + # The backfill stream over replication operates on *positive* numbers, + # which means we need to negate it. + return -self.store._backfill_id_gen.get_minimal_local_current_token() -class PresenceStream(Stream): + +class PresenceStream(_StreamFromIdGen): @attr.s(slots=True, frozen=True, auto_attribs=True) class PresenceStreamRow: user_id: str @@ -283,9 +311,7 @@ def __init__(self, hs: "HomeServer"): update_function = make_http_update_function(hs, self.NAME) super().__init__( - hs.get_instance_name(), - current_token_without_instance(store.get_current_presence_token), - update_function, + hs.get_instance_name(), update_function, store._presence_id_gen ) @@ -305,13 +331,18 @@ class PresenceFederationStreamRow: ROW_TYPE = PresenceFederationStreamRow def __init__(self, hs: "HomeServer"): - federation_queue = hs.get_presence_handler().get_federation_queue() + self._federation_queue = hs.get_presence_handler().get_federation_queue() super().__init__( hs.get_instance_name(), - federation_queue.get_current_token, - federation_queue.get_replication_rows, + self._federation_queue.get_replication_rows, ) + def current_token(self, instance_name: str) -> Token: + return self._federation_queue.get_current_token(instance_name) + + def minimal_local_current_token(self) -> Token: + return self._federation_queue.get_current_token(self.local_instance_name) + class TypingStream(Stream): @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -341,20 +372,25 @@ def __init__(self, hs: "HomeServer"): update_function: Callable[ [str, int, int, int], Awaitable[Tuple[List[Tuple[int, Any]], int, bool]] ] = typing_writer_handler.get_all_typing_updates - current_token_function = typing_writer_handler.get_current_token + self.current_token_function = typing_writer_handler.get_current_token else: # Query the typing writer process update_function = make_http_update_function(hs, self.NAME) - current_token_function = hs.get_typing_handler().get_current_token + self.current_token_function = hs.get_typing_handler().get_current_token super().__init__( hs.get_instance_name(), - current_token_without_instance(current_token_function), update_function, ) + def current_token(self, instance_name: str) -> Token: + return self.current_token_function() + + def minimal_local_current_token(self) -> Token: + return self.current_token_function() -class ReceiptsStream(Stream): + +class ReceiptsStream(_StreamFromIdGen): @attr.s(slots=True, frozen=True, auto_attribs=True) class ReceiptsStreamRow: room_id: str @@ -371,12 +407,12 @@ def __init__(self, hs: "HomeServer"): store = hs.get_datastores().main super().__init__( hs.get_instance_name(), - current_token_without_instance(store.get_max_receipt_stream_id), store.get_all_updated_receipts, + store._receipts_id_gen, ) -class PushRulesStream(Stream): +class PushRulesStream(_StreamFromIdGen): """A user has changed their push rules""" @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -387,20 +423,16 @@ class PushRulesStreamRow: ROW_TYPE = PushRulesStreamRow def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastores().main + store = hs.get_datastores().main super().__init__( hs.get_instance_name(), - self._current_token, - self.store.get_all_push_rule_updates, + store.get_all_push_rule_updates, + store._push_rules_stream_id_gen, ) - def _current_token(self, instance_name: str) -> int: - push_rules_token = self.store.get_max_push_rules_stream_id() - return push_rules_token - -class PushersStream(Stream): +class PushersStream(_StreamFromIdGen): """A user has added/changed/removed a pusher""" @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -418,8 +450,8 @@ def __init__(self, hs: "HomeServer"): super().__init__( hs.get_instance_name(), - current_token_without_instance(store.get_pushers_stream_token), store.get_all_updated_pushers_rows, + store._push_rules_stream_id_gen, ) @@ -447,15 +479,20 @@ class CachesStreamRow: ROW_TYPE = CachesStreamRow def __init__(self, hs: "HomeServer"): - store = hs.get_datastores().main + self.store = hs.get_datastores().main super().__init__( hs.get_instance_name(), - store.get_cache_stream_token_for_writer, - store.get_all_updated_caches, + self.store.get_all_updated_caches, ) + def current_token(self, instance_name: str) -> Token: + return self.store.get_cache_stream_token_for_writer(instance_name) + + def minimal_local_current_token(self) -> Token: + return self.current_token(self.local_instance_name) + -class DeviceListsStream(Stream): +class DeviceListsStream(_StreamFromIdGen): """Either a user has updated their devices or a remote server needs to be told about a device update. """ @@ -473,8 +510,8 @@ def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main super().__init__( hs.get_instance_name(), - current_token_without_instance(self.store.get_device_stream_token), self._update_function, + self.store._device_list_id_gen, ) async def _update_function( @@ -525,7 +562,7 @@ async def _update_function( return updates, upper_limit_token, devices_limited or signatures_limited -class ToDeviceStream(Stream): +class ToDeviceStream(_StreamFromIdGen): """New to_device messages for a client""" @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -539,12 +576,12 @@ def __init__(self, hs: "HomeServer"): store = hs.get_datastores().main super().__init__( hs.get_instance_name(), - current_token_without_instance(store.get_to_device_stream_token), store.get_all_new_device_messages, + store._device_inbox_id_gen, ) -class AccountDataStream(Stream): +class AccountDataStream(_StreamFromIdGen): """Global or per room account data was changed""" @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -560,8 +597,8 @@ def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main super().__init__( hs.get_instance_name(), - current_token_without_instance(self.store.get_max_account_data_stream_id), self._update_function, + self.store._account_data_id_gen, ) async def _update_function( diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index ad9b76071383..cad6d7e09aac 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -18,10 +18,10 @@ import attr from synapse.replication.tcp.streams._base import ( - Stream, StreamRow, StreamUpdateResult, Token, + _StreamFromIdGen, ) if TYPE_CHECKING: @@ -119,7 +119,7 @@ class EventsStreamCurrentStateRow(BaseEventsStreamRow): TypeToRow = {Row.TypeId: Row for Row in _EventRows} -class EventsStream(Stream): +class EventsStream(_StreamFromIdGen): """We received a new event, or an event went from being an outlier to not""" NAME = "events" @@ -127,9 +127,7 @@ class EventsStream(Stream): def __init__(self, hs: "HomeServer"): self._store = hs.get_datastores().main super().__init__( - hs.get_instance_name(), - self._store._stream_id_gen.get_current_token_for_writer, - self._update_function, + hs.get_instance_name(), self._update_function, self._store._stream_id_gen ) async def _update_function( diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py index 4046bdec6931..7f5af5852c1c 100644 --- a/synapse/replication/tcp/streams/federation.py +++ b/synapse/replication/tcp/streams/federation.py @@ -18,6 +18,7 @@ from synapse.replication.tcp.streams._base import ( Stream, + Token, current_token_without_instance, make_http_update_function, ) @@ -47,7 +48,7 @@ def __init__(self, hs: "HomeServer"): # will be a real FederationSender, which has stubs for current_token and # get_replication_rows.) federation_sender = hs.get_federation_sender() - current_token = current_token_without_instance( + self.current_token_func = current_token_without_instance( federation_sender.get_current_token ) update_function: Callable[ @@ -57,15 +58,21 @@ def __init__(self, hs: "HomeServer"): elif hs.should_send_federation(): # federation sender: Query master process update_function = make_http_update_function(hs, self.NAME) - current_token = self._stub_current_token + self.current_token_func = self._stub_current_token else: # other worker: stub out the update function (we're not interested in # any updates so when we get a POSITION we do nothing) update_function = self._stub_update_function - current_token = self._stub_current_token + self.current_token_func = self._stub_current_token - super().__init__(hs.get_instance_name(), current_token, update_function) + super().__init__(hs.get_instance_name(), update_function) + + def current_token(self, instance_name: str) -> Token: + return self.current_token_func(instance_name) + + def minimal_local_current_token(self) -> Token: + return self.current_token(self.local_instance_name) @staticmethod def _stub_current_token(instance_name: str) -> int: diff --git a/synapse/replication/tcp/streams/partial_state.py b/synapse/replication/tcp/streams/partial_state.py index a8ce5ffd7289..ad181d7e9331 100644 --- a/synapse/replication/tcp/streams/partial_state.py +++ b/synapse/replication/tcp/streams/partial_state.py @@ -15,7 +15,7 @@ import attr -from synapse.replication.tcp.streams import Stream +from synapse.replication.tcp.streams._base import _StreamFromIdGen if TYPE_CHECKING: from synapse.server import HomeServer @@ -27,7 +27,7 @@ class UnPartialStatedRoomStreamRow: room_id: str -class UnPartialStatedRoomStream(Stream): +class UnPartialStatedRoomStream(_StreamFromIdGen): """ Stream to notify about rooms becoming un-partial-stated; that is, when the background sync finishes such that we now have full state for @@ -41,8 +41,8 @@ def __init__(self, hs: "HomeServer"): store = hs.get_datastores().main super().__init__( hs.get_instance_name(), - store.get_un_partial_stated_rooms_token, store.get_un_partial_stated_rooms_from_stream, + store._un_partial_stated_rooms_stream_id_gen, ) @@ -56,7 +56,7 @@ class UnPartialStatedEventStreamRow: rejection_status_changed: bool -class UnPartialStatedEventStream(Stream): +class UnPartialStatedEventStream(_StreamFromIdGen): """ Stream to notify about events becoming un-partial-stated. """ @@ -68,6 +68,6 @@ def __init__(self, hs: "HomeServer"): store = hs.get_datastores().main super().__init__( hs.get_instance_name(), - store.get_un_partial_stated_events_token, store.get_un_partial_stated_events_from_stream, + store._un_partial_stated_events_stream_id_gen, ) diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 4dc04de35e48..9876ec252f54 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -133,6 +133,15 @@ def get_current_token_for_writer(self, instance_name: str) -> int: """ raise NotImplementedError() + @abc.abstractmethod + def get_minimal_local_current_token(self) -> int: + """Tries to return a minimal current token for the local instance, + i.e. for writers this would be the last successful write. + + If local instance is not a writer (or has written yet) then falls back + to returning the normal "current token". + """ + @abc.abstractmethod def get_next(self) -> AsyncContextManager[int]: """ @@ -312,6 +321,9 @@ def get_current_token(self) -> int: def get_current_token_for_writer(self, instance_name: str) -> int: return self.get_current_token() + def get_minimal_local_current_token(self) -> int: + return self.get_current_token() + class MultiWriterIdGenerator(AbstractStreamIdGenerator): """Generates and tracks stream IDs for a stream with multiple writers. @@ -721,6 +733,12 @@ def get_current_token_for_writer(self, instance_name: str) -> int: pos = max(pos, self._persisted_upto_position) return self._return_factor * pos + def get_minimal_local_current_token(self) -> int: + with self._lock: + return self._return_factor * self._current_positions.get( + self._instance_name, self._persisted_upto_position + ) + def get_positions(self) -> Dict[str, int]: """Get a copy of the current positon map. diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index 012babc35b2a..855eab6ac0d6 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -648,6 +648,19 @@ def _insert(txn: Cursor) -> None: with self.assertRaises(IncorrectDatabaseSetup): self._create_id_generator("first") + def test_minimal_local_token(self) -> None: + self._insert_rows("first", 3) + self._insert_rows("second", 4) + + first_id_gen = self._create_id_generator("first", writers=["first", "second"]) + second_id_gen = self._create_id_generator("second", writers=["first", "second"]) + + self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7}) + self.assertEqual(first_id_gen.get_minimal_local_current_token(), 3) + + self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7}) + self.assertEqual(second_id_gen.get_minimal_local_current_token(), 7) + class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): """Tests MultiWriterIdGenerator that produce *negative* stream IDs."""