Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Refactor storing of server keys #16261

Merged
merged 5 commits into from
Sep 12, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/16261.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Simplify server key storage.
35 changes: 7 additions & 28 deletions synapse/crypto/keyring.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall I think this is probably clearer...but the historical data in the tables still won't be the same, which is confusing.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, though given the keys have expiry times that will quickly be remedied 🤷

Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,7 @@
get_verify_key,
is_signing_algorithm_supported,
)
from signedjson.sign import (
SignatureVerifyException,
encode_canonical_json,
signature_ids,
verify_signed_json,
)
from signedjson.sign import SignatureVerifyException, signature_ids, verify_signed_json
from signedjson.types import VerifyKey
from unpaddedbase64 import decode_base64

Expand Down Expand Up @@ -596,24 +591,12 @@ async def process_v2_response(
verify_key=verify_key, valid_until_ts=key_data["expired_ts"]
)

key_json_bytes = encode_canonical_json(response_json)

await make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
self.store.store_server_keys_json,
server_name=server_name,
key_id=key_id,
from_server=from_server,
ts_now_ms=time_added_ms,
ts_expires_ms=ts_valid_until_ms,
key_json_bytes=key_json_bytes,
)
for key_id in verify_keys
],
consumeErrors=True,
).addErrback(unwrapFirstError)
await self.store.store_server_keys_response(
server_name=server_name,
from_server=from_server,
ts_added_ms=time_added_ms,
verify_keys=verify_keys,
response_json=response_json,
)

return verify_keys
Expand Down Expand Up @@ -775,10 +758,6 @@ async def get_server_verify_key_v2_indirect(

keys.setdefault(server_name, {}).update(processed_response)

await self.store.store_server_signature_keys(
perspective_name, time_now_ms, added_keys
)

return keys

def _validate_perspectives_response(
Expand Down
163 changes: 75 additions & 88 deletions synapse/storage/databases/main/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,17 @@
import itertools
import json
import logging
from typing import Dict, Iterable, Mapping, Optional, Tuple
from typing import Dict, Iterable, Optional, Tuple

from canonicaljson import encode_canonical_json
from signedjson.key import decode_verify_key_bytes
from unpaddedbase64 import decode_base64

from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.keys import FetchKeyResult, FetchKeyResultForRemote
from synapse.storage.types import Cursor
from synapse.types import JsonDict
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.iterutils import batch_iter

Expand Down Expand Up @@ -95,103 +98,87 @@ def _txn(txn: Cursor) -> Dict[Tuple[str, str], FetchKeyResult]:

return await self.db_pool.runInteraction("get_server_signature_keys", _txn)

async def store_server_signature_keys(
async def store_server_keys_response(
self,
server_name: str,
from_server: str,
ts_added_ms: int,
verify_keys: Mapping[Tuple[str, str], FetchKeyResult],
verify_keys: Dict[str, FetchKeyResult],
response_json: JsonDict,
) -> None:
"""Stores NACL verification keys for remote servers.
"""Stores the keys for the given server that we got from `from_server`.

Args:
from_server: Where the verification keys were looked up
ts_added_ms: The time to record that the key was added
verify_keys:
keys to be stored. Each entry is a triplet of
(server_name, key_id, key).
server_name: The owner of the keys
from_server: Which server we got the keys from
ts_added_ms: When we're adding the keys
verify_keys: The decoded keys
response_json: The full *signed* response JSON that contains the keys.
"""
key_values = []
value_values = []
invalidations = []
for (server_name, key_id), fetch_result in verify_keys.items():
key_values.append((server_name, key_id))
value_values.append(
(
from_server,
ts_added_ms,
fetch_result.valid_until_ts,
db_binary_type(fetch_result.verify_key.encode()),
)
)
# invalidate takes a tuple corresponding to the params of
# _get_server_signature_key. _get_server_signature_key only takes one
# param, which is itself the 2-tuple (server_name, key_id).
invalidations.append((server_name, key_id))

await self.db_pool.simple_upsert_many(
table="server_signature_keys",
key_names=("server_name", "key_id"),
key_values=key_values,
value_names=(
"from_server",
"ts_added_ms",
"ts_valid_until_ms",
"verify_key",
),
value_values=value_values,
desc="store_server_signature_keys",
)
key_json_bytes = encode_canonical_json(response_json)

def store_server_keys_response_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_upsert_many_txn(
txn,
table="server_signature_keys",
key_names=("server_name", "key_id"),
key_values=[(server_name, key_id) for key_id in verify_keys],
value_names=(
"from_server",
"ts_added_ms",
"ts_valid_until_ms",
"verify_key",
),
value_values=[
(
from_server,
ts_added_ms,
fetch_result.valid_until_ts,
db_binary_type(fetch_result.verify_key.encode()),
)
for fetch_result in verify_keys.values()
],
)

invalidate = self._get_server_signature_key.invalidate
for i in invalidations:
invalidate((i,))
self.db_pool.simple_upsert_many_txn(
txn,
table="server_keys_json",
key_names=("server_name", "key_id", "from_server"),
key_values=[
(server_name, key_id, from_server) for key_id in verify_keys
],
value_names=(
"ts_added_ms",
"ts_valid_until_ms",
"key_json",
),
value_values=[
(
ts_added_ms,
fetch_result.valid_until_ts,
db_binary_type(key_json_bytes),
)
for fetch_result in verify_keys.values()
],
)

async def store_server_keys_json(
self,
server_name: str,
key_id: str,
from_server: str,
ts_now_ms: int,
ts_expires_ms: int,
key_json_bytes: bytes,
) -> None:
"""Stores the JSON bytes for a set of keys from a server
The JSON should be signed by the originating server, the intermediate
server, and by this server. Updates the value for the
(server_name, key_id, from_server) triplet if one already existed.
Args:
server_name: The name of the server.
key_id: The identifier of the key this JSON is for.
from_server: The server this JSON was fetched from.
ts_now_ms: The time now in milliseconds.
ts_valid_until_ms: The time when this json stops being valid.
key_json_bytes: The encoded JSON.
"""
await self.db_pool.simple_upsert(
table="server_keys_json",
keyvalues={
"server_name": server_name,
"key_id": key_id,
"from_server": from_server,
},
values={
"server_name": server_name,
"key_id": key_id,
"from_server": from_server,
Comment on lines -177 to -179
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was silly for us to select and upsert the same values. 🤦

"ts_added_ms": ts_now_ms,
"ts_valid_until_ms": ts_expires_ms,
"key_json": db_binary_type(key_json_bytes),
},
desc="store_server_keys_json",
)
# invalidate takes a tuple corresponding to the params of
# _get_server_keys_json. _get_server_keys_json only takes one
# param, which is itself the 2-tuple (server_name, key_id).
for key_id in verify_keys:
self._invalidate_cache_and_stream(
txn, self._get_server_keys_json, ((server_name, key_id),)
)
self._invalidate_cache_and_stream(
txn, self.get_server_key_json_for_remote, (server_name, key_id)
)
self._invalidate_cache_and_stream(
txn, self._get_server_signature_key, ((server_name, key_id),)
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW get_server_signature_keys (which calls _get_server_signature_key) is only used in tests. Should we just nuke the cache on it?

Copy link
Member Author

@erikjohnston erikjohnston Sep 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just went down a bit of a rabbit hole and concluded we should just rip out the read functions of that table. The tests that use it can either a) use the other table, or b) specifically test that table "works".

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Works for me. 👍


# invalidate takes a tuple corresponding to the params of
# _get_server_keys_json. _get_server_keys_json only takes one
# param, which is itself the 2-tuple (server_name, key_id).
await self.invalidate_cache_and_stream(
"_get_server_keys_json", ((server_name, key_id),)
)
await self.invalidate_cache_and_stream(
"get_server_key_json_for_remote", (server_name, key_id)
await self.db_pool.runInteraction(
"store_server_keys_response", store_server_keys_response_txn
)

@cached()
Expand Down
41 changes: 23 additions & 18 deletions tests/crypto/test_keyring.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,23 +189,24 @@ def test_verify_json_for_server(self) -> None:
kr = keyring.Keyring(self.hs)

key1 = signedjson.key.generate_signing_key("1")
r = self.hs.get_datastores().main.store_server_keys_json(
r = self.hs.get_datastores().main.store_server_keys_response(
"server9",
get_key_id(key1),
from_server="test",
ts_now_ms=int(time.time() * 1000),
ts_expires_ms=1000,
ts_added_ms=int(time.time() * 1000),
verify_keys={
get_key_id(key1): FetchKeyResult(
verify_key=get_verify_key(key1), valid_until_ts=1000
)
},
# The entire response gets signed & stored, just include the bits we
# care about.
key_json_bytes=canonicaljson.encode_canonical_json(
{
"verify_keys": {
get_key_id(key1): {
"key": encode_verify_key_base64(get_verify_key(key1))
}
response_json={
"verify_keys": {
get_key_id(key1): {
"key": encode_verify_key_base64(get_verify_key(key1))
}
}
),
},
)
self.get_success(r)

Expand Down Expand Up @@ -293,13 +294,17 @@ def test_verify_json_for_server_with_null_valid_until_ms(self) -> None:
mock_fetcher.get_keys = AsyncMock(return_value={})

key1 = signedjson.key.generate_signing_key("1")
r = self.hs.get_datastores().main.store_server_signature_keys(
"server9",
int(time.time() * 1000),
# None is not a valid value in FetchKeyResult, but we're abusing this
# API to insert null values into the database. The nulls get converted
# to 0 when fetched in KeyStore.get_server_signature_keys.
{("server9", get_key_id(key1)): FetchKeyResult(get_verify_key(key1), None)}, # type: ignore[arg-type]

r = self.hs.get_datastores().main.db_pool.simple_upsert(
table="server_signature_keys",
keyvalues={"server_name": "server9", "key_id": get_key_id(key1)},
values={
"from_server": "server9",
"ts_added_ms": int(time.time() * 1000),
"ts_valid_until_ms": None,
"verify_key": memoryview(get_verify_key(key1).encode()),
},
desc="store_server_signature_keys",
)
self.get_success(r)

Expand Down
Loading