diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py index 0331c1a2996d..41563371dcd2 100644 --- a/synapse/storage/databases/main/keys.py +++ b/synapse/storage/databases/main/keys.py @@ -39,65 +39,6 @@ class KeyStore(CacheInvalidationWorkerStore): """Persistence for signature verification keys""" - @cached() - def _get_server_signature_key( - self, server_name_and_key_id: Tuple[str, str] - ) -> FetchKeyResult: - raise NotImplementedError() - - @cachedList( - cached_method_name="_get_server_signature_key", - list_name="server_name_and_key_ids", - ) - async def get_server_signature_keys( - self, server_name_and_key_ids: Iterable[Tuple[str, str]] - ) -> Dict[Tuple[str, str], FetchKeyResult]: - """ - Args: - server_name_and_key_ids: - iterable of (server_name, key-id) tuples to fetch keys for - - Returns: - A map from (server_name, key_id) -> FetchKeyResult, or None if the - key is unknown - """ - keys = {} - - def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str], ...]) -> None: - """Processes a batch of keys to fetch, and adds the result to `keys`.""" - - # batch_iter always returns tuples so it's safe to do len(batch) - sql = """ - SELECT server_name, key_id, verify_key, ts_valid_until_ms - FROM server_signature_keys WHERE 1=0 - """ + " OR (server_name=? AND key_id=?)" * len( - batch - ) - - txn.execute(sql, tuple(itertools.chain.from_iterable(batch))) - - for row in txn: - server_name, key_id, key_bytes, ts_valid_until_ms = row - - if ts_valid_until_ms is None: - # Old keys may be stored with a ts_valid_until_ms of null, - # in which case we treat this as if it was set to `0`, i.e. - # it won't match key requests that define a minimum - # `ts_valid_until_ms`. - ts_valid_until_ms = 0 - - keys[(server_name, key_id)] = FetchKeyResult( - verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)), - valid_until_ts=ts_valid_until_ms, - ) - - def _txn(txn: Cursor) -> Dict[Tuple[str, str], FetchKeyResult]: - for batch in batch_iter(server_name_and_key_ids, 50): - _get_keys(txn, batch) - return keys - - return await self.db_pool.runInteraction("get_server_signature_keys", _txn) - async def store_server_keys_response( self, server_name: str, @@ -173,9 +114,6 @@ def store_server_keys_response_txn(txn: LoggingTransaction) -> None: self._invalidate_cache_and_stream( txn, self.get_server_key_json_for_remote, (server_name, key_id) ) - self._invalidate_cache_and_stream( - txn, self._get_server_signature_key, ((server_name, key_id),) - ) await self.db_pool.runInteraction( "store_server_keys_response", store_server_keys_response_txn diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index dd70489645c3..9e450967adcf 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -312,7 +312,7 @@ def test_verify_json_for_server_with_null_valid_until_ms(self) -> None: signedjson.sign.sign_json(json1, "server9", key1) # should succeed on a signed object with a 0 minimum_valid_until_ms - d = self.hs.get_datastores().main.get_server_signature_keys( + d = self.hs.get_datastores().main.get_server_keys_json( [("server9", get_key_id(key1))] ) result = self.get_success(d) diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py deleted file mode 100644 index a409d6fc3e5e..000000000000 --- a/tests/storage/test_keys.py +++ /dev/null @@ -1,171 +0,0 @@ -# Copyright 2017 Vector Creations Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import signedjson.key -import signedjson.types -import unpaddedbase64 - -from synapse.storage.keys import FetchKeyResult - -import tests.unittest - - -def decode_verify_key_base64( - key_id: str, key_base64: str -) -> signedjson.types.VerifyKey: - key_bytes = unpaddedbase64.decode_base64(key_base64) - return signedjson.key.decode_verify_key_bytes(key_id, key_bytes) - - -KEY_1 = decode_verify_key_base64( - "ed25519:key1", "fP5l4JzpZPq/zdbBg5xx6lQGAAOM9/3w94cqiJ5jPrw" -) -KEY_2 = decode_verify_key_base64( - "ed25519:key2", "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw" -) - - -class KeyStoreTestCase(tests.unittest.HomeserverTestCase): - def test_get_server_signature_keys(self) -> None: - store = self.hs.get_datastores().main - - key_id_1 = "ed25519:key1" - key_id_2 = "ed25519:KEY_ID_2" - self.get_success( - store.store_server_keys_response( - "server1", - from_server="from_server", - ts_added_ms=10, - verify_keys={ - key_id_1: FetchKeyResult(KEY_1, 100), - key_id_2: FetchKeyResult(KEY_2, 200), - }, - response_json={ - "verify_keys": { - key_id_1: { - "key": signedjson.key.encode_verify_key_base64(KEY_1) - }, - key_id_2: { - "key": signedjson.key.encode_verify_key_base64(KEY_2) - }, - } - }, - ), - ) - - res = self.get_success( - store.get_server_signature_keys( - [ - ("server1", key_id_1), - ("server1", key_id_2), - ("server1", "ed25519:key3"), - ] - ) - ) - - self.assertEqual(len(res.keys()), 3) - res1 = res[("server1", key_id_1)] - self.assertEqual(res1.verify_key, KEY_1) - self.assertEqual(res1.verify_key.version, "key1") - self.assertEqual(res1.valid_until_ts, 100) - - res2 = res[("server1", key_id_2)] - self.assertEqual(res2.verify_key, KEY_2) - # version comes from the ID it was stored with - self.assertEqual(res2.verify_key.version, "KEY_ID_2") - self.assertEqual(res2.valid_until_ts, 200) - - # non-existent result gives None - self.assertIsNone(res[("server1", "ed25519:key3")]) - - def test_cache(self) -> None: - """Check that updates correctly invalidate the cache.""" - - store = self.hs.get_datastores().main - - key_id_1 = "ed25519:key1" - key_id_2 = "ed25519:key2" - - self.get_success( - store.store_server_keys_response( - "srv1", - from_server="from_server", - ts_added_ms=0, - verify_keys={ - key_id_1: FetchKeyResult(KEY_1, 100), - key_id_2: FetchKeyResult(KEY_2, 200), - }, - response_json={ - "verify_keys": { - key_id_1: { - "key": signedjson.key.encode_verify_key_base64(KEY_1) - }, - key_id_2: { - "key": signedjson.key.encode_verify_key_base64(KEY_2) - }, - } - }, - ), - ) - - res = self.get_success( - store.get_server_signature_keys([("srv1", key_id_1), ("srv1", key_id_2)]) - ) - self.assertEqual(len(res.keys()), 2) - - res1 = res[("srv1", key_id_1)] - self.assertEqual(res1.verify_key, KEY_1) - self.assertEqual(res1.valid_until_ts, 100) - - res2 = res[("srv1", key_id_2)] - self.assertEqual(res2.verify_key, KEY_2) - self.assertEqual(res2.valid_until_ts, 200) - - # we should be able to look up the same thing again without a db hit - res = self.get_success(store.get_server_signature_keys([("srv1", key_id_1)])) - self.assertEqual(len(res.keys()), 1) - self.assertEqual(res[("srv1", key_id_1)].verify_key, KEY_1) - - new_key_2 = signedjson.key.get_verify_key( - signedjson.key.generate_signing_key("key2") - ) - d = store.store_server_keys_response( - "srv1", - from_server="from_server", - ts_added_ms=0, - verify_keys={ - key_id_2: FetchKeyResult(new_key_2, 300), - }, - response_json={ - "verify_keys": { - key_id_2: { - "key": signedjson.key.encode_verify_key_base64(new_key_2) - }, - } - }, - ) - self.get_success(d) - - res = self.get_success( - store.get_server_signature_keys([("srv1", key_id_1), ("srv1", key_id_2)]) - ) - self.assertEqual(len(res.keys()), 2) - - res1 = res[("srv1", key_id_1)] - self.assertEqual(res1.verify_key, KEY_1) - self.assertEqual(res1.valid_until_ts, 100) - - res2 = res[("srv1", key_id_2)] - self.assertEqual(res2.verify_key, new_key_2) - self.assertEqual(res2.valid_until_ts, 300)