Skip to content

Commit

Permalink
Backport support for fallback keys
Browse files Browse the repository at this point in the history
  • Loading branch information
neilalexander committed Dec 17, 2024
1 parent 23e097c commit b302ef0
Show file tree
Hide file tree
Showing 13 changed files with 446 additions and 20 deletions.
1 change: 1 addition & 0 deletions syncapi/internal/keychange.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ func DeviceOTKCounts(ctx context.Context, keyAPI api.SyncKeyAPI, userID, deviceI
return queryRes.Error
}
res.DeviceListsOTKCount = queryRes.Count.KeyCount
res.DeviceListsUnusedFallbackAlgorithms = queryRes.UnusedFallbackAlgorithms
return nil
}

Expand Down
16 changes: 9 additions & 7 deletions syncapi/types/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -350,13 +350,14 @@ type ToDeviceResponse struct {

// Response represents a /sync API response. See https://matrix.org/docs/spec/client_server/r0.2.0.html#get-matrix-client-r0-sync
type Response struct {
NextBatch StreamingToken `json:"next_batch"`
AccountData *ClientEvents `json:"account_data,omitempty"`
Presence *ClientEvents `json:"presence,omitempty"`
Rooms *RoomsResponse `json:"rooms,omitempty"`
ToDevice *ToDeviceResponse `json:"to_device,omitempty"`
DeviceLists *DeviceLists `json:"device_lists,omitempty"`
DeviceListsOTKCount map[string]int `json:"device_one_time_keys_count,omitempty"`
NextBatch StreamingToken `json:"next_batch"`
AccountData *ClientEvents `json:"account_data,omitempty"`
Presence *ClientEvents `json:"presence,omitempty"`
Rooms *RoomsResponse `json:"rooms,omitempty"`
ToDevice *ToDeviceResponse `json:"to_device,omitempty"`
DeviceLists *DeviceLists `json:"device_lists,omitempty"`
DeviceListsOTKCount map[string]int `json:"device_one_time_keys_count,omitempty"`
DeviceListsUnusedFallbackAlgorithms []string `json:"device_unused_fallback_key_types"`
}

func (r Response) MarshalJSON() ([]byte, error) {
Expand Down Expand Up @@ -419,6 +420,7 @@ func NewResponse() *Response {
res.DeviceLists = &DeviceLists{}
res.ToDevice = &ToDeviceResponse{}
res.DeviceListsOTKCount = map[string]int{}
res.DeviceListsUnusedFallbackAlgorithms = []string{}

return &res
}
Expand Down
4 changes: 3 additions & 1 deletion sytest-whitelist
Original file line number Diff line number Diff line change
Expand Up @@ -793,4 +793,6 @@ remote user can join room with version 11
User can invite remote user to room with version 11
Remote user can backfill in a room with version 11
Can reject invites over federation for rooms with version 11
Can receive redactions from regular users over federation in room version 11
Can receive redactions from regular users over federation in room version 11
Can upload self-signing keys
uploading signed devices gets propagated over federation
36 changes: 28 additions & 8 deletions userapi/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -788,12 +788,30 @@ type OneTimeKeysCount struct {
KeyCount map[string]int
}

// FallbackKeys represents a set of fallback keys for a single device
// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload
type FallbackKeys struct {
// The user who owns this device
UserID string
// The device ID of this device
DeviceID string
// A map of algorithm:key_id => key JSON
KeyJSON map[string]json.RawMessage
}

// Split a key in KeyJSON into algorithm and key ID
func (k *FallbackKeys) Split(keyIDWithAlgo string) (algo string, keyID string) {
segments := strings.Split(keyIDWithAlgo, ":")
return segments[0], segments[1]
}

// PerformUploadKeysRequest is the request to PerformUploadKeys
type PerformUploadKeysRequest struct {
UserID string // Required - User performing the request
DeviceID string // Optional - Device performing the request, for fetching OTK count
DeviceKeys []DeviceKeys
OneTimeKeys []OneTimeKeys
UserID string // Required - User performing the request
DeviceID string // Optional - Device performing the request, for fetching OTK count
DeviceKeys []DeviceKeys
OneTimeKeys []OneTimeKeys
FallbackKeys []FallbackKeys
// OnlyDisplayNameUpdates should be `true` if ALL the DeviceKeys are present to update
// the display name for their respective device, and NOT to modify the keys. The key
// itself doesn't change but it's easier to pretend upload new keys and reuse the same code paths.
Expand All @@ -810,8 +828,9 @@ type PerformUploadKeysResponse struct {
// A fatal error when processing e.g database failures
Error *KeyError
// A map of user_id -> device_id -> Error for tracking failures.
KeyErrors map[string]map[string]*KeyError
OneTimeKeyCounts []OneTimeKeysCount
KeyErrors map[string]map[string]*KeyError
OneTimeKeyCounts []OneTimeKeysCount
FallbackKeysUnusedAlgorithms []string
}

// PerformDeleteKeysRequest asks the keyserver to forget about certain
Expand Down Expand Up @@ -917,8 +936,9 @@ type QueryOneTimeKeysRequest struct {

type QueryOneTimeKeysResponse struct {
// OTK key counts, in the extended /sync form described by https://matrix.org/docs/spec/client_server/r0.6.1#id84
Count OneTimeKeysCount
Error *KeyError
Count OneTimeKeysCount
UnusedFallbackAlgorithms []string
Error *KeyError
}

type QueryDeviceMessagesRequest struct {
Expand Down
55 changes: 51 additions & 4 deletions userapi/internal/key_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,22 @@ func (a *UserInternalAPI) PerformUploadKeys(ctx context.Context, req *api.Perfor
if len(req.DeviceKeys) > 0 {
a.uploadLocalDeviceKeys(ctx, req, res)
}
if len(req.OneTimeKeys) > 0 {
a.uploadOneTimeKeys(ctx, req, res)
if len(req.OneTimeKeys) > 0 || len(req.FallbackKeys) > 0 {
a.uploadOneTimeAndFallbackKeys(ctx, req, res)
}
otks, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
if err != nil {
return err
}
algos, err := a.KeyDatabase.UnusedFallbackKeyAlgorithms(ctx, req.UserID, req.DeviceID)
if err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("Failed to query unused fallback algorithms: %s", err),
}
return nil
}
res.OneTimeKeyCounts = []api.OneTimeKeysCount{*otks}
res.FallbackKeysUnusedAlgorithms = algos
return nil
}

Expand Down Expand Up @@ -169,7 +177,15 @@ func (a *UserInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOn
}
return nil
}
algos, err := a.KeyDatabase.UnusedFallbackKeyAlgorithms(ctx, req.UserID, req.DeviceID)
if err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("Failed to query unused fallback algorithms: %s", err),
}
return nil
}
res.Count = *count
res.UnusedFallbackAlgorithms = algos
return nil
}

Expand Down Expand Up @@ -507,6 +523,9 @@ func (a *UserInternalAPI) queryRemoteKeysOnServer(
for userID := range userIDsForAllDevices {
err := a.Updater.ManualUpdate(context.Background(), spec.ServerName(serverName), userID)
if err != nil {
if errors.Is(err, context.Canceled) {
return
}
logrus.WithFields(logrus.Fields{
logrus.ErrorKey: err,
"user_id": userID,
Expand All @@ -520,6 +539,9 @@ func (a *UserInternalAPI) queryRemoteKeysOnServer(
// user so the fact that we're populating all devices here isn't a problem so long as we have devices.
err = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, nil)
if err != nil {
if errors.Is(err, context.Canceled) {
return
}
logrus.WithFields(logrus.Fields{
logrus.ErrorKey: err,
"user_id": userID,
Expand Down Expand Up @@ -715,7 +737,7 @@ func (a *UserInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Pe
}
}

func (a *UserInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
func (a *UserInternalAPI) uploadOneTimeAndFallbackKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
if req.UserID == "" {
res.Error = &api.KeyError{
Err: "user ID missing",
Expand Down Expand Up @@ -768,7 +790,32 @@ func (a *UserInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perfor
// collect counts
res.OneTimeKeyCounts = append(res.OneTimeKeyCounts, *counts)
}

if len(req.FallbackKeys) > 0 {
if err := a.KeyDatabase.DeleteFallbackKeys(ctx, req.UserID, req.DeviceID); err != nil {
res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
Err: fmt.Sprintf("%s device %s : failed to clear fallback keys: %s", req.UserID, req.DeviceID, err.Error()),
})
return
}
for _, key := range req.FallbackKeys {
// grab existing keys based on (user/device/algorithm/key ID)
keyIDsWithAlgorithms := make([]string, len(key.KeyJSON))
i := 0
for keyIDWithAlgo := range key.KeyJSON {
keyIDsWithAlgorithms[i] = keyIDWithAlgo
i++
}
unused, err := a.KeyDatabase.StoreFallbackKeys(ctx, key)
if err != nil {
res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
Err: fmt.Sprintf("%s device %s : failed to store fallback keys: %s", req.UserID, req.DeviceID, err.Error()),
})
continue
}
// collect counts
res.FallbackKeysUnusedAlgorithms = unused
}
}
}

func emitDeviceKeyChanges(producer KeyChangeProducer, existing, new []api.DeviceMessage, onlyUpdateDisplayName bool) error {
Expand Down
9 changes: 9 additions & 0 deletions userapi/storage/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,15 @@ type KeyDatabase interface {
// OneTimeKeysCount returns a count of all OTKs for this device.
OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error)

// StoreFallbackKeys persists the given fallback keys.
StoreFallbackKeys(ctx context.Context, keys api.FallbackKeys) ([]string, error)

// UnusedFallbackKeyAlgorithms returns unused fallback algorithms for this user/device.
UnusedFallbackKeyAlgorithms(ctx context.Context, userID, deviceID string) ([]string, error)

// DeleteFallbackKeys deletes all fallback keys for the user.
DeleteFallbackKeys(ctx context.Context, userID, deviceID string) error

// DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced.
DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error

Expand Down
134 changes: 134 additions & 0 deletions userapi/storage/postgres/fallback_keys_table.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
// Copyright 2024 New Vector Ltd.
// Copyright 2017 Vector Creations Ltd
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.

package postgres

import (
"context"
"database/sql"
"encoding/json"
"time"

"github.com/element-hq/dendrite/internal"
"github.com/element-hq/dendrite/internal/sqlutil"
"github.com/element-hq/dendrite/userapi/api"
"github.com/element-hq/dendrite/userapi/storage/tables"
)

var fallbackKeysSchema = `
-- Stores one-time public keys for users
CREATE TABLE IF NOT EXISTS keyserver_fallback_keys (
user_id TEXT NOT NULL,
device_id TEXT NOT NULL,
key_id TEXT NOT NULL,
algorithm TEXT NOT NULL,
ts_added_secs BIGINT NOT NULL,
key_json TEXT NOT NULL,
used BOOLEAN NOT NULL,
-- Clobber based on tuple of user/device/algorithm.
CONSTRAINT keyserver_fallback_keys_unique UNIQUE (user_id, device_id, algorithm)
);
CREATE INDEX IF NOT EXISTS keyserver_fallback_keys_idx ON keyserver_fallback_keys (user_id, device_id);
`

const upsertFallbackKeysSQL = "" +
"INSERT INTO keyserver_fallback_keys (user_id, device_id, key_id, algorithm, ts_added_secs, key_json, used)" +
" VALUES ($1, $2, $3, $4, $5, $6, false)" +
" ON CONFLICT ON CONSTRAINT keyserver_fallback_keys_unique" +
" DO UPDATE SET key_id = $3, key_json = $6, used = false"

const selectFallbackUnusedAlgorithmsSQL = "" +
"SELECT algorithm FROM keyserver_fallback_keys WHERE user_id = $1 AND device_id = $2 AND used = false"

const selectFallbackKeysByAlgorithmSQL = "" +
"SELECT key_id, key_json FROM keyserver_fallback_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 ORDER BY used ASC LIMIT 1"

const deleteFallbackKeysSQL = "" +
"DELETE FROM keyserver_fallback_keys WHERE user_id = $1 AND device_id = $2"

const updateFallbackKeyUsedSQL = "" +
"UPDATE keyserver_fallback_keys SET used=true WHERE user_id = $1 AND device_id = $2 AND key_id = $3 AND algorithm = $4"

type fallbackKeysStatements struct {
db *sql.DB
upsertKeysStmt *sql.Stmt
selectUnusedAlgorithmsStmt *sql.Stmt
selectKeyByAlgorithmStmt *sql.Stmt
deleteFallbackKeysStmt *sql.Stmt
updateFallbackKeyUsedStmt *sql.Stmt
}

func NewPostgresFallbackKeysTable(db *sql.DB) (tables.FallbackKeys, error) {
s := &fallbackKeysStatements{
db: db,
}
_, err := db.Exec(fallbackKeysSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.upsertKeysStmt, upsertFallbackKeysSQL},
{&s.selectUnusedAlgorithmsStmt, selectFallbackUnusedAlgorithmsSQL},
{&s.selectKeyByAlgorithmStmt, selectFallbackKeysByAlgorithmSQL},
{&s.deleteFallbackKeysStmt, deleteFallbackKeysSQL},
{&s.updateFallbackKeyUsedStmt, updateFallbackKeyUsedSQL},
}.Prepare(db)
}

func (s *fallbackKeysStatements) SelectUnusedFallbackKeyAlgorithms(ctx context.Context, userID, deviceID string) ([]string, error) {
rows, err := s.selectUnusedAlgorithmsStmt.QueryContext(ctx, userID, deviceID)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed")
algos := []string{}
for rows.Next() {
var algorithm string
if err = rows.Scan(&algorithm); err != nil {
return nil, err
}
algos = append(algos, algorithm)
}
return algos, rows.Err()
}

func (s *fallbackKeysStatements) InsertFallbackKeys(ctx context.Context, txn *sql.Tx, keys api.FallbackKeys) ([]string, error) {
now := time.Now().Unix()
for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
algo, keyID := keys.Split(keyIDWithAlgo)
_, err := sqlutil.TxStmt(txn, s.upsertKeysStmt).ExecContext(
ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON),
)
if err != nil {
return nil, err
}
}
return s.SelectUnusedFallbackKeyAlgorithms(ctx, keys.UserID, keys.DeviceID)
}

func (s *fallbackKeysStatements) DeleteFallbackKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error {
_, err := sqlutil.TxStmt(txn, s.deleteFallbackKeysStmt).ExecContext(ctx, userID, deviceID)
return err
}

func (s *fallbackKeysStatements) SelectAndUpdateFallbackKey(
ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string,
) (map[string]json.RawMessage, error) {
var keyID string
var keyJSON string
err := sqlutil.TxStmtContext(ctx, txn, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, err
}
_, err = sqlutil.TxStmtContext(ctx, txn, s.updateFallbackKeyUsedStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
return map[string]json.RawMessage{
algorithm + ":" + keyID: json.RawMessage(keyJSON),
}, err
}
5 changes: 5 additions & 0 deletions userapi/storage/postgres/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,10 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp
if err != nil {
return nil, err
}
fk, err := NewPostgresFallbackKeysTable(db)
if err != nil {
return nil, err
}
dk, err := NewPostgresDeviceKeysTable(db)
if err != nil {
return nil, err
Expand All @@ -164,6 +168,7 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp

return &shared.KeyDatabase{
OneTimeKeysTable: otk,
FallbackKeysTable: fk,
DeviceKeysTable: dk,
KeyChangesTable: kc,
StaleDeviceListsTable: sdl,
Expand Down
Loading

0 comments on commit b302ef0

Please sign in to comment.