Skip to content

Commit

Permalink
Copy and simplify GetBulkStateContent for ACLs
Browse files Browse the repository at this point in the history
  • Loading branch information
S7evinK committed Dec 19, 2024
1 parent e0b1539 commit e2fd591
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 9 deletions.
10 changes: 4 additions & 6 deletions roomserver/acls/acls.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ import (
"time"

"github.com/element-hq/dendrite/roomserver/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec"
"github.com/sirupsen/logrus"
)
Expand All @@ -28,9 +27,8 @@ type ServerACLDatabase interface {
// RoomsWithACLs returns all room IDs for rooms with ACLs
RoomsWithACLs(ctx context.Context) ([]string, error)

// GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match.
// If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned.
GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error)
// GetBulkStateACLs returns all server ACLs for the given rooms.
GetBulkStateACLs(ctx context.Context, roomIDs []string) ([]tables.StrippedEvent, error)
}

type ServerACLs struct {
Expand Down Expand Up @@ -59,7 +57,7 @@ func NewServerACLs(db ServerACLDatabase) *ServerACLs {
aclRegexCache: make(map[string]**regexp.Regexp, 100),
}

// Look up all of the rooms that the current state server knows about.
// Look up all rooms with ACLs.
rooms, err := db.RoomsWithACLs(ctx)
if err != nil {
logrus.WithError(err).Fatalf("Failed to get known rooms")
Expand All @@ -68,7 +66,7 @@ func NewServerACLs(db ServerACLDatabase) *ServerACLs {
// do then we'll process it into memory so that we have the regexes to
// hand.

events, err := db.GetBulkStateContent(ctx, rooms, []gomatrixserverlib.StateKeyTuple{{EventType: MRoomServerACL, StateKey: ""}}, false)
events, err := db.GetBulkStateACLs(ctx, rooms)
if err != nil {
logrus.WithError(err).Errorf("Failed to get server ACLs for all rooms: %q", err)
}
Expand Down
5 changes: 2 additions & 3 deletions roomserver/acls/acls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"testing"

"github.com/element-hq/dendrite/roomserver/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec"
"github.com/stretchr/testify/assert"
)
Expand Down Expand Up @@ -108,11 +107,11 @@ var (

type dummyACLDB struct{}

func (d dummyACLDB) RoomsWithACLs(ctx context.Context) ([]string, error) {
func (d dummyACLDB) RoomsWithACLs(_ context.Context) ([]string, error) {
return []string{"1", "2"}, nil
}

func (d dummyACLDB) GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error) {
func (d dummyACLDB) GetBulkStateACLs(_ context.Context, _ []string) ([]tables.StrippedEvent, error) {
return []tables.StrippedEvent{
{
RoomID: "1",
Expand Down
2 changes: 2 additions & 0 deletions roomserver/storage/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ type Database interface {

// RoomsWithACLs returns all room IDs for rooms with ACLs
RoomsWithACLs(ctx context.Context) ([]string, error)
// GetBulkStateACLs returns all server ACLs for the given rooms.
GetBulkStateACLs(ctx context.Context, roomIDs []string) ([]tables.StrippedEvent, error)
QueryAdminEventReports(ctx context.Context, from uint64, limit uint64, backwards bool, userID string, roomID string) ([]api.QueryAdminEventReportsResponse, int64, error)
QueryAdminEventReport(ctx context.Context, reportID uint64) (api.QueryAdminEventReportResponse, error)
AdminDeleteEventReport(ctx context.Context, reportID uint64) error
Expand Down
57 changes: 57 additions & 0 deletions roomserver/storage/shared/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -1437,6 +1437,63 @@ func (d *Database) GetRoomsByMembership(ctx context.Context, userID spec.UserID,
return roomIDs, nil
}

// GetBulkStateACLs is a lighter weight form of GetBulkStateContent, which only returns ACL state events.
func (d *Database) GetBulkStateACLs(ctx context.Context, roomIDs []string) ([]tables.StrippedEvent, error) {
tuples := []gomatrixserverlib.StateKeyTuple{{EventType: "m.room.server_acl", StateKey: ""}}

var eventNIDs []types.EventNID
eventNIDToVer := make(map[types.EventNID]gomatrixserverlib.RoomVersion)
// TODO: This feels like this is going to be really slow...
for _, roomID := range roomIDs {
roomInfo, err2 := d.roomInfo(ctx, nil, roomID)
if err2 != nil {
return nil, fmt.Errorf("GetBulkStateACLs: failed to load room info for room %s : %w", roomID, err2)
}
// for unknown rooms or rooms which we don't have the current state, skip them.
if roomInfo == nil || roomInfo.IsStub() {
continue
}
// No querier needed, as we don't actually do state resolution
stateRes := state.NewStateResolution(d, roomInfo, nil)
entries, err2 := stateRes.LoadStateAtSnapshotForStringTuples(ctx, roomInfo.StateSnapshotNID(), tuples)
if err2 != nil {
return nil, fmt.Errorf("GetBulkStateACLs: failed to load state for room %s : %w", roomID, err2)
}
for _, entry := range entries {
eventNIDs = append(eventNIDs, entry.EventNID)
eventNIDToVer[entry.EventNID] = roomInfo.RoomVersion
}
}
eventIDs, err := d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs)
if err != nil {
eventIDs = map[types.EventNID]string{}
}
events, err := d.EventJSONTable.BulkSelectEventJSON(ctx, nil, eventNIDs)
if err != nil {
return nil, fmt.Errorf("GetBulkStateACLs: failed to load event JSON for event nids: %w", err)
}
result := make([]tables.StrippedEvent, len(events))
for i := range events {
roomVer := eventNIDToVer[events[i].EventNID]
verImpl, err := gomatrixserverlib.GetRoomVersion(roomVer)
if err != nil {
return nil, err
}
ev, err := verImpl.NewEventFromTrustedJSONWithEventID(eventIDs[events[i].EventNID], events[i].EventJSON, false)
if err != nil {
return nil, fmt.Errorf("GetBulkStateACLs: failed to load event JSON for event NID %v : %w", events[i].EventNID, err)
}
result[i] = tables.StrippedEvent{
EventType: ev.Type(),
RoomID: ev.RoomID().String(),
StateKey: *ev.StateKey(),
ContentValue: tables.ExtractContentValue(&types.HeaderedEvent{PDU: ev}),
}
}

return result, nil
}

// GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match.
// If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned.
func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error) {
Expand Down

0 comments on commit e2fd591

Please sign in to comment.