From e2fd591d9fea54f81b361b64b0eb8cee0176fbac Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Thu, 19 Dec 2024 20:06:39 +0100 Subject: [PATCH] Copy and simplify GetBulkStateContent for ACLs --- roomserver/acls/acls.go | 10 ++--- roomserver/acls/acls_test.go | 5 +-- roomserver/storage/interface.go | 2 + roomserver/storage/shared/storage.go | 57 ++++++++++++++++++++++++++++ 4 files changed, 65 insertions(+), 9 deletions(-) diff --git a/roomserver/acls/acls.go b/roomserver/acls/acls.go index 4d59c879..01e8820c 100644 --- a/roomserver/acls/acls.go +++ b/roomserver/acls/acls.go @@ -17,7 +17,6 @@ import ( "time" "github.com/element-hq/dendrite/roomserver/storage/tables" - "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" "github.com/sirupsen/logrus" ) @@ -28,9 +27,8 @@ type ServerACLDatabase interface { // RoomsWithACLs returns all room IDs for rooms with ACLs RoomsWithACLs(ctx context.Context) ([]string, error) - // GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match. - // If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned. - GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error) + // GetBulkStateACLs returns all server ACLs for the given rooms. + GetBulkStateACLs(ctx context.Context, roomIDs []string) ([]tables.StrippedEvent, error) } type ServerACLs struct { @@ -59,7 +57,7 @@ func NewServerACLs(db ServerACLDatabase) *ServerACLs { aclRegexCache: make(map[string]**regexp.Regexp, 100), } - // Look up all of the rooms that the current state server knows about. + // Look up all rooms with ACLs. rooms, err := db.RoomsWithACLs(ctx) if err != nil { logrus.WithError(err).Fatalf("Failed to get known rooms") @@ -68,7 +66,7 @@ func NewServerACLs(db ServerACLDatabase) *ServerACLs { // do then we'll process it into memory so that we have the regexes to // hand. - events, err := db.GetBulkStateContent(ctx, rooms, []gomatrixserverlib.StateKeyTuple{{EventType: MRoomServerACL, StateKey: ""}}, false) + events, err := db.GetBulkStateACLs(ctx, rooms) if err != nil { logrus.WithError(err).Errorf("Failed to get server ACLs for all rooms: %q", err) } diff --git a/roomserver/acls/acls_test.go b/roomserver/acls/acls_test.go index 16f3887d..d5a36a61 100644 --- a/roomserver/acls/acls_test.go +++ b/roomserver/acls/acls_test.go @@ -12,7 +12,6 @@ import ( "testing" "github.com/element-hq/dendrite/roomserver/storage/tables" - "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" "github.com/stretchr/testify/assert" ) @@ -108,11 +107,11 @@ var ( type dummyACLDB struct{} -func (d dummyACLDB) RoomsWithACLs(ctx context.Context) ([]string, error) { +func (d dummyACLDB) RoomsWithACLs(_ context.Context) ([]string, error) { return []string{"1", "2"}, nil } -func (d dummyACLDB) GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error) { +func (d dummyACLDB) GetBulkStateACLs(_ context.Context, _ []string) ([]tables.StrippedEvent, error) { return []tables.StrippedEvent{ { RoomID: "1", diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 49086dba..3bdeeef8 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -187,6 +187,8 @@ type Database interface { // RoomsWithACLs returns all room IDs for rooms with ACLs RoomsWithACLs(ctx context.Context) ([]string, error) + // GetBulkStateACLs returns all server ACLs for the given rooms. + GetBulkStateACLs(ctx context.Context, roomIDs []string) ([]tables.StrippedEvent, error) QueryAdminEventReports(ctx context.Context, from uint64, limit uint64, backwards bool, userID string, roomID string) ([]api.QueryAdminEventReportsResponse, int64, error) QueryAdminEventReport(ctx context.Context, reportID uint64) (api.QueryAdminEventReportResponse, error) AdminDeleteEventReport(ctx context.Context, reportID uint64) error diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index a4eb0eb9..6eb1518b 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -1437,6 +1437,63 @@ func (d *Database) GetRoomsByMembership(ctx context.Context, userID spec.UserID, return roomIDs, nil } +// GetBulkStateACLs is a lighter weight form of GetBulkStateContent, which only returns ACL state events. +func (d *Database) GetBulkStateACLs(ctx context.Context, roomIDs []string) ([]tables.StrippedEvent, error) { + tuples := []gomatrixserverlib.StateKeyTuple{{EventType: "m.room.server_acl", StateKey: ""}} + + var eventNIDs []types.EventNID + eventNIDToVer := make(map[types.EventNID]gomatrixserverlib.RoomVersion) + // TODO: This feels like this is going to be really slow... + for _, roomID := range roomIDs { + roomInfo, err2 := d.roomInfo(ctx, nil, roomID) + if err2 != nil { + return nil, fmt.Errorf("GetBulkStateACLs: failed to load room info for room %s : %w", roomID, err2) + } + // for unknown rooms or rooms which we don't have the current state, skip them. + if roomInfo == nil || roomInfo.IsStub() { + continue + } + // No querier needed, as we don't actually do state resolution + stateRes := state.NewStateResolution(d, roomInfo, nil) + entries, err2 := stateRes.LoadStateAtSnapshotForStringTuples(ctx, roomInfo.StateSnapshotNID(), tuples) + if err2 != nil { + return nil, fmt.Errorf("GetBulkStateACLs: failed to load state for room %s : %w", roomID, err2) + } + for _, entry := range entries { + eventNIDs = append(eventNIDs, entry.EventNID) + eventNIDToVer[entry.EventNID] = roomInfo.RoomVersion + } + } + eventIDs, err := d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs) + if err != nil { + eventIDs = map[types.EventNID]string{} + } + events, err := d.EventJSONTable.BulkSelectEventJSON(ctx, nil, eventNIDs) + if err != nil { + return nil, fmt.Errorf("GetBulkStateACLs: failed to load event JSON for event nids: %w", err) + } + result := make([]tables.StrippedEvent, len(events)) + for i := range events { + roomVer := eventNIDToVer[events[i].EventNID] + verImpl, err := gomatrixserverlib.GetRoomVersion(roomVer) + if err != nil { + return nil, err + } + ev, err := verImpl.NewEventFromTrustedJSONWithEventID(eventIDs[events[i].EventNID], events[i].EventJSON, false) + if err != nil { + return nil, fmt.Errorf("GetBulkStateACLs: failed to load event JSON for event NID %v : %w", events[i].EventNID, err) + } + result[i] = tables.StrippedEvent{ + EventType: ev.Type(), + RoomID: ev.RoomID().String(), + StateKey: *ev.StateKey(), + ContentValue: tables.ExtractContentValue(&types.HeaderedEvent{PDU: ev}), + } + } + + return result, nil +} + // GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match. // If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned. func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error) {